Skip to content

Commit

Permalink
chore: port to numpy 2.0 (#60)
Browse files Browse the repository at this point in the history
* move to local branch

* fix array_api import at test_spec_elementwise_functions.py

* adding _wrapper in ceil function at _spec_elementwise_functions.py

* implementing a wrapper to fix test errors

* fixing import, trying ceil & floor w\o wrapper

* pre-commit fixes

* reverting wrapper removal in floor and ceil

* fixing lower bound in the same PR

* fixing test errors for numpy 1.22

* skip wrapping of dtype in numpy 1

* adding helper file & floor, ceil tests for numpy 1 changes

* pre-commit changes

* fix: cleanup

* fix: prettyprint

* fix: pin numpy to a stable version for now

* revert changes

* fix: fix to numpy 2.0 stable version

* fix: ensure compatibility with numpy 2.0.0

* fix: pre-commit

* fix: numpy version

---------

Co-authored-by: ohrechykha <[email protected]>
  • Loading branch information
ianna and ohrechykha authored Aug 21, 2024
1 parent fff5354 commit a83deb4
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"awkward>=2.5.0",
"awkward>=2.6.7",
]

[project.optional-dependencies]
Expand Down
20 changes: 20 additions & 0 deletions src/ragged/_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE
from __future__ import annotations

import numpy as np


def regularise_to_float(t: np.dtype, /) -> np.dtype:
# Ensure compatibility with numpy 2.0.0
if np.__version__ >= "2.1":
# Just pass and return the input type if the numpy version is not 2.0.0
return t

if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t
8 changes: 4 additions & 4 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def __str__(self) -> str:
if len(self._shape) == 0:
return f"{self._impl}"
elif len(self._shape) == 1:
return f"{ak._prettyprint.valuestr(self._impl, 1, 80)}"
return f"{ak.prettyprint.valuestr(self._impl, 1, 80)}"
else:
prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
"\n ", "\n "
)
return f"[\n {prep}\n]"
Expand All @@ -259,9 +259,9 @@ def __repr__(self) -> str:
if len(self._shape) == 0:
return f"ragged.array({self._impl})"
elif len(self._shape) == 1:
return f"ragged.array({ak._prettyprint.valuestr(self._impl, 1, 80 - 14)})"
return f"ragged.array({ak.prettyprint.valuestr(self._impl, 1, 80 - 14)})"
else:
prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace(
"\n ", "\n "
)
return f"ragged.array([\n {prep}\n])"
Expand Down
5 changes: 3 additions & 2 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np

from ._helper_functions import regularise_to_float
from ._spec_array_object import _box, _unbox, array


Expand Down Expand Up @@ -414,7 +415,7 @@ def ceil(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html
"""

return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype)
return _box(type(x), np.ceil(*_unbox(x)), dtype=regularise_to_float(x.dtype))


def conj(x: array, /) -> array:
Expand Down Expand Up @@ -586,7 +587,7 @@ def floor(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html
"""

return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype)
return _box(type(x), np.floor(*_unbox(x)), dtype=regularise_to_float(x.dtype))


def floor_divide(x1: array, x2: array, /) -> array:
Expand Down
72 changes: 61 additions & 11 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,27 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

import pytest

import ragged
from ragged._helper_functions import regularise_to_float

has_complex_dtype = True
numpy_has_array_api = False

devices = ["cpu"]

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

numpy_has_array_api = True
has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
except ModuleNotFoundError:
import numpy as xp # noqa: ICN001

try:
import cupy as cp

Expand Down Expand Up @@ -374,17 +388,34 @@ def test_ceil(device, x):
assert xp.ceil(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api,
reason=f"testing only in numpy version 1, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int(device, x_int):
def test_ceil_int_1(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result)
assert xp.ceil(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
numpy_has_array_api,
reason=f"testing only in numpy version 2, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_ceil_int_2(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result).astype(
regularise_to_float(first(result).dtype)
)
assert xp.ceil(first(x_int)).dtype == regularise_to_float(result.dtype)


@pytest.mark.skipif(
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -487,13 +518,32 @@ def test_floor(device, x):
assert xp.floor(first(x)).dtype == result.dtype


@pytest.mark.skipif(
not numpy_has_array_api,
reason=f"testing only in numpy version 1, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_floor_int(device, x_int):
def test_floor_int_1(device, x_int):
result = ragged.floor(
x_int.to_device(device)
) # always returns float64 regardless of x_int.dtype
assert type(result) is type(x_int)
assert result.shape == x_int.shape


@pytest.mark.skipif(
numpy_has_array_api,
reason=f"testing only in numpy version 2, but got numpy version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
def test_floor_int_2(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result)
assert xp.floor(first(x_int)).dtype == result.dtype
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
regularise_to_float(first(result).dtype)
)
assert xp.floor(first(x_int)).dtype == regularise_to_float(result.dtype)


@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -571,7 +621,7 @@ def test_greater_equal_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -838,7 +888,7 @@ def test_pow_inplace_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -888,7 +938,7 @@ def test_round(device, x):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down

0 comments on commit a83deb4

Please sign in to comment.