Skip to content

Commit 3eb3231

Browse files
committed
Fixes java format
1 parent d4d6e48 commit 3eb3231

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

api/src/main/java/ai/djl/modality/Classifications.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,12 @@ public Classifications(List<String> classNames, NDArray probabilities) {
8888
*/
8989
public Classifications(List<String> classNames, NDArray probabilities, int topK) {
9090
this.classNames = classNames;
91-
if (probabilities.getDataType().equals(DataType.FLOAT32)) {
91+
if (probabilities.getDataType() == DataType.FLOAT32) {
9292
// Avoid converting float32 to float64 as this is not supported on MPS device
9393
this.probabilities = new ArrayList<>();
94-
for (float prob : probabilities.toFloatArray())
94+
for (float prob : probabilities.toFloatArray()) {
9595
this.probabilities.add((double) prob);
96+
}
9697
} else {
9798
NDArray array = probabilities.toType(DataType.FLOAT64, false);
9899
this.probabilities =

api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,11 @@ default NDArray toTensor() {
435435
result = result.expandDims(0);
436436
}
437437
// For Apple Silicon MPS it is important not to switch to 64-bit float here
438-
if (result.getDataType().equals(DataType.FLOAT32))
438+
if (result.getDataType() == DataType.FLOAT32) {
439439
result = result.div(255.0f).transpose(0, 3, 1, 2);
440-
else
440+
} else {
441441
result = result.div(255.0).transpose(0, 3, 1, 2);
442+
}
442443
if (dim == 3) {
443444
result = result.squeeze(0);
444445
}

basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,11 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
117117
}
118118

119119
byte[] buf = Utils.toByteArray(is);
120-
try (NDArray array = manager.create(ByteBuffer.wrap(buf),
121-
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) {
120+
try (NDArray array =
121+
manager.create(
122+
ByteBuffer.wrap(buf),
123+
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1),
124+
DataType.UINT8)) {
122125
return array.toType(DataType.FLOAT32, false);
123126
}
124127
}
@@ -131,7 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
131134
}
132135

133136
byte[] buf = Utils.toByteArray(is);
134-
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
137+
try (NDArray array =
138+
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
135139
return array.toType(DataType.FLOAT32, false);
136140
}
137141
}

basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
112112
}
113113

114114
byte[] buf = Utils.toByteArray(is);
115-
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
115+
try (NDArray array =
116+
manager.create(
117+
ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
116118
return array.toType(DataType.FLOAT32, false);
117119
}
118120
}
@@ -124,7 +126,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
124126
throw new AssertionError("Failed skip data.");
125127
}
126128
byte[] buf = Utils.toByteArray(is);
127-
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
129+
try (NDArray array =
130+
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
128131
return array.toType(DataType.FLOAT32, false);
129132
}
130133
}

engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ public void testMps() {
4242
}
4343

4444
private static boolean checkMpsCompatible() {
45-
return "aarch64".equals(System.getProperty("os.arch")) &&
46-
System.getProperty("os.name").startsWith("Mac");
45+
return "aarch64".equals(System.getProperty("os.arch"))
46+
&& System.getProperty("os.name").startsWith("Mac");
4747
}
4848

4949
@Test
@@ -54,9 +54,10 @@ public void testToTensorMPS() {
5454

5555
// Test that toTensor does not fail on MPS (e.g. due to use of float64 for division)
5656
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
57-
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);;
57+
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);
58+
;
5859
NDArray tensor = array.getNDArrayInternal().toTensor();
59-
Assert.assertEquals(tensor.toFloatArray(), new float[]{127f/255f});
60+
Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f});
6061
}
6162
}
6263

@@ -66,16 +67,13 @@ public void testClassificationsMPS() {
6667
throw new SkipException("MPS classification test requires Apple Silicon macOS.");
6768
}
6869

69-
// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to float64)
70-
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
70+
// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to
71+
// float64)
72+
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
7173
List<String> names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth");
72-
NDArray tensor = manager.create(new float[]{0f, 0.125f, 1f, 0.5f, 0.25f});
73-
Classifications classifications = new Classifications(
74-
names,
75-
tensor
76-
);
74+
NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f});
75+
Classifications classifications = new Classifications(names, tensor);
7776
Assert.assertNotNull(classifications.topK(1).equals(Arrays.asList("Third")));
7877
}
7978
}
80-
8179
}

0 commit comments

Comments
 (0)