Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
return new Tensor_float64(data, shape);
}

/**
* Creates a new Tensor instance with given data-type and all elements initialized to one.
*
* @param shape Tensor shape
* @param dtype Tensor data-type
*/
public static Tensor ones(long[] shape, DType dtype) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
int numElements = (int) numel(shape);
switch (dtype) {
case UINT8:
byte[] uInt8Data = new byte[numElements];
Arrays.fill(uInt8Data, (byte) 1);
return Tensor.fromBlobUnsigned(uInt8Data, shape);
case INT8:
byte[] int8Data = new byte[numElements];
Arrays.fill(int8Data, (byte) 1);
return Tensor.fromBlob(int8Data, shape);
case INT32:
int[] int32Data = new int[numElements];
Arrays.fill(int32Data, 1);
return Tensor.fromBlob(int32Data, shape);
case FLOAT:
float[] float32Data = new float[numElements];
Arrays.fill(float32Data, 1.0f);
return Tensor.fromBlob(float32Data, shape);
case INT64:
long[] int64Data = new long[numElements];
Arrays.fill(int64Data, 1L);
return Tensor.fromBlob(int64Data, shape);
case DOUBLE:
double[] float64Data = new double[numElements];
Arrays.fill(float64Data, 1.0);
return Tensor.fromBlob(float64Data, shape);
default:
throw new IllegalArgumentException(
String.format("Tensor.ones() cannot be used with DType %s", dtype));
}
}

/**
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
*
* @param shape Tensor shape
* @param dtype Tensor data-type
*/
public static Tensor zeros(long[] shape, DType dtype) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
int numElements = (int) numel(shape);
switch (dtype) {
case UINT8:
byte[] uInt8Data = new byte[numElements];
return Tensor.fromBlobUnsigned(uInt8Data, shape);
case INT8:
byte[] int8Data = new byte[numElements];
return Tensor.fromBlob(int8Data, shape);
case INT32:
int[] int32Data = new int[numElements];
return Tensor.fromBlob(int32Data, shape);
case FLOAT:
float[] float32Data = new float[numElements];
return Tensor.fromBlob(float32Data, shape);
case INT64:
long[] int64Data = new long[numElements];
return Tensor.fromBlob(int64Data, shape);
case DOUBLE:
double[] float64Data = new double[numElements];
return Tensor.fromBlob(float64Data, shape);
default:
throw new IllegalArgumentException(
String.format("Tensor.zeros() cannot be used with DType %s", dtype));
}
}

@DoNotStrip private HybridData mHybridData;

private Tensor(long[] shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,24 +254,24 @@ class TensorTest {
assertEquals(tensor.dtype(), DType.FLOAT)

assertThatThrownBy { tensor.dataAsByteArray }
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")

assertThatThrownBy { tensor.dataAsUnsignedByteArray }
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")

assertThatThrownBy { tensor.dataAsIntArray }
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")

assertThatThrownBy { tensor.dataAsDoubleArray }
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")

assertThatThrownBy { tensor.dataAsLongArray }
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
.isInstanceOf(IllegalStateException::class.java)
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
}

@Test
Expand All @@ -281,20 +281,20 @@ class TensorTest {
val mismatchShape = longArrayOf(1, 2)

assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Data array must be not null")
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Data array must be not null")

assertThatThrownBy { Tensor.fromBlob(data, null) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shape must be not null")
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shape must be not null")

assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shape elements must be non negative")
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shape elements must be non negative")

assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
}

@Test
Expand Down Expand Up @@ -336,4 +336,32 @@ class TensorTest {
assertEquals(shape[i], deserShape[i])
}
}

@Test
fun testOnes_DTypeIsFloat() {
val shape = longArrayOf(2, 2)
val tensor = Tensor.ones(shape, DType.FLOAT)
val data = tensor.dataAsFloatArray
assertEquals(DType.FLOAT, tensor.dtype())
for (i in shape.indices) {
assertEquals(shape[i], tensor.shape[i])
}
for (i in data.indices) {
assertEquals(data[i], 1.0f, 1e-5.toFloat())
}
}

@Test
fun testZeros_DTypeIsFloat() {
val shape = longArrayOf(2, 2)
val tensor = Tensor.zeros(shape, DType.FLOAT)
val data = tensor.dataAsFloatArray
assertEquals(DType.FLOAT, tensor.dtype())
for (i in shape.indices) {
assertEquals(shape[i], tensor.shape[i])
}
for (i in data.indices) {
assertEquals(data[i], 0.0f, 1e-5.toFloat())
}
}
}
Loading