diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java index 1c5532b1d99..9467c90d061 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java @@ -30,8 +30,31 @@ import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; +// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see +// https://github.com/tensorflow/java/issues/486 public class CustomGradientTest { + @EnabledOnOs(OS.WINDOWS) + @Test + public void customGradientRegistrationUnsupportedOnWindows() { + assertThrows( + UnsupportedOperationException.class, + () -> + TensorFlow.registerCustomGradient( + NthElement.OP_NAME, + (tf, op, gradInputs) -> + Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f)))); + + assertThrows( + UnsupportedOperationException.class, + () -> + TensorFlow.registerCustomGradient( + NthElement.Inputs.class, + (tf, op, gradInputs) -> + Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f)))); + } + + @DisabledOnOs(OS.WINDOWS) @Test public void testAlreadyExisting() { assertFalse( @@ -45,8 +68,6 @@ public void testAlreadyExisting() { })); } - // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see - // https://github.com/tensorflow/java/issues/486 @DisabledOnOs(OS.WINDOWS) @Test public void testCustomGradient() { @@ -77,26 +98,6 @@ public void testCustomGradient() { } } - @EnabledOnOs(OS.WINDOWS) - @Test - public void testCustomGradientThrowsOnWindows() { - assertThrows( - UnsupportedOperationException.class, - () -> - TensorFlow.registerCustomGradient( - NthElement.OP_NAME, - (tf, op, gradInputs) -> - Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f)))); - - assertThrows( - UnsupportedOperationException.class, - () -> - TensorFlow.registerCustomGradient( - NthElement.Inputs.class, - (tf, op, gradInputs) -> - Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f)))); - } - private static Output[] toArray(Output... outputs) { return outputs; }