Skip to content

Commit

Permalink
fixing test errors for numpy 1.22
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 12, 2024
1 parent 2f8c081 commit 8f9d3ad
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ragged

has_complex_dtype = True
numpy_has_array_api = False

# if np.lib.NumpyVersion(np.__version__) < "2.0.0b1":
# with warnings.catch_warnings():
Expand All @@ -30,11 +31,13 @@
# else:
# xp = np
devices = ["cpu"]

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

numpy_has_array_api = True
has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
except ModuleNotFoundError:
import numpy as xp # noqa: ICN001
Expand Down Expand Up @@ -403,8 +406,23 @@ def test_ceil(device, x):
assert xp.ceil(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api, reason=f"testing in numpy version {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int_1(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == np.array(first(result)).astype(first(result).dtype)
assert xp.ceil(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
numpy_has_array_api, reason=f"testing in numpy version {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int(device, x_int):
def test_ceil_int_2(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
Expand Down Expand Up @@ -516,12 +534,31 @@ def test_floor(device, x):
assert xp.floor(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api, reason=f"testing in numpy version {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_floor_int(device, x_int):
def test_floor_int_1(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
first(result).dtype
)
assert xp.floor(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
numpy_has_array_api, reason=f"testing in numpy version {np.__version__}"
)
@pytest.mark.parametrize("device", devices)
def test_floor_int_2(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
_wrapper(first(result).dtype)
)
assert xp.floor(first(x_int)).dtype == _wrapper(result.dtype)


Expand Down

0 comments on commit 8f9d3ad

Please sign in to comment.