diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 98b34c4ab..b2f592c9a 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -272,13 +272,13 @@ def minus(a, b): ] JAX_BITWISE_OP_RECORDS = [ - op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, + op_record("bitwise_and", 2, int_dtypes, all_shapes, jtu.rand_default, []), - op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, + op_record("bitwise_not", 1, int_dtypes, all_shapes, jtu.rand_default, []), - op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, + op_record("bitwise_or", 2, int_dtypes, all_shapes, jtu.rand_default, []), - op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, + op_record("bitwise_xor", 2, int_dtypes, all_shapes, jtu.rand_default, []), ]