Skip to content

Commit 56ec76c

Browse files
committed
Android: Add Tensor.ones() and Tensor.zeros() factory method to create tensors initialized with ones and zeros resp. (#15125)
1 parent 964515c commit 56ec76c

File tree

2 files changed

+122
-18
lines changed
  • extension/android/executorch_android/src

2 files changed

+122
-18
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
392392
return new Tensor_float64(data, shape);
393393
}
394394

395+
/**
396+
* Creates a new Tensor instance with given data-type and all elements initialized to one.
397+
*
398+
* @param shape Tensor shape
399+
* @param dtype Tensor data-type
400+
*/
401+
public static Tensor ones(long[] shape, DType dtype) {
402+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
403+
checkShape(shape);
404+
int numElements = (int) numel(shape);
405+
switch (dtype) {
406+
case UINT8:
407+
byte[] uInt8Data = new byte[numElements];
408+
Arrays.fill(uInt8Data, (byte) 1);
409+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
410+
case INT8:
411+
byte[] int8Data = new byte[numElements];
412+
Arrays.fill(int8Data, (byte) 1);
413+
return Tensor.fromBlob(int8Data, shape);
414+
case INT32:
415+
int[] int32Data = new int[numElements];
416+
Arrays.fill(int32Data, 1);
417+
return Tensor.fromBlob(int32Data, shape);
418+
case FLOAT:
419+
float[] float32Data = new float[numElements];
420+
Arrays.fill(float32Data, 1.0f);
421+
return Tensor.fromBlob(float32Data, shape);
422+
case INT64:
423+
long[] int64Data = new long[numElements];
424+
Arrays.fill(int64Data, 1L);
425+
return Tensor.fromBlob(int64Data, shape);
426+
case DOUBLE:
427+
double[] float64Data = new double[numElements];
428+
Arrays.fill(float64Data, 1.0);
429+
return Tensor.fromBlob(float64Data, shape);
430+
default:
431+
throw new IllegalArgumentException(
432+
String.format("Tensor.ones() cannot be used with DType %s", dtype));
433+
}
434+
}
435+
436+
/**
437+
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
438+
*
439+
* @param shape Tensor shape
440+
* @param dtype Tensor data-type
441+
*/
442+
public static Tensor zeros(long[] shape, DType dtype) {
443+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
444+
checkShape(shape);
445+
int numElements = (int) numel(shape);
446+
switch (dtype) {
447+
case UINT8:
448+
byte[] uInt8Data = new byte[numElements];
449+
return Tensor.fromBlobUnsigned(uInt8Data, shape);
450+
case INT8:
451+
byte[] int8Data = new byte[numElements];
452+
return Tensor.fromBlob(int8Data, shape);
453+
case INT32:
454+
int[] int32Data = new int[numElements];
455+
return Tensor.fromBlob(int32Data, shape);
456+
case FLOAT:
457+
float[] float32Data = new float[numElements];
458+
return Tensor.fromBlob(float32Data, shape);
459+
case INT64:
460+
long[] int64Data = new long[numElements];
461+
return Tensor.fromBlob(int64Data, shape);
462+
case DOUBLE:
463+
double[] float64Data = new double[numElements];
464+
return Tensor.fromBlob(float64Data, shape);
465+
default:
466+
throw new IllegalArgumentException(
467+
String.format("Tensor.zeros() cannot be used with DType %s", dtype));
468+
}
469+
}
470+
395471
@DoNotStrip private HybridData mHybridData;
396472

397473
private Tensor(long[] shape) {

extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -254,24 +254,24 @@ class TensorTest {
254254
assertEquals(tensor.dtype(), DType.FLOAT)
255255

256256
assertThatThrownBy { tensor.dataAsByteArray }
257-
.isInstanceOf(IllegalStateException::class.java)
258-
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")
257+
.isInstanceOf(IllegalStateException::class.java)
258+
.hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.")
259259

260260
assertThatThrownBy { tensor.dataAsUnsignedByteArray }
261-
.isInstanceOf(IllegalStateException::class.java)
262-
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")
261+
.isInstanceOf(IllegalStateException::class.java)
262+
.hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.")
263263

264264
assertThatThrownBy { tensor.dataAsIntArray }
265-
.isInstanceOf(IllegalStateException::class.java)
266-
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")
265+
.isInstanceOf(IllegalStateException::class.java)
266+
.hasMessage("Tensor of type Tensor_float32 cannot return data as int array.")
267267

268268
assertThatThrownBy { tensor.dataAsDoubleArray }
269-
.isInstanceOf(IllegalStateException::class.java)
270-
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")
269+
.isInstanceOf(IllegalStateException::class.java)
270+
.hasMessage("Tensor of type Tensor_float32 cannot return data as double array.")
271271

272272
assertThatThrownBy { tensor.dataAsLongArray }
273-
.isInstanceOf(IllegalStateException::class.java)
274-
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
273+
.isInstanceOf(IllegalStateException::class.java)
274+
.hasMessage("Tensor of type Tensor_float32 cannot return data as long array.")
275275
}
276276

277277
@Test
@@ -281,20 +281,20 @@ class TensorTest {
281281
val mismatchShape = longArrayOf(1, 2)
282282

283283
assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) }
284-
.isInstanceOf(IllegalArgumentException::class.java)
285-
.hasMessage("Data array must be not null")
284+
.isInstanceOf(IllegalArgumentException::class.java)
285+
.hasMessage("Data array must be not null")
286286

287287
assertThatThrownBy { Tensor.fromBlob(data, null) }
288-
.isInstanceOf(IllegalArgumentException::class.java)
289-
.hasMessage("Shape must be not null")
288+
.isInstanceOf(IllegalArgumentException::class.java)
289+
.hasMessage("Shape must be not null")
290290

291291
assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) }
292-
.isInstanceOf(IllegalArgumentException::class.java)
293-
.hasMessage("Shape elements must be non negative")
292+
.isInstanceOf(IllegalArgumentException::class.java)
293+
.hasMessage("Shape elements must be non negative")
294294

295295
assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) }
296-
.isInstanceOf(IllegalArgumentException::class.java)
297-
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
296+
.isInstanceOf(IllegalArgumentException::class.java)
297+
.hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]")
298298
}
299299

300300
@Test
@@ -336,4 +336,32 @@ class TensorTest {
336336
assertEquals(shape[i], deserShape[i])
337337
}
338338
}
339+
340+
@Test
341+
fun testOnes_DTypeIsFloat() {
342+
val shape = longArrayOf(2, 2)
343+
val tensor = Tensor.ones(shape, DType.FLOAT)
344+
val data = tensor.dataAsFloatArray
345+
assertEquals(DType.FLOAT, tensor.dtype())
346+
for (i in shape.indices) {
347+
assertEquals(shape[i], tensor.shape[i])
348+
}
349+
for (i in data.indices) {
350+
assertEquals(data[i], 1.0f, 1e-5.toFloat())
351+
}
352+
}
353+
354+
@Test
355+
fun testZeros_DTypeIsFloat() {
356+
val shape = longArrayOf(2, 2)
357+
val tensor = Tensor.zeros(shape, DType.FLOAT)
358+
val data = tensor.dataAsFloatArray
359+
assertEquals(DType.FLOAT, tensor.dtype())
360+
for (i in shape.indices) {
361+
assertEquals(shape[i], tensor.shape[i])
362+
}
363+
for (i in data.indices) {
364+
assertEquals(data[i], 0.0f, 1e-5.toFloat())
365+
}
366+
}
339367
}

0 commit comments

Comments
 (0)