Skip to content

Commit

Permalink
adding helper file & floor, ceil tests for numpy 1 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 19, 2024
1 parent ce6af43 commit a567704
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/ragged/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

from ._helper_functions import regularise_to_float
from ._spec_array_object import array
from ._spec_constants import (
e,
Expand Down Expand Up @@ -292,4 +293,6 @@
# _spec_utility_functions
"all",
"any",
# _helper_functions
"regularise_to_float",
]
13 changes: 13 additions & 0 deletions src/ragged/_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE

import numpy as np

def regularise_to_float(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t
33 changes: 10 additions & 23 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
warnings.simplefilter("ignore")

import pytest

from ragged._helper_functions import regularise_to_float
import ragged

has_complex_dtype = True
Expand Down Expand Up @@ -47,17 +47,6 @@ def first(x: ragged.array) -> Any:
return xp.asarray(out.item(), dtype=x.dtype)


def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t


def test_existence():
assert ragged.abs is not None
assert ragged.acos is not None
Expand Down Expand Up @@ -407,8 +396,6 @@ def test_ceil_int_1(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == np.array(first(result)).astype(first(result).dtype)
assert xp.ceil(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
Expand All @@ -420,8 +407,10 @@ def test_ceil_int_2(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.ceil(first(x_int)).dtype == _wrapper(result.dtype)
assert xp.ceil(first(x_int)) == first(result).astype(
ragged.regularise_to_float(first(result).dtype)
)
assert xp.ceil(first(x_int)).dtype == ragged.regularise_to_float(result.dtype)


@pytest.mark.skipif(
Expand Down Expand Up @@ -534,13 +523,11 @@ def test_floor(device, x):
)
@pytest.mark.parametrize("device", devices)
def test_floor_int_1(device, x_int):
result = ragged.floor(x_int.to_device(device))
result = ragged.floor(
x_int.to_device(device)
) # always returns float64 regardless of x_int.dtype
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
first(result).dtype
)
assert xp.floor(first(x_int)).dtype == result.dtype


@pytest.mark.skipif(
Expand All @@ -553,9 +540,9 @@ def test_floor_int_2(device, x_int):
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == np.asarray(first(result)).astype(
_wrapper(first(result).dtype)
ragged.regularise_to_float(first(result).dtype)
)
assert xp.floor(first(x_int)).dtype == _wrapper(result.dtype)
assert xp.floor(first(x_int)).dtype == ragged.regularise_to_float(result.dtype)


@pytest.mark.parametrize("device", devices)
Expand Down

0 comments on commit a567704

Please sign in to comment.