Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: port to numpy 2.0 #60

Merged
merged 20 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"awkward>=2.5.0",
"awkward>=2.6.7",
]

[project.optional-dependencies]
Expand Down
20 changes: 20 additions & 0 deletions src/ragged/_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE
from __future__ import annotations

import numpy as np


def regularise_to_float(t: np.dtype, /) -> np.dtype:
# Ensure compatibility with numpy 2.0.0
if np.__version__ >= "2.1":
# Just pass and return the input type if the numpy version is not 2.0.0
return t

if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t
8 changes: 4 additions & 4 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def __str__(self) -> str:
if len(self._shape) == 0:
return f"{self._impl}"
elif len(self._shape) == 1:
return f"{ak._prettyprint.valuestr(self._impl, 1, 80)}"
return f"{ak.prettyprint.valuestr(self._impl, 1, 80)}"
else:
prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
"\n ", "\n "
)
return f"[\n {prep}\n]"
Expand All @@ -259,9 +259,9 @@ def __repr__(self) -> str:
if len(self._shape) == 0:
return f"ragged.array({self._impl})"
elif len(self._shape) == 1:
return f"ragged.array({ak._prettyprint.valuestr(self._impl, 1, 80 - 14)})"
return f"ragged.array({ak.prettyprint.valuestr(self._impl, 1, 80 - 14)})"
else:
prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
"\n ", "\n "
)
return f"ragged.array([\n {prep}\n])"
Expand Down
5 changes: 3 additions & 2 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np

from ._helper_functions import regularise_to_float
from ._spec_array_object import _box, _unbox, array


Expand Down Expand Up @@ -414,7 +415,7 @@ def ceil(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html
"""

return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype)
return _box(type(x), np.ceil(*_unbox(x)), dtype=regularise_to_float(x.dtype))


def conj(x: array, /) -> array:
Expand Down Expand Up @@ -586,7 +587,7 @@ def floor(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html
"""

return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype)
return _box(type(x), np.floor(*_unbox(x)), dtype=regularise_to_float(x.dtype))


def floor_divide(x1: array, x2: array, /) -> array:
Expand Down
72 changes: 61 additions & 11 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,27 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

import pytest

import ragged
from ragged._helper_functions import regularise_to_float

has_complex_dtype = True
numpy_has_array_api = False

devices = ["cpu"]

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

numpy_has_array_api = True
has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
except ModuleNotFoundError:
import numpy as xp # noqa: ICN001

try:
import cupy as cp

Expand Down Expand Up @@ -374,17 +388,34 @@ def test_ceil(device, x):
assert xp.ceil(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api,
reason=f"testing only in numpy version 1, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int(device, x_int):
def test_ceil_int_1(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result)
assert xp.ceil(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
numpy_has_array_api,
reason=f"testing only in numpy version 2, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int_2(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result).astype(
regularise_to_float(first(result).dtype)
)
assert xp.ceil(first(x_int)).dtype == regularise_to_float(result.dtype)


@pytest.mark.skipif(
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -487,13 +518,32 @@ def test_floor(device, x):
assert xp.floor(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api,
reason=f"testing only in numpy version 1, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_floor_int(device, x_int):
def test_floor_int_1(device, x_int):
result = ragged.floor(
x_int.to_device(device)
) # always returns float64 regardless of x_int.dtype
assert type(result) is type(x_int)
assert result.shape == x_int.shape


@pytest.mark.skipif(
numpy_has_array_api,
reason=f"testing only in numpy version 2, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_floor_int_2(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result)
assert xp.floor(first(x_int)).dtype == result.dtype
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
regularise_to_float(first(result).dtype)
)
assert xp.floor(first(x_int)).dtype == regularise_to_float(result.dtype)


@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -571,7 +621,7 @@ def test_greater_equal_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -838,7 +888,7 @@ def test_pow_inplace_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -888,7 +938,7 @@ def test_round(device, x):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down
Loading