diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 98b34c4ab..b1eaf0980 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -500,6 +500,8 @@ def f(): *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) + @unittest.skipIf(onp.__version__ >= onp.lib.NumpyVersion('2.0.0'), + 'tf numpy is implemented to be numpy 1.x compatible') def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, tolerance, inexact, check_incomplete_shape): # TODO(b/147769803): Remove this skipping