Skip to content

Commit

Permalink
[numpy] Fix test failures under NumPy 2.0.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662343913
  • Loading branch information
hawkinsp authored and copybara-github committed Aug 13, 2024
1 parent df631ba commit 2bf1075
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
python_scalar_dtypes = {
bool: onp.dtype(onp.bool_),
int: onp.dtype(onp.int_),
float: onp.dtype(onp.float_),
complex: onp.dtype(onp.complex_),
float: onp.dtype(onp.float64),
complex: onp.dtype(onp.complex128),
}


Expand Down
4 changes: 2 additions & 2 deletions trax/tf_numpy/numpy_impl/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
from numpy import uint32
from numpy import uint64
from numpy import uint8
from numpy import float_
from numpy import float16
from numpy import float32
from numpy import float64
from numpy import complex_
float_ = float64
from numpy import complex64
from numpy import complex128
complex_ = complex128

from numpy import inexact

Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/numpy_impl/tests/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def testArray(self):
self.all_arrays, self.all_types, ndmins, [True, False]):
self.match(
array_ops.array(a, dtype=dtype, ndmin=ndmin, copy=copy),
np.array(a, dtype=dtype, ndmin=ndmin, copy=copy))
np.array(a, dtype=dtype, ndmin=ndmin))

zeros_list = array_ops.zeros(5)

Expand Down

0 comments on commit 2bf1075

Please sign in to comment.