From e100df79c8f7d3b3be742eeb47300a68176d4e6f Mon Sep 17 00:00:00 2001 From: Trax Team Date: Mon, 9 Sep 2024 16:22:07 -0700 Subject: [PATCH] disable testOp when numpy 2.0 is installed PiperOrigin-RevId: 672712350 --- trax/tf_numpy/jax_tests/lax_numpy_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 98b34c4ab..b1eaf0980 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -500,6 +500,8 @@ def f(): *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) + @unittest.skipIf(onp.__version__ >= onp.lib.NumpyVersion('2.0.0'), + 'tf numpy is implemented to be numpy 1.x compatible') def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, tolerance, inexact, check_incomplete_shape): # TODO(b/147769803): Remove this skipping