3030import org .tensorflow .proto .framework .DataType ;
3131import org .tensorflow .types .TFloat32 ;
3232
33+ // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
34+ // https://github.com/tensorflow/java/issues/486
3335public class CustomGradientTest {
3436
37+ @ EnabledOnOs (OS .WINDOWS )
38+ @ Test
39+ public void customGradientRegistrationUnsupportedOnWindows () {
40+ assertThrows (
41+ UnsupportedOperationException .class ,
42+ () ->
43+ TensorFlow .registerCustomGradient (
44+ NthElement .OP_NAME ,
45+ (tf , op , gradInputs ) ->
46+ Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
47+
48+ assertThrows (
49+ UnsupportedOperationException .class ,
50+ () ->
51+ TensorFlow .registerCustomGradient (
52+ NthElement .Inputs .class ,
53+ (tf , op , gradInputs ) ->
54+ Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
55+ }
56+
57+ @ DisabledOnOs (OS .WINDOWS )
3558 @ Test
3659 public void testAlreadyExisting () {
3760 assertFalse (
@@ -45,8 +68,6 @@ public void testAlreadyExisting() {
4568 }));
4669 }
4770
48- // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
49- // https://github.com/tensorflow/java/issues/486
5071 @ DisabledOnOs (OS .WINDOWS )
5172 @ Test
5273 public void testCustomGradient () {
@@ -77,26 +98,6 @@ public void testCustomGradient() {
7798 }
7899 }
79100
80- @ EnabledOnOs (OS .WINDOWS )
81- @ Test
82- public void testCustomGradientThrowsOnWindows () {
83- assertThrows (
84- UnsupportedOperationException .class ,
85- () ->
86- TensorFlow .registerCustomGradient (
87- NthElement .OP_NAME ,
88- (tf , op , gradInputs ) ->
89- Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
90-
91- assertThrows (
92- UnsupportedOperationException .class ,
93- () ->
94- TensorFlow .registerCustomGradient (
95- NthElement .Inputs .class ,
96- (tf , op , gradInputs ) ->
97- Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
98- }
99-
100101 private static Output <?>[] toArray (Output <?>... outputs ) {
101102 return outputs ;
102103 }
0 commit comments