Skip to content

Commit

Permalink
Throw exception on custom gradient registration on Windows (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
karllessard authored Feb 17, 2023
1 parent dc25607 commit a447b4b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.PointerPointer;
Expand Down Expand Up @@ -193,6 +194,9 @@ private static synchronized boolean hasGradient(String opType) {
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
* of unified gradient support once it is fully supported by tensorflow core.
*
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
*
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
* field, i.e. {@link Add#OP_NAME}.
* @param gradient the gradient function to use
Expand All @@ -201,6 +205,10 @@ private static synchronized boolean hasGradient(String opType) {
*/
public static synchronized boolean registerCustomGradient(
String opType, RawCustomGradient gradient) {
if (isWindowsOs()) {
throw new UnsupportedOperationException(
"Custom gradient registration is not supported on Windows systems.");
}
if (hasGradient(opType)) {
return false;
}
Expand All @@ -216,6 +224,9 @@ public static synchronized boolean registerCustomGradient(
* generated op classes or custom op classes with the correct annotations. To operate on the
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
*
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
*
* @param inputClass the inputs class of op to register the gradient for.
* @param gradient the gradient function to use
* @return {@code true} if the gradient was registered, {@code false} if there was already a
Expand All @@ -225,8 +236,11 @@ public static synchronized boolean registerCustomGradient(
*/
public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(
Class<T> inputClass, CustomGradient<T> gradient) {
if (isWindowsOs()) {
throw new UnsupportedOperationException(
"Custom gradient registration is not supported on Windows systems.");
}
OpInputsMetadata metadata = inputClass.getAnnotation(OpInputsMetadata.class);

if (metadata == null) {
throw new IllegalArgumentException(
"Inputs Class "
Expand All @@ -253,4 +267,8 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
gradientFuncs.add(g);
return true;
}

private static boolean isWindowsOs() {
return System.getProperty("os.name", "").toLowerCase(Locale.ENGLISH).startsWith("win");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Arrays;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.EnabledOnOs;
import org.junit.jupiter.api.condition.OS;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Ops;
Expand All @@ -44,8 +45,8 @@ public void testAlreadyExisting() {
}));
}

// FIXME: Since TF 2.10.1, this test is failing on Windows, because the whole JVM crashes when
// calling the JavaCPP generated binding `NameMap.erase`. Disable it until we find a fix.
// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
// https://github.com/tensorflow/java/issues/486
@DisabledOnOs(OS.WINDOWS)
@Test
public void testCustomGradient() {
Expand Down Expand Up @@ -76,6 +77,26 @@ public void testCustomGradient() {
}
}

@EnabledOnOs(OS.WINDOWS)
@Test
public void testCustomGradientThrowsOnWindows() {
assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.OP_NAME,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));

assertThrows(
UnsupportedOperationException.class,
() ->
TensorFlow.registerCustomGradient(
NthElement.Inputs.class,
(tf, op, gradInputs) ->
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
}

private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}
Expand Down

0 comments on commit a447b4b

Please sign in to comment.