From a83deb4bf4069967cadc51f3c6ea678a81e68fcb Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Wed, 21 Aug 2024 16:38:11 +0200 Subject: [PATCH] chore: port to numpy 2.0 (#60) * move to local branch * fix array_api import at test_spec_elementwise_functions.py * adding _wrapper in ceil function at _spec_elementwise_functions.py * implementing a wrapper to fix test errors * fixing import, trying ceil & floor w\o wrapper * pre-commit fixes * reverting wrapper removal in floor and ceil * fixing lower bound in the same PR * fixing test errors for numpy 1.22 * skip wrapping of dtype in numpy 1 * adding helper file & floor, ceil tests for numpy 1 changes * pre-commit changes * fix: cleanup * fix: prettyprint * fix: pin numpy to a stable version for now * revert changes * fix: fix to numpy 2.0 stable version * fix: ensure compatibility with numpy 2.0.0 * fix: pre-commit * fix: numpy version --------- Co-authored-by: ohrechykha --- pyproject.toml | 2 +- src/ragged/_helper_functions.py | 20 +++++++ src/ragged/_spec_array_object.py | 8 +-- src/ragged/_spec_elementwise_functions.py | 5 +- tests/test_spec_elementwise_functions.py | 72 +++++++++++++++++++---- 5 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 src/ragged/_helper_functions.py diff --git a/pyproject.toml b/pyproject.toml index 3b7a8e8..b566795 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "awkward>=2.5.0", + "awkward>=2.6.7", ] [project.optional-dependencies] diff --git a/src/ragged/_helper_functions.py b/src/ragged/_helper_functions.py new file mode 100644 index 0000000..7688543 --- /dev/null +++ b/src/ragged/_helper_functions.py @@ -0,0 +1,20 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE +from __future__ import annotations + +import numpy as np + + +def regularise_to_float(t: np.dtype, /) -> np.dtype: + # Ensure compatibility with numpy 2.0.0 + if np.__version__ >= "2.1": + # Just pass and return the input type if the numpy version is not 2.0.0 + return t + + if t in [np.int8, np.uint8, np.bool_, bool]: + return np.float16 + elif t in [np.int16, np.uint16]: + return np.float32 + elif t in [np.int32, np.uint32, np.int64, np.uint64]: + return np.float64 + else: + return t diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index e83fe2e..3e27993 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -244,9 +244,9 @@ def __str__(self) -> str: if len(self._shape) == 0: return f"{self._impl}" elif len(self._shape) == 1: - return f"{ak._prettyprint.valuestr(self._impl, 1, 80)}" + return f"{ak.prettyprint.valuestr(self._impl, 1, 80)}" else: - prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace( + prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace( "\n ", "\n " ) return f"[\n {prep}\n]" @@ -259,9 +259,9 @@ def __repr__(self) -> str: if len(self._shape) == 0: return f"ragged.array({self._impl})" elif len(self._shape) == 1: - return f"ragged.array({ak._prettyprint.valuestr(self._impl, 1, 80 - 14)})" + return f"ragged.array({ak.prettyprint.valuestr(self._impl, 1, 80 - 14)})" else: - prep = ak._prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace( + prep = ak.prettyprint.valuestr(self._impl, 20, 80 - 4)[1:-1].replace( "\n ", "\n " ) return f"ragged.array([\n {prep}\n])" diff --git a/src/ragged/_spec_elementwise_functions.py b/src/ragged/_spec_elementwise_functions.py index 3357c6c..46b8f2a 100644 --- a/src/ragged/_spec_elementwise_functions.py +++ b/src/ragged/_spec_elementwise_functions.py @@ -10,6 +10,7 @@ import numpy as np +from ._helper_functions import regularise_to_float from ._spec_array_object import _box, _unbox, array @@ -414,7 +415,7 @@ def ceil(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html """ - return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype) + return _box(type(x), np.ceil(*_unbox(x)), dtype=regularise_to_float(x.dtype)) def conj(x: array, /) -> array: @@ -586,7 +587,7 @@ def floor(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html """ - return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype) + return _box(type(x), np.floor(*_unbox(x)), dtype=regularise_to_float(x.dtype)) def floor_divide(x1: array, x2: array, /) -> array: diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index fdac4a9..c8e03a4 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -14,13 +14,27 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import numpy.array_api as xp import pytest import ragged +from ragged._helper_functions import regularise_to_float + +has_complex_dtype = True +numpy_has_array_api = False devices = ["cpu"] + +try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import numpy.array_api as xp + + numpy_has_array_api = True + has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes +except ModuleNotFoundError: + import numpy as xp # noqa: ICN001 + try: import cupy as cp @@ -374,17 +388,34 @@ def test_ceil(device, x): assert xp.ceil(first(x)).dtype == result.dtype +@pytest.mark.skipif( + not numpy_has_array_api, + reason=f"testing only in numpy version 1, but got numpy version {np.__version__}", +) @pytest.mark.parametrize("device", devices) -def test_ceil_int(device, x_int): +def test_ceil_int_1(device, x_int): result = ragged.ceil(x_int.to_device(device)) assert type(result) is type(x_int) assert result.shape == x_int.shape - assert xp.ceil(first(x_int)) == first(result) - assert xp.ceil(first(x_int)).dtype == result.dtype @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + numpy_has_array_api, + reason=f"testing only in numpy version 2, but got numpy version {np.__version__}", +) +@pytest.mark.parametrize("device", devices) +def test_ceil_int_2(device, x_int): + result = ragged.ceil(x_int.to_device(device)) + assert type(result) is type(x_int) + assert result.shape == x_int.shape + assert xp.ceil(first(x_int)) == first(result).astype( + regularise_to_float(first(result).dtype) + ) + assert xp.ceil(first(x_int)).dtype == regularise_to_float(result.dtype) + + +@pytest.mark.skipif( + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -487,13 +518,32 @@ def test_floor(device, x): assert xp.floor(first(x)).dtype == result.dtype +@pytest.mark.skipif( + not numpy_has_array_api, + reason=f"testing only in numpy version 1, but got numpy version {np.__version__}", +) @pytest.mark.parametrize("device", devices) -def test_floor_int(device, x_int): +def test_floor_int_1(device, x_int): + result = ragged.floor( + x_int.to_device(device) + ) # always returns float64 regardless of x_int.dtype + assert type(result) is type(x_int) + assert result.shape == x_int.shape + + +@pytest.mark.skipif( + numpy_has_array_api, + reason=f"testing only in numpy version 2, but got numpy version {np.__version__}", +) +@pytest.mark.parametrize("device", devices) +def test_floor_int_2(device, x_int): result = ragged.floor(x_int.to_device(device)) assert type(result) is type(x_int) assert result.shape == x_int.shape - assert xp.floor(first(x_int)) == first(result) - assert xp.floor(first(x_int)).dtype == result.dtype + assert xp.floor(first(x_int)) == np.asarray(first(result)).astype( + regularise_to_float(first(result).dtype) + ) + assert xp.floor(first(x_int)).dtype == regularise_to_float(result.dtype) @pytest.mark.parametrize("device", devices) @@ -571,7 +621,7 @@ def test_greater_equal_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -838,7 +888,7 @@ def test_pow_inplace_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -888,7 +938,7 @@ def test_round(device, x): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices)