Skip to content

Commit

Permalink
Disable all custom gradient tests on Windows (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
karllessard authored Feb 18, 2023
1 parent a447b4b commit 60e5473
Showing 1 changed file with 23 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,31 @@
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;

// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
// https://github.com/tensorflow/java/issues/486
public class CustomGradientTest {

@EnabledOnOs(OS.WINDOWS)
@Test
public void customGradientRegistrationUnsupportedOnWindows() {
assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.OP_NAME,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));

assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.Inputs.class,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
}

@DisabledOnOs(OS.WINDOWS)
@Test
public void testAlreadyExisting() {
assertFalse(
Expand All @@ -45,8 +68,6 @@ public void testAlreadyExisting() {
}));
}

// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
// https://github.com/tensorflow/java/issues/486
@DisabledOnOs(OS.WINDOWS)
@Test
public void testCustomGradient() {
Expand Down Expand Up @@ -77,26 +98,6 @@ public void testCustomGradient() {
}
}

@EnabledOnOs(OS.WINDOWS)
@Test
public void testCustomGradientThrowsOnWindows() {
assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.OP_NAME,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));

assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.Inputs.class,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
}

private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}
Expand Down

0 comments on commit 60e5473

Please sign in to comment.