From 2bf10753b332c4f502696c781a4619409fd18b0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 12 Aug 2024 20:15:01 -0700 Subject: [PATCH] [numpy] Fix test failures under NumPy 2.0. PiperOrigin-RevId: 662343913 --- trax/tf_numpy/jax_tests/test_util.py | 4 ++-- trax/tf_numpy/numpy_impl/dtypes.py | 4 ++-- trax/tf_numpy/numpy_impl/tests/array_ops_test.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py index f70f968a7..0110f0db6 100644 --- a/trax/tf_numpy/jax_tests/test_util.py +++ b/trax/tf_numpy/jax_tests/test_util.py @@ -89,8 +89,8 @@ python_scalar_dtypes = { bool: onp.dtype(onp.bool_), int: onp.dtype(onp.int_), - float: onp.dtype(onp.float_), - complex: onp.dtype(onp.complex_), + float: onp.dtype(onp.float64), + complex: onp.dtype(onp.complex128), } diff --git a/trax/tf_numpy/numpy_impl/dtypes.py b/trax/tf_numpy/numpy_impl/dtypes.py index 0f696d792..424242b25 100644 --- a/trax/tf_numpy/numpy_impl/dtypes.py +++ b/trax/tf_numpy/numpy_impl/dtypes.py @@ -31,13 +31,13 @@ from numpy import uint32 from numpy import uint64 from numpy import uint8 -from numpy import float_ from numpy import float16 from numpy import float32 from numpy import float64 -from numpy import complex_ +float_ = float64 from numpy import complex64 from numpy import complex128 +complex_ = complex128 from numpy import inexact diff --git a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py b/trax/tf_numpy/numpy_impl/tests/array_ops_test.py index e8130d29e..3080a63e7 100644 --- a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py +++ b/trax/tf_numpy/numpy_impl/tests/array_ops_test.py @@ -261,7 +261,7 @@ def testArray(self): self.all_arrays, self.all_types, ndmins, [True, False]): self.match( array_ops.array(a, dtype=dtype, ndmin=ndmin, copy=copy), - np.array(a, dtype=dtype, ndmin=ndmin, copy=copy)) + np.array(a, dtype=dtype, ndmin=ndmin)) zeros_list = array_ops.zeros(5)