Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fixing awkward lower bound failure with numpy v2 (issue #61) #62

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"awkward>=2.5.0",
"awkward>=2.6.7",
]

[project.optional-dependencies]
Expand Down
24 changes: 22 additions & 2 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,17 @@ def ceil(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html
"""

return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype)
def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t

return _box(type(x), np.ceil(*_unbox(x)), dtype=_wrapper(x.dtype))


def conj(x: array, /) -> array:
Expand Down Expand Up @@ -586,7 +596,17 @@ def floor(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html
"""

return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype)
def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t

return _box(type(x), np.floor(*_unbox(x)), dtype=_wrapper(x.dtype))


def floor_divide(x1: array, x2: array, /) -> array:
Expand Down
47 changes: 38 additions & 9 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,31 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

import pytest

import ragged

has_complex_dtype = True

# if np.lib.NumpyVersion(np.__version__) < "2.0.0b1":
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# import array_api_strict as xp # type: ignore[import-not-found]

# has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
# else:
# xp = np
devices = ["cpu"]
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
except ModuleNotFoundError:
import numpy as xp # noqa: ICN001

try:
import cupy as cp

Expand All @@ -34,6 +52,17 @@ def first(x: ragged.array) -> Any:
return xp.asarray(out.item(), dtype=x.dtype)


def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t


def test_existence():
assert ragged.abs is not None
assert ragged.acos is not None
Expand Down Expand Up @@ -379,12 +408,12 @@ def test_ceil_int(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result)
assert xp.ceil(first(x_int)).dtype == result.dtype
assert xp.ceil(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.ceil(first(x_int)).dtype == _wrapper(result.dtype)


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -492,8 +521,8 @@ def test_floor_int(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result)
assert xp.floor(first(x_int)).dtype == result.dtype
assert xp.floor(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.floor(first(x_int)).dtype == _wrapper(result.dtype)


@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -571,7 +600,7 @@ def test_greater_equal_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -838,7 +867,7 @@ def test_pow_inplace_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -888,7 +917,7 @@ def test_round(device, x):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down
Loading