diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 40bfe563fac..e8c0a918b13 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -392,6 +392,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) { return new Tensor_float64(data, shape); } + /** + * Creates a new Tensor instance with given data-type and all elements initialized to one. + * + * @param shape Tensor shape + * @param dtype Tensor data-type + */ + public static Tensor ones(long[] shape, DType dtype) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + int numElements = (int) numel(shape); + switch (dtype) { + case UINT8: + byte[] uInt8Data = new byte[numElements]; + Arrays.fill(uInt8Data, (byte) 1); + return Tensor.fromBlobUnsigned(uInt8Data, shape); + case INT8: + byte[] int8Data = new byte[numElements]; + Arrays.fill(int8Data, (byte) 1); + return Tensor.fromBlob(int8Data, shape); + case INT32: + int[] int32Data = new int[numElements]; + Arrays.fill(int32Data, 1); + return Tensor.fromBlob(int32Data, shape); + case FLOAT: + float[] float32Data = new float[numElements]; + Arrays.fill(float32Data, 1.0f); + return Tensor.fromBlob(float32Data, shape); + case INT64: + long[] int64Data = new long[numElements]; + Arrays.fill(int64Data, 1L); + return Tensor.fromBlob(int64Data, shape); + case DOUBLE: + double[] float64Data = new double[numElements]; + Arrays.fill(float64Data, 1.0); + return Tensor.fromBlob(float64Data, shape); + default: + throw new IllegalArgumentException( + String.format("Tensor.ones() cannot be used with DType %s", dtype)); + } + } + + /** + * Creates a new Tensor instance with given data-type and all elements initialized to zero. + * + * @param shape Tensor shape + * @param dtype Tensor data-type + */ + public static Tensor zeros(long[] shape, DType dtype) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + int numElements = (int) numel(shape); + switch (dtype) { + case UINT8: + byte[] uInt8Data = new byte[numElements]; + return Tensor.fromBlobUnsigned(uInt8Data, shape); + case INT8: + byte[] int8Data = new byte[numElements]; + return Tensor.fromBlob(int8Data, shape); + case INT32: + int[] int32Data = new int[numElements]; + return Tensor.fromBlob(int32Data, shape); + case FLOAT: + float[] float32Data = new float[numElements]; + return Tensor.fromBlob(float32Data, shape); + case INT64: + long[] int64Data = new long[numElements]; + return Tensor.fromBlob(int64Data, shape); + case DOUBLE: + double[] float64Data = new double[numElements]; + return Tensor.fromBlob(float64Data, shape); + default: + throw new IllegalArgumentException( + String.format("Tensor.zeros() cannot be used with DType %s", dtype)); + } + } + @DoNotStrip private HybridData mHybridData; private Tensor(long[] shape) { diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index 98a8a97822e..f1947ef8aef 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -254,24 +254,24 @@ class TensorTest { assertEquals(tensor.dtype(), DType.FLOAT) assertThatThrownBy { tensor.dataAsByteArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") assertThatThrownBy { tensor.dataAsUnsignedByteArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") assertThatThrownBy { tensor.dataAsIntArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") assertThatThrownBy { tensor.dataAsDoubleArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") assertThatThrownBy { tensor.dataAsLongArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") } @Test @@ -281,20 +281,20 @@ class TensorTest { val mismatchShape = longArrayOf(1, 2) assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Data array must be not null") + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Data array must be not null") assertThatThrownBy { Tensor.fromBlob(data, null) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Shape must be not null") + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape must be not null") assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Shape elements must be non negative") + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape elements must be non negative") assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") } @Test @@ -336,4 +336,32 @@ class TensorTest { assertEquals(shape[i], deserShape[i]) } } + + @Test + fun testOnes_DTypeIsFloat() { + val shape = longArrayOf(2, 2) + val tensor = Tensor.ones(shape, DType.FLOAT) + val data = tensor.dataAsFloatArray + assertEquals(DType.FLOAT, tensor.dtype()) + for (i in shape.indices) { + assertEquals(shape[i], tensor.shape[i]) + } + for (i in data.indices) { + assertEquals(data[i], 1.0f, 1e-5.toFloat()) + } + } + + @Test + fun testZeros_DTypeIsFloat() { + val shape = longArrayOf(2, 2) + val tensor = Tensor.zeros(shape, DType.FLOAT) + val data = tensor.dataAsFloatArray + assertEquals(DType.FLOAT, tensor.dtype()) + for (i in shape.indices) { + assertEquals(shape[i], tensor.shape[i]) + } + for (i in data.indices) { + assertEquals(data[i], 0.0f, 1e-5.toFloat()) + } + } }