From 39c9bbbe5fd0bf64f8402e626a389617616e8161 Mon Sep 17 00:00:00 2001 From: Trax Team Date: Wed, 21 Aug 2024 15:23:35 -0700 Subject: [PATCH] skip testLinSpace and testLogSpace if numpy 2.0 is used PiperOrigin-RevId: 666057286 --- trax/tf_numpy/numpy_impl/tests/math_ops_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trax/tf_numpy/numpy_impl/tests/math_ops_test.py b/trax/tf_numpy/numpy_impl/tests/math_ops_test.py index 85f5db434..5922dbc88 100644 --- a/trax/tf_numpy/numpy_impl/tests/math_ops_test.py +++ b/trax/tf_numpy/numpy_impl/tests/math_ops_test.py @@ -15,6 +15,7 @@ """Tests for tf numpy mathematical methods.""" import itertools +import unittest from absl.testing import parameterized import numpy as np from six.moves import range @@ -262,6 +263,8 @@ def run_test(arr, *args, **kwargs): run_test([[1, 2], [3, 4]], axis=-1) run_test([[1, 2], [3, 4]], axis=-2) + @unittest.skipIf(np.__version__ >= np.lib.NumpyVersion('2.0.0'), + 'tf numpy is implemented to be numpy 1.x compatible') def testLinSpace(self): array_transforms = [ lambda x: x, # Identity, @@ -291,6 +294,8 @@ def run_test(start, stop, **kwargs): run_test(0, -1, num=10) run_test(0, -1, endpoint=False) + @unittest.skipIf(np.__version__ >= np.lib.NumpyVersion('2.0.0'), + 'tf numpy is implemented to be numpy 1.x compatible') def testLogSpace(self): array_transforms = [ lambda x: x, # Identity,