Skip to content

Commit

Permalink
fixing import, trying ceil & floor w\o wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 8, 2024
1 parent 32eae19 commit d9ebe1b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
30 changes: 20 additions & 10 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,17 @@ def ceil(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html
"""

def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t
# def _wrapper(t: np.dtype, /) -> np.dtype:
# if t in [np.int8, np.uint8, np.bool_, bool]:
# return np.float16
# elif t in [np.int16, np.uint16]:
# return np.float32
# elif t in [np.int32, np.uint32, np.int64, np.uint64]:
# return np.float64
# else:
# return t

return _box(type(x), np.ceil(*_unbox(x)), dtype=_wrapper(x.dtype))
return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype)


def conj(x: array, /) -> array:
Expand Down Expand Up @@ -596,6 +596,16 @@ def floor(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html
"""

# def _wrapper(t: np.dtype, /) -> np.dtype:
# if t in [np.int8, np.uint8, np.bool_, bool]:
# return np.float16
# elif t in [np.int16, np.uint16]:
# return np.float32
# elif t in [np.int32, np.uint32, np.int64, np.uint64]:
# return np.float64
# else:
# return t

return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype)


Expand Down
8 changes: 5 additions & 3 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
# xp = np
devices = ["cpu"]
try:
import numpy.array_api as xp
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
except ModuleNotFoundError:
import numpy as xp # noqa: ICN001

try:
import cupy as cp

Expand All @@ -48,7 +51,6 @@ def first(x: ragged.array) -> Any:
out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl
return xp.asarray(out.item(), dtype=x.dtype)


def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
Expand Down

0 comments on commit d9ebe1b

Please sign in to comment.