Skip to content

Commit 57acdc8

Browse files
authored
Adding casts to the if test so it passes on GPU (#528)
1 parent ff1e816 commit 57acdc8

File tree

1 file changed

+7
-2
lines changed
  • tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core

1 file changed

+7
-2
lines changed

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.tensorflow.Session;
2828
import org.tensorflow.Signature;
2929
import org.tensorflow.op.Ops;
30+
import org.tensorflow.types.TFloat32;
3031
import org.tensorflow.types.TInt32;
3132

3233
public class IfTest {
@@ -37,15 +38,19 @@ private static Operand<TInt32> basicIf(Ops tf, Operand<TInt32> a, Operand<TInt32
3738
(ops) -> {
3839
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
3940
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
40-
return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build();
41+
Operand<TInt32> y = ops.identity(a1);
42+
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
4143
});
4244

4345
ConcreteFunction elseBranch =
4446
ConcreteFunction.create(
4547
(ops) -> {
4648
Operand<TInt32> a1 = ops.placeholder(TInt32.class);
4749
Operand<TInt32> b1 = ops.placeholder(TInt32.class);
48-
Operand<TInt32> y = ops.math.neg(b1);
50+
// Casts around the math.neg operator as it's not implemented correctly for int32 in
51+
// GPUs at some point between TF 2.10 and TF 2.15.
52+
Operand<TInt32> y =
53+
ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(b1, TFloat32.class)), TInt32.class);
4954
return Signature.builder().input("a", a1).input("b", b1).output("y", y).build();
5055
});
5156

0 commit comments

Comments
 (0)