@@ -254,24 +254,24 @@ class TensorTest {
254254 assertEquals(tensor.dtype(), DType .FLOAT )
255255
256256 assertThatThrownBy { tensor.dataAsByteArray }
257- .isInstanceOf(IllegalStateException ::class .java)
258- .hasMessage(" Tensor of type Tensor_float32 cannot return data as byte array." )
257+ .isInstanceOf(IllegalStateException ::class .java)
258+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as byte array." )
259259
260260 assertThatThrownBy { tensor.dataAsUnsignedByteArray }
261- .isInstanceOf(IllegalStateException ::class .java)
262- .hasMessage(" Tensor of type Tensor_float32 cannot return data as unsigned byte array." )
261+ .isInstanceOf(IllegalStateException ::class .java)
262+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as unsigned byte array." )
263263
264264 assertThatThrownBy { tensor.dataAsIntArray }
265- .isInstanceOf(IllegalStateException ::class .java)
266- .hasMessage(" Tensor of type Tensor_float32 cannot return data as int array." )
265+ .isInstanceOf(IllegalStateException ::class .java)
266+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as int array." )
267267
268268 assertThatThrownBy { tensor.dataAsDoubleArray }
269- .isInstanceOf(IllegalStateException ::class .java)
270- .hasMessage(" Tensor of type Tensor_float32 cannot return data as double array." )
269+ .isInstanceOf(IllegalStateException ::class .java)
270+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as double array." )
271271
272272 assertThatThrownBy { tensor.dataAsLongArray }
273- .isInstanceOf(IllegalStateException ::class .java)
274- .hasMessage(" Tensor of type Tensor_float32 cannot return data as long array." )
273+ .isInstanceOf(IllegalStateException ::class .java)
274+ .hasMessage(" Tensor of type Tensor_float32 cannot return data as long array." )
275275 }
276276
277277 @Test
@@ -281,20 +281,20 @@ class TensorTest {
281281 val mismatchShape = longArrayOf(1 , 2 )
282282
283283 assertThatThrownBy { Tensor .fromBlob(null as FloatArray? , mismatchShape) }
284- .isInstanceOf(IllegalArgumentException ::class .java)
285- .hasMessage(" Data array must be not null" )
284+ .isInstanceOf(IllegalArgumentException ::class .java)
285+ .hasMessage(" Data array must be not null" )
286286
287287 assertThatThrownBy { Tensor .fromBlob(data, null ) }
288- .isInstanceOf(IllegalArgumentException ::class .java)
289- .hasMessage(" Shape must be not null" )
288+ .isInstanceOf(IllegalArgumentException ::class .java)
289+ .hasMessage(" Shape must be not null" )
290290
291291 assertThatThrownBy { Tensor .fromBlob(data, shapeWithNegativeValues) }
292- .isInstanceOf(IllegalArgumentException ::class .java)
293- .hasMessage(" Shape elements must be non negative" )
292+ .isInstanceOf(IllegalArgumentException ::class .java)
293+ .hasMessage(" Shape elements must be non negative" )
294294
295295 assertThatThrownBy { Tensor .fromBlob(data, mismatchShape) }
296- .isInstanceOf(IllegalArgumentException ::class .java)
297- .hasMessage(" Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]" )
296+ .isInstanceOf(IllegalArgumentException ::class .java)
297+ .hasMessage(" Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]" )
298298 }
299299
300300 @Test
@@ -336,4 +336,32 @@ class TensorTest {
336336 assertEquals(shape[i], deserShape[i])
337337 }
338338 }
339+
340+ @Test
341+ fun testOnes_DTypeIsFloat () {
342+ val shape = longArrayOf(2 , 2 )
343+ val tensor = Tensor .ones(shape, DType .FLOAT )
344+ val data = tensor.dataAsFloatArray
345+ assertEquals(DType .FLOAT , tensor.dtype())
346+ for (i in shape.indices) {
347+ assertEquals(shape[i], tensor.shape[i])
348+ }
349+ for (i in data.indices) {
350+ assertEquals(data[i], 1.0f , 1e- 5 .toFloat())
351+ }
352+ }
353+
354+ @Test
355+ fun testZeros_DTypeIsFloat () {
356+ val shape = longArrayOf(2 , 2 )
357+ val tensor = Tensor .zeros(shape, DType .FLOAT )
358+ val data = tensor.dataAsFloatArray
359+ assertEquals(DType .FLOAT , tensor.dtype())
360+ for (i in shape.indices) {
361+ assertEquals(shape[i], tensor.shape[i])
362+ }
363+ for (i in data.indices) {
364+ assertEquals(data[i], 0.0f , 1e- 5 .toFloat())
365+ }
366+ }
339367}
0 commit comments