From 4f1670c0b7b5af30a14ea191bf0c97d4543cd958 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Sat, 30 Dec 2023 20:10:09 +0530 Subject: [PATCH] feat: elementwise functions (mappers) (#5) * feat: elementwise functions (mappers) * abs * acos, acosh * Give up on versioned submodules; it makes typing difficult. * Give up on versioned submodules; it makes typing difficult. * Give up on versioned submodules; it makes typing difficult. * Also checking function values. * atan, atan2, atanh, bitwise_and, bitwise_invert * bitwise_shift_left, bitwise_or, bitwise_shift_right, bitwise_xor, ceil * Use 'numpy.array_api' to test these functions. * conj, cos, cosh, divide, equal * exp, expm1, floor, floor_divide * greater, greater_equal, imag * isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logaddexp, logical_and, logical_not, logical_or, logical_xor * multiply, negative, not_equal, positive, pow, real, remainder, round * sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc; finished all of the free elementwise functions * Implemented all elementwise dunder methods on 'array', too. * pylint in CI is tripping up on https://github.com/pylint-dev/pylint/issues/850. * Also test all of the in-place dunder methods. --- src/ragged/__init__.py | 288 +++++- src/ragged/{common => }/_import.py | 0 src/ragged/{common => }/_spec_array_object.py | 367 ++++++- src/ragged/{common => }/_spec_constants.py | 0 .../{common => }/_spec_creation_functions.py | 30 +- .../{common => }/_spec_data_type_functions.py | 12 +- .../_spec_elementwise_functions.py | 237 ++--- .../{common => }/_spec_indexing_functions.py | 2 +- .../_spec_linear_algebra_functions.py | 8 +- .../_spec_manipulation_functions.py | 20 +- .../{common => }/_spec_searching_functions.py | 8 +- .../{common => }/_spec_set_functions.py | 8 +- .../{common => }/_spec_sorting_functions.py | 4 +- .../_spec_statistical_functions.py | 14 +- .../{common => }/_spec_utility_functions.py | 4 +- src/ragged/{common => }/_typing.py | 25 +- src/ragged/common/__init__.py | 293 ------ src/ragged/v202212/__init__.py | 298 ------ src/ragged/v202212/_spec_array_object.py | 18 - src/ragged/v202212/_spec_constants.py | 11 - .../v202212/_spec_creation_functions.py | 45 - .../v202212/_spec_data_type_functions.py | 25 - .../v202212/_spec_elementwise_functions.py | 131 --- .../v202212/_spec_indexing_functions.py | 11 - .../v202212/_spec_linear_algebra_functions.py | 16 - .../v202212/_spec_manipulation_functions.py | 33 - .../v202212/_spec_searching_functions.py | 16 - src/ragged/v202212/_spec_set_functions.py | 16 - src/ragged/v202212/_spec_sorting_functions.py | 11 - .../v202212/_spec_statistical_functions.py | 19 - src/ragged/v202212/_spec_utility_functions.py | 11 - tests/conftest.py | 22 + tests/test_spec_array_object.py | 37 + tests/test_spec_elementwise_functions.py | 970 ++++++++++++++++++ tests/test_spec_version.py | 2 - 35 files changed, 1804 insertions(+), 1208 deletions(-) rename src/ragged/{common => }/_import.py (100%) rename src/ragged/{common => }/_spec_array_object.py (68%) rename src/ragged/{common => }/_spec_constants.py (100%) rename src/ragged/{common => }/_spec_creation_functions.py (97%) rename src/ragged/{common => }/_spec_data_type_functions.py (97%) rename src/ragged/{common => }/_spec_elementwise_functions.py (89%) rename src/ragged/{common => }/_spec_indexing_functions.py (98%) rename src/ragged/{common => }/_spec_linear_algebra_functions.py (98%) rename src/ragged/{common => }/_spec_manipulation_functions.py (97%) rename src/ragged/{common => }/_spec_searching_functions.py (97%) rename src/ragged/{common => }/_spec_set_functions.py (97%) rename src/ragged/{common => }/_spec_sorting_functions.py (97%) rename src/ragged/{common => }/_spec_statistical_functions.py (98%) rename src/ragged/{common => }/_spec_utility_functions.py (98%) rename src/ragged/{common => }/_typing.py (75%) delete mode 100644 src/ragged/common/__init__.py delete mode 100644 src/ragged/v202212/__init__.py delete mode 100644 src/ragged/v202212/_spec_array_object.py delete mode 100644 src/ragged/v202212/_spec_constants.py delete mode 100644 src/ragged/v202212/_spec_creation_functions.py delete mode 100644 src/ragged/v202212/_spec_data_type_functions.py delete mode 100644 src/ragged/v202212/_spec_elementwise_functions.py delete mode 100644 src/ragged/v202212/_spec_indexing_functions.py delete mode 100644 src/ragged/v202212/_spec_linear_algebra_functions.py delete mode 100644 src/ragged/v202212/_spec_manipulation_functions.py delete mode 100644 src/ragged/v202212/_spec_searching_functions.py delete mode 100644 src/ragged/v202212/_spec_set_functions.py delete mode 100644 src/ragged/v202212/_spec_sorting_functions.py delete mode 100644 src/ragged/v202212/_spec_statistical_functions.py delete mode 100644 src/ragged/v202212/_spec_utility_functions.py create mode 100644 tests/conftest.py diff --git a/src/ragged/__init__.py b/src/ragged/__init__.py index 54b6b77..8377e92 100644 --- a/src/ragged/__init__.py +++ b/src/ragged/__init__.py @@ -4,10 +4,292 @@ Ragged array module. FIXME: needs more documentation! - -Version 2022.12 is current, so `ragged.v202212.*` is identical to `ragged.*`. """ from __future__ import annotations -from .v202212 import * # noqa: F403 # pylint: disable=W0622 +from ._spec_array_object import array +from ._spec_constants import ( + e, + inf, + nan, + newaxis, + pi, +) +from ._spec_creation_functions import ( + arange, + asarray, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, +) +from ._spec_data_type_functions import ( + astype, + can_cast, + finfo, + iinfo, + isdtype, + result_type, +) +from ._spec_elementwise_functions import ( # pylint: disable=W0622 + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + imag, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) +from ._spec_indexing_functions import ( + take, +) +from ._spec_linear_algebra_functions import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) +from ._spec_manipulation_functions import ( + broadcast_arrays, + broadcast_to, + concat, + expand_dims, + flip, + permute_dims, + reshape, + roll, + squeeze, + stack, +) +from ._spec_searching_functions import ( + argmax, + argmin, + nonzero, + where, +) +from ._spec_set_functions import ( + unique_all, + unique_counts, + unique_inverse, + unique_values, +) +from ._spec_sorting_functions import ( + argsort, + sort, +) +from ._spec_statistical_functions import ( # pylint: disable=W0622 + max, + mean, + min, + prod, + std, + sum, + var, +) +from ._spec_utility_functions import ( # pylint: disable=W0622 + all, + any, +) + +__array_api_version__ = "2022.12" + +__all__ = [ + "__array_api_version__", + # _spec_array_object + "array", + # _spec_constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # _spec_creation_functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # _spec_data_type_functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # _spec_elementwise_functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # _spec_indexing_functions + "take", + # _spec_linear_algebra_functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # _spec_manipulation_functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", + # _spec_searching_functions + "argmax", + "argmin", + "nonzero", + "where", + # _spec_set_functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # _spec_sorting_functions + "argsort", + "sort", + # _spec_statistical_functions + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # _spec_utility_functions + "all", + "any", +] diff --git a/src/ragged/common/_import.py b/src/ragged/_import.py similarity index 100% rename from src/ragged/common/_import.py rename to src/ragged/_import.py diff --git a/src/ragged/common/_spec_array_object.py b/src/ragged/_spec_array_object.py similarity index 68% rename from src/ragged/common/_spec_array_object.py rename to src/ragged/_spec_array_object.py index c237458..493e669 100644 --- a/src/ragged/common/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -7,7 +7,7 @@ from __future__ import annotations import enum -from numbers import Real +import numbers from typing import TYPE_CHECKING, Any, Union import awkward as ak @@ -29,6 +29,7 @@ Shape, SupportsBufferProtocol, SupportsDLPack, + numeric_types, ) @@ -161,7 +162,24 @@ def __init__( self._impl = obj self._shape, self._dtype = _shape_dtype(self._impl.layout) - elif isinstance(obj, (bool, Real)): + elif hasattr(obj, "__dlpack_device__") and getattr(obj, "shape", None) == (): + device_type, _ = obj.__dlpack_device__() + if ( + isinstance(device_type, enum.Enum) and device_type.value == 1 + ) or device_type == 1: + self._impl = np.array(obj) + self._shape, self._dtype = (), self._impl.dtype + elif ( + isinstance(device_type, enum.Enum) and device_type.value == 2 + ) or device_type == 2: + cp = _import.cupy() + self._impl = cp.array(obj) + self._shape, self._dtype = (), self._impl.dtype + else: + msg = f"unsupported __dlpack_device__ type: {device_type}" + raise TypeError(msg) + + elif isinstance(obj, (bool, numbers.Complex)): self._impl = np.array(obj) self._shape, self._dtype = (), self._impl.dtype @@ -169,7 +187,7 @@ def __init__( self._impl = ak.Array(obj) self._shape, self._dtype = _shape_dtype(self._impl.layout) - if not isinstance(dtype, np.dtype): + if dtype is not None and not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) if dtype is not None and dtype != self._dtype: @@ -188,8 +206,8 @@ def __init__( msg = f"dtype must not have a shape: dtype.shape = {self._dtype.shape}" raise TypeError(msg) - if not issubclass(self._dtype.type, np.number): - msg = f"dtype must be numeric: dtype.type = {self._dtype.type}" + if self._dtype.type not in numeric_types: + msg = f"dtype must be numeric (bool, [u]int*, float*, complex*): dtype.type = {self._dtype.type}" raise TypeError(msg) if device is not None: @@ -197,7 +215,7 @@ def __init__( self._impl = ak.to_backend(self._impl, device) elif isinstance(self._impl, np.ndarray) and device == "cuda": cp = _import.cupy() - self._impl = cp.array(self._impl.item()) + self._impl = cp.array(self._impl) assert copy is None, "TODO" @@ -207,7 +225,7 @@ def __str__(self) -> str: """ if len(self._shape) == 0: - return f"{self._impl.item()}" + return f"{self._impl}" elif len(self._shape) == 1: return f"{ak._prettyprint.valuestr(self._impl, 1, 80)}" else: @@ -222,7 +240,7 @@ def __repr__(self) -> str: """ if len(self._shape) == 0: - return f"ragged.array({self._impl.item()})" + return f"ragged.array({self._impl})" elif len(self._shape) == 1: return f"ragged.array({ak._prettyprint.valuestr(self._impl, 1, 80 - 14)})" else: @@ -266,7 +284,7 @@ def mT(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.mT.html """ - assert False, "TODO" + assert False, "TODO 1" @property def ndim(self) -> int: @@ -329,7 +347,7 @@ def T(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.T.html """ - assert False, "TODO" + assert False, "TODO 2" # methods: https://data-apis.org/array-api/latest/API_specification/array_object.html#methods @@ -340,7 +358,11 @@ def __abs__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__abs__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + return ns.abs(self) def __add__(self, other: int | float | array, /) -> array: """ @@ -350,7 +372,14 @@ def __add__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__add__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.add(self, other) def __and__(self, other: int | bool | array, /) -> array: """ @@ -360,7 +389,14 @@ def __and__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__and__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.bitwise_and(self, other) def __array_namespace__(self, *, api_version: None | str = None) -> Any: """ @@ -369,8 +405,13 @@ def __array_namespace__(self, *, api_version: None | str = None) -> Any: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html """ - assert api_version, "TODO" - assert False, "TODO" + import ragged # pylint: disable=C0415,R0401 + + if api_version is not None and api_version != ragged.__array_api_version__: + msg = f"api_version {api_version!r} is not implemented; {ragged.__array_api_version__ = }" + raise NotImplementedError(msg) + + return ragged def __bool__(self) -> bool: # FIXME pylint: disable=E0304 """ @@ -379,7 +420,7 @@ def __bool__(self) -> bool: # FIXME pylint: disable=E0304 https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__bool__.html """ - assert False, "TODO" + return bool(self._impl) def __complex__(self) -> complex: """ @@ -388,7 +429,7 @@ def __complex__(self) -> complex: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__complex__.html """ - assert False, "TODO" + return complex(self._impl) # type: ignore[arg-type] def __dlpack__(self, *, stream: None | int | Any = None) -> PyCapsule: """ @@ -406,7 +447,7 @@ def __dlpack__(self, *, stream: None | int | Any = None) -> PyCapsule: """ assert stream, "TODO" - assert False, "TODO" + assert False, "TODO 9" def __dlpack_device__(self) -> tuple[enum.Enum, int]: """ @@ -418,7 +459,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html """ - assert False, "TODO" + assert False, "TODO 10" def __eq__(self, other: int | float | bool | array, /) -> array: # type: ignore[override] """ @@ -428,7 +469,14 @@ def __eq__(self, other: int | float | bool | array, /) -> array: # type: ignore https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__eq__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.equal(self, other) def __float__(self) -> float: """ @@ -437,7 +485,7 @@ def __float__(self) -> float: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__float__.html """ - assert False, "TODO" + return float(self._impl) # type: ignore[arg-type] def __floordiv__(self, other: int | float | array, /) -> array: """ @@ -447,7 +495,14 @@ def __floordiv__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__floordiv__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.floor_divide(self, other) def __ge__(self, other: int | float | array, /) -> array: """ @@ -457,7 +512,14 @@ def __ge__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__ge__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.greater_equal(self, other) def __getitem__(self, key: GetSliceKey, /) -> array: """ @@ -466,7 +528,7 @@ def __getitem__(self, key: GetSliceKey, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__getitem__.html """ - assert False, "TODO" + assert False, "TODO 15" def __gt__(self, other: int | float | array, /) -> array: """ @@ -476,7 +538,14 @@ def __gt__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__gt__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.greater(self, other) def __index__(self) -> int: # FIXME pylint: disable=E0305 """ @@ -485,7 +554,7 @@ def __index__(self) -> int: # FIXME pylint: disable=E0305 https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__index__.html """ - assert False, "TODO" + return self._impl.__index__() # type: ignore[no-any-return, union-attr] def __int__(self) -> int: """ @@ -494,7 +563,7 @@ def __int__(self) -> int: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__int__.html """ - assert False, "TODO" + return int(self._impl) # type: ignore[arg-type] def __invert__(self) -> array: """ @@ -503,7 +572,11 @@ def __invert__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__invert__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + return ns.bitwise_invert(self) def __le__(self, other: int | float | array, /) -> array: """ @@ -513,7 +586,14 @@ def __le__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__le__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.less_equal(self, other) def __lshift__(self, other: int | array, /) -> array: """ @@ -523,7 +603,14 @@ def __lshift__(self, other: int | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__lshift__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.bitwise_left_shift(self, other) def __lt__(self, other: int | float | array, /) -> array: """ @@ -533,7 +620,14 @@ def __lt__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__lt__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.less(self, other) def __matmul__(self, other: array, /) -> array: """ @@ -542,7 +636,7 @@ def __matmul__(self, other: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__matmul__.html """ - assert False, "TODO" + assert False, "TODO 22" def __mod__(self, other: int | float | array, /) -> array: """ @@ -552,7 +646,14 @@ def __mod__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__mod__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.remainder(self, other) def __mul__(self, other: int | float | array, /) -> array: """ @@ -562,7 +663,14 @@ def __mul__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__mul__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.multiply(self, other) def __ne__(self, other: int | float | bool | array, /) -> array: # type: ignore[override] """ @@ -572,7 +680,14 @@ def __ne__(self, other: int | float | bool | array, /) -> array: # type: ignore https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__ne__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.not_equal(self, other) def __neg__(self) -> array: """ @@ -581,7 +696,11 @@ def __neg__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__neg__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + return ns.negative(self) def __or__(self, other: int | bool | array, /) -> array: """ @@ -591,7 +710,14 @@ def __or__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__or__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.bitwise_or(self, other) def __pos__(self) -> array: """ @@ -600,7 +726,11 @@ def __pos__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__pos__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + return ns.positive(self) def __pow__(self, other: int | float | array, /) -> array: """ @@ -612,7 +742,14 @@ def __pow__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__pow__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.pow(self, other) def __rshift__(self, other: int | array, /) -> array: """ @@ -622,7 +759,14 @@ def __rshift__(self, other: int | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__rshift__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.bitwise_right_shift(self, other) def __setitem__( self, key: SetSliceKey, value: int | float | bool | array, / @@ -633,7 +777,7 @@ def __setitem__( https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__setitem__.html """ - assert False, "TODO" + assert False, "TODO 31" def __sub__(self, other: int | float | array, /) -> array: """ @@ -643,7 +787,14 @@ def __sub__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__sub__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.subtract(self, other) def __truediv__(self, other: int | float | array, /) -> array: """ @@ -653,7 +804,14 @@ def __truediv__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__truediv__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.divide(self, other) def __xor__(self, other: int | bool | array, /) -> array: """ @@ -663,7 +821,14 @@ def __xor__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__xor__.html """ - assert False, "TODO" + from ragged import ( # pylint: disable=C0415,R0401 + _spec_elementwise_functions as ns, + ) + + if not isinstance(other, array): + other = array(other, device=self._device) + + return ns.bitwise_xor(self, other) def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> array: """ @@ -680,20 +845,25 @@ def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> ar https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html """ - if isinstance(self._impl, ak.Array) and device != ak.backend(self._impl): - assert stream is None, "TODO" - impl = ak.to_backend(self._impl, device) + if isinstance(self._impl, ak.Array): + if device != ak.backend(self._impl): + assert stream is None, "TODO" + impl = ak.to_backend(self._impl, device) + else: + impl = self._impl elif isinstance(self._impl, np.ndarray): + # self._impl is a NumPy 0-dimensional array if device == "cuda": assert stream is None, "TODO" cp = _import.cupy() - impl = cp.array(self._impl.item()) + impl = cp.array(self._impl) else: impl = self._impl else: - impl = np.array(self._impl.item()) if device == "cpu" else self._impl + # self._impl is a CuPy 0-dimensional array + impl = self._impl.get() if device == "cpu" else self._impl # type: ignore[union-attr] return self._new(impl, self._shape, self._dtype, device) @@ -709,6 +879,10 @@ def __iadd__(self, other: int | float | array, /) -> array: out = self + other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __isub__(self, other: int | float | array, /) -> array: @@ -721,6 +895,10 @@ def __isub__(self, other: int | float | array, /) -> array: out = self - other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __imul__(self, other: int | float | array, /) -> array: @@ -733,6 +911,10 @@ def __imul__(self, other: int | float | array, /) -> array: out = self * other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __itruediv__(self, other: int | float | array, /) -> array: @@ -745,6 +927,10 @@ def __itruediv__(self, other: int | float | array, /) -> array: out = self / other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __ifloordiv__(self, other: int | float | array, /) -> array: @@ -757,6 +943,10 @@ def __ifloordiv__(self, other: int | float | array, /) -> array: out = self // other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __ipow__(self, other: int | float | array, /) -> array: @@ -769,6 +959,10 @@ def __ipow__(self, other: int | float | array, /) -> array: out = self**other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __imod__(self, other: int | float | array, /) -> array: @@ -781,6 +975,10 @@ def __imod__(self, other: int | float | array, /) -> array: out = self % other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __imatmul__(self, other: array, /) -> array: @@ -793,6 +991,10 @@ def __imatmul__(self, other: array, /) -> array: out = self @ other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __iand__(self, other: int | bool | array, /) -> array: @@ -805,6 +1007,10 @@ def __iand__(self, other: int | bool | array, /) -> array: out = self & other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __ior__(self, other: int | bool | array, /) -> array: @@ -817,6 +1023,10 @@ def __ior__(self, other: int | bool | array, /) -> array: out = self | other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __ixor__(self, other: int | bool | array, /) -> array: @@ -829,6 +1039,10 @@ def __ixor__(self, other: int | bool | array, /) -> array: out = self ^ other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __ilshift__(self, other: int | array, /) -> array: @@ -841,6 +1055,10 @@ def __ilshift__(self, other: int | array, /) -> array: out = self << other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self def __irshift__(self, other: int | array, /) -> array: @@ -853,6 +1071,10 @@ def __irshift__(self, other: int | array, /) -> array: out = self >> other self._impl, self._device = out._impl, out._device + if isinstance(self._impl, ak.Array): + self._shape, self._dtype = _shape_dtype(self._impl.layout) + else: + self._shape, self._dtype = (), self._impl.dtype # type: ignore[union-attr] return self # reflected operators: https://data-apis.org/array-api/2022.12/API_specification/array_object.html#reflected-operators @@ -870,3 +1092,50 @@ def __irshift__(self, other: int | array, /) -> array: __rxor__ = __xor__ __rlshift__ = __lshift__ __rrshift__ = __rshift__ + + +def _unbox(*inputs: array) -> tuple[ak.Array | SupportsDLPack, ...]: + if len(inputs) > 1 and any(type(inputs[0]) is not type(x) for x in inputs): + types = "\n".join(f"{type(x).__module__}.{type(x).__name__}" for x in inputs) + msg = f"mixed array types: {types}" + raise TypeError(msg) + + return tuple(x._impl for x in inputs) # pylint: disable=W0212 + + +def _box( + cls: type[array], + output: ak.Array | np.number | SupportsDLPack, + *, + dtype: None | Dtype = None, +) -> array: + if isinstance(output, ak.Array): + impl = output + shape, dtype_observed = _shape_dtype(output.layout) + if dtype is not None and dtype != dtype_observed: + impl = ak.values_astype(impl, dtype) + else: + dtype = dtype_observed + device = ak.backend(output) + + elif isinstance(output, np.number): + impl = np.array(output) + shape = output.shape + dtype_observed = output.dtype + if dtype is not None and dtype != dtype_observed: + impl = impl.astype(dtype) + else: + dtype = dtype_observed + device = "cpu" + + else: + impl = output + shape = output.shape # type: ignore[union-attr] + dtype_observed = output.dtype # type: ignore[union-attr] + if dtype is not None and dtype != dtype_observed: + impl = impl.astype(dtype) + else: + dtype = dtype_observed + device = "cpu" if isinstance(output, np.ndarray) else "cuda" + + return cls._new(impl, shape, dtype, device) # pylint: disable=W0212 diff --git a/src/ragged/common/_spec_constants.py b/src/ragged/_spec_constants.py similarity index 100% rename from src/ragged/common/_spec_constants.py rename to src/ragged/_spec_constants.py diff --git a/src/ragged/common/_spec_creation_functions.py b/src/ragged/_spec_creation_functions.py similarity index 97% rename from src/ragged/common/_spec_creation_functions.py rename to src/ragged/_spec_creation_functions.py index 0e64ca7..6b4c66b 100644 --- a/src/ragged/common/_spec_creation_functions.py +++ b/src/ragged/_spec_creation_functions.py @@ -58,7 +58,7 @@ def arange( assert step, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 35" def asarray( @@ -139,7 +139,7 @@ def empty( assert shape, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 36" def empty_like( @@ -164,7 +164,7 @@ def empty_like( assert x, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 37" def eye( @@ -201,7 +201,7 @@ def eye( assert k, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 38" def from_dlpack(x: object, /) -> array: @@ -218,7 +218,7 @@ def from_dlpack(x: object, /) -> array: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 39" def full( @@ -257,7 +257,7 @@ def full( assert fill_value, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 40" def full_like( @@ -290,7 +290,7 @@ def full_like( assert fill_value, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 41" def linspace( @@ -349,7 +349,7 @@ def linspace( assert dtype, "TODO" assert device, "TODO" assert endpoint, "TODO" - assert False, "TODO" + assert False, "TODO 42" def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: @@ -390,7 +390,7 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: assert arrays, "TODO" assert indexing, "TODO" - assert False, "TODO" + assert False, "TODO 43" def ones( @@ -417,7 +417,7 @@ def ones( assert shape, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 44" def ones_like( @@ -443,7 +443,7 @@ def ones_like( assert x, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 45" def tril(x: array, /, *, k: int = 0) -> array: @@ -468,7 +468,7 @@ def tril(x: array, /, *, k: int = 0) -> array: assert x, "TODO" assert k, "TODO" - assert False, "TODO" + assert False, "TODO 46" def triu(x: array, /, *, k: int = 0) -> array: @@ -493,7 +493,7 @@ def triu(x: array, /, *, k: int = 0) -> array: assert x, "TODO" assert k, "TODO" - assert False, "TODO" + assert False, "TODO 47" def zeros( @@ -520,7 +520,7 @@ def zeros( assert shape, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 48" def zeros_like( @@ -546,4 +546,4 @@ def zeros_like( assert x, "TODO" assert dtype, "TODO" assert device, "TODO" - assert False, "TODO" + assert False, "TODO 49" diff --git a/src/ragged/common/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py similarity index 97% rename from src/ragged/common/_spec_data_type_functions.py rename to src/ragged/_spec_data_type_functions.py index 5c0df34..ebce649 100644 --- a/src/ragged/common/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -37,7 +37,7 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: assert x, "TODO" assert dtype, "TODO" assert copy, "TODO" - assert False, "TODO" + assert False, "TODO 50" def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: @@ -58,7 +58,7 @@ def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: assert from_, "TODO" assert to, "TODO" - assert False, "TODO" + assert False, "TODO 51" @dataclass @@ -115,7 +115,7 @@ def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 """ assert type, "TODO" - assert False, "TODO" + assert False, "TODO 52" @dataclass @@ -156,7 +156,7 @@ def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 """ assert type, "TODO" - assert False, "TODO" + assert False, "TODO 53" def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: @@ -201,7 +201,7 @@ def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: assert dtype, "TODO" assert kind, "TODO" - assert False, "TODO" + assert False, "TODO 54" def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: @@ -219,4 +219,4 @@ def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: """ assert arrays_and_dtypes, "TODO" - assert False, "TODO" + assert False, "TODO 55" diff --git a/src/ragged/common/_spec_elementwise_functions.py b/src/ragged/_spec_elementwise_functions.py similarity index 89% rename from src/ragged/common/_spec_elementwise_functions.py rename to src/ragged/_spec_elementwise_functions.py index 3e8e218..3357c6c 100644 --- a/src/ragged/common/_spec_elementwise_functions.py +++ b/src/ragged/_spec_elementwise_functions.py @@ -6,7 +6,11 @@ from __future__ import annotations -from ._spec_array_object import array +import warnings + +import numpy as np + +from ._spec_array_object import _box, _unbox, array def abs(x: array, /) -> array: # pylint: disable=W0622 @@ -36,8 +40,7 @@ def abs(x: array, /) -> array: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.abs.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.absolute(*_unbox(x))) def acos(x: array, /) -> array: @@ -66,8 +69,7 @@ def acos(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.acos.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arccos(*_unbox(x))) def acosh(x: array, /) -> array: @@ -102,8 +104,7 @@ def acosh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.acosh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arccosh(*_unbox(x))) def add(x1: array, x2: array, /) -> array: @@ -122,9 +123,7 @@ def add(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.add.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.add(*_unbox(x1, x2))) def asin(x: array, /) -> array: @@ -153,8 +152,7 @@ def asin(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.asin.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arcsin(*_unbox(x))) def asinh(x: array, /) -> array: @@ -183,8 +181,7 @@ def asinh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.asinh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arcsinh(*_unbox(x))) def atan(x: array, /) -> array: @@ -209,8 +206,7 @@ def atan(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.atan.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arctan(*_unbox(x))) def atan2(x1: array, x2: array, /) -> array: @@ -251,9 +247,7 @@ def atan2(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.atan2.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.arctan2(*_unbox(x1, x2))) def atanh(x: array, /) -> array: @@ -282,8 +276,7 @@ def atanh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.atanh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.arctanh(*_unbox(x))) def bitwise_and(x1: array, x2: array, /) -> array: @@ -303,9 +296,7 @@ def bitwise_and(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_and.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.bitwise_and(*_unbox(x1, x2))) def bitwise_invert(x: array, /) -> array: @@ -322,8 +313,7 @@ def bitwise_invert(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_invert.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.invert(*_unbox(x))) def bitwise_left_shift(x1: array, x2: array, /) -> array: @@ -344,9 +334,7 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_left_shift.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.left_shift(*_unbox(x1, x2))) def bitwise_or(x1: array, x2: array, /) -> array: @@ -366,9 +354,7 @@ def bitwise_or(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_or.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.bitwise_or(*_unbox(x1, x2))) def bitwise_right_shift(x1: array, x2: array, /) -> array: @@ -390,9 +376,7 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_right_shift.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.right_shift(*_unbox(x1, x2))) def bitwise_xor(x1: array, x2: array, /) -> array: @@ -412,9 +396,7 @@ def bitwise_xor(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_xor.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.bitwise_xor(*_unbox(x1, x2))) def ceil(x: array, /) -> array: @@ -432,8 +414,7 @@ def ceil(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.ceil(*_unbox(x)), dtype=x.dtype) def conj(x: array, /) -> array: @@ -462,8 +443,7 @@ def conj(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.conj.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.conjugate(*_unbox(x))) def cos(x: array, /) -> array: @@ -483,8 +463,7 @@ def cos(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.cos.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.cos(*_unbox(x))) def cosh(x: array, /) -> array: @@ -507,8 +486,7 @@ def cosh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.cosh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.cosh(*_unbox(x))) def divide(x1: array, x2: array, /) -> array: @@ -527,9 +505,7 @@ def divide(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.divide.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.divide(*_unbox(x1, x2))) def equal(x1: array, x2: array, /) -> array: @@ -549,9 +525,7 @@ def equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.equal(*_unbox(x1, x2))) def exp(x: array, /) -> array: @@ -571,8 +545,7 @@ def exp(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.exp.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.exp(*_unbox(x))) def expm1(x: array, /) -> array: @@ -594,8 +567,7 @@ def expm1(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.expm1.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.expm1(*_unbox(x))) def floor(x: array, /) -> array: @@ -614,8 +586,7 @@ def floor(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype) def floor_divide(x1: array, x2: array, /) -> array: @@ -636,9 +607,7 @@ def floor_divide(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor_divide.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.floor_divide(*_unbox(x1, x2))) def greater(x1: array, x2: array, /) -> array: @@ -658,9 +627,7 @@ def greater(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.greater(*_unbox(x1, x2))) def greater_equal(x1: array, x2: array, /) -> array: @@ -680,9 +647,7 @@ def greater_equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater_equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.greater_equal(*_unbox(x1, x2))) def imag(x: array, /) -> array: @@ -702,8 +667,14 @@ def imag(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.imag.html """ - assert x, "TODO" - assert False, "TODO" + (a,) = _unbox(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return _box( + type(x), + (a - np.conjugate(a)) / 2j, + dtype=np.dtype(f"f{x.dtype.itemsize // 2}"), + ) def isfinite(x: array, /) -> array: @@ -720,8 +691,7 @@ def isfinite(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isfinite.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.isfinite(*_unbox(x))) def isinf(x: array, /) -> array: @@ -739,8 +709,7 @@ def isinf(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isinf.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.isinf(*_unbox(x))) def isnan(x: array, /) -> array: @@ -758,8 +727,7 @@ def isnan(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isnan.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.isnan(*_unbox(x))) def less(x1: array, x2: array, /) -> array: @@ -779,9 +747,7 @@ def less(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.less.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.less(*_unbox(x1, x2))) def less_equal(x1: array, x2: array, /) -> array: @@ -801,9 +767,7 @@ def less_equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.less_equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.less_equal(*_unbox(x1, x2))) def log(x: array, /) -> array: @@ -822,8 +786,7 @@ def log(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.log(*_unbox(x))) def log1p(x: array, /) -> array: @@ -846,8 +809,7 @@ def log1p(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log1p.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.log1p(*_unbox(x))) def log2(x: array, /) -> array: @@ -866,8 +828,7 @@ def log2(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log2.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.log2(*_unbox(x))) def log10(x: array, /) -> array: @@ -886,8 +847,7 @@ def log10(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log10.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.log10(*_unbox(x))) def logaddexp(x1: array, x2: array, /) -> array: @@ -907,9 +867,7 @@ def logaddexp(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logaddexp.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.logaddexp(*_unbox(x1, x2))) def logical_and(x1: array, x2: array, /) -> array: @@ -928,9 +886,7 @@ def logical_and(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_and.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.logical_and(*_unbox(x1, x2))) def logical_not(x: array, /) -> array: @@ -947,8 +903,7 @@ def logical_not(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_not.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.logical_not(*_unbox(x))) def logical_or(x1: array, x2: array, /) -> array: @@ -967,9 +922,7 @@ def logical_or(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_or.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.logical_or(*_unbox(x1, x2))) def logical_xor(x1: array, x2: array, /) -> array: @@ -988,9 +941,7 @@ def logical_xor(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_xor.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.logical_xor(*_unbox(x1, x2))) def multiply(x1: array, x2: array, /) -> array: @@ -1009,9 +960,7 @@ def multiply(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.multiply.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.multiply(*_unbox(x1, x2))) def negative(x: array, /) -> array: @@ -1029,8 +978,7 @@ def negative(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.negative.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.negative(*_unbox(x))) def not_equal(x1: array, x2: array, /) -> array: @@ -1050,9 +998,7 @@ def not_equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.not_equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.not_equal(*_unbox(x1, x2))) def positive(x: array, /) -> array: @@ -1070,8 +1016,7 @@ def positive(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.positive.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.positive(*_unbox(x))) def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622 @@ -1094,9 +1039,7 @@ def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.pow.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.power(*_unbox(x1, x2))) def real(x: array, /) -> array: @@ -1116,8 +1059,14 @@ def real(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.real.html """ - assert x, "TODO" - assert False, "TODO" + (a,) = _unbox(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return _box( + type(x), + (a + np.conjugate(a)) / 2, + dtype=np.dtype(f"f{x.dtype.itemsize // 2}"), + ) def remainder(x1: array, x2: array, /) -> array: @@ -1139,9 +1088,7 @@ def remainder(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.remainder.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.remainder(*_unbox(x1, x2))) def round(x: array, /) -> array: # pylint: disable=W0622 @@ -1167,8 +1114,24 @@ def round(x: array, /) -> array: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.round.html """ - assert x, "TODO" - assert False, "TODO" + (a,) = _unbox(x) + if x.dtype in (np.complex64, np.complex128): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + a_conj = np.conjugate(a) + dt = np.dtype(f"f{x.dtype.itemsize // 2}") + re = _box(type(x), (a + a_conj) / 2, dtype=dt) + im = _box(type(x), (a - a_conj) / 2j, dtype=dt) + return add(round(re), multiply(round(im), array(1j, device=x.device))) + + else: + frac, whole = np.modf(a) + abs_frac = np.absolute(frac) + return _box( + type(x), + whole + + ((abs_frac == 0.5) * (whole % 2 != 0) + (abs_frac > 0.5)) * np.sign(frac), + ) def sign(x: array, /) -> array: @@ -1193,8 +1156,7 @@ def sign(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sign.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.sign(*_unbox(x))) def sin(x: array, /) -> array: @@ -1214,8 +1176,7 @@ def sin(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.sin(*_unbox(x))) def sinh(x: array, /) -> array: @@ -1238,8 +1199,7 @@ def sinh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sinh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.sinh(*_unbox(x))) def square(x: array, /) -> array: @@ -1260,8 +1220,7 @@ def square(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.square.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.square(*_unbox(x))) def sqrt(x: array, /) -> array: @@ -1280,8 +1239,7 @@ def sqrt(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sqrt.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.sqrt(*_unbox(x))) def subtract(x1: array, x2: array, /) -> array: @@ -1303,9 +1261,7 @@ def subtract(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.subtract.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO" + return _box(type(x1), np.subtract(*_unbox(x1, x2))) def tan(x: array, /) -> array: @@ -1325,8 +1281,7 @@ def tan(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.tan.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.tan(*_unbox(x))) def tanh(x: array, /) -> array: @@ -1352,8 +1307,7 @@ def tanh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.tanh.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.tanh(*_unbox(x))) def trunc(x: array, /) -> array: @@ -1371,5 +1325,4 @@ def trunc(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.trunc.html """ - assert x, "TODO" - assert False, "TODO" + return _box(type(x), np.trunc(*_unbox(x))) diff --git a/src/ragged/common/_spec_indexing_functions.py b/src/ragged/_spec_indexing_functions.py similarity index 98% rename from src/ragged/common/_spec_indexing_functions.py rename to src/ragged/_spec_indexing_functions.py index 6c82f19..d4a2b36 100644 --- a/src/ragged/common/_spec_indexing_functions.py +++ b/src/ragged/_spec_indexing_functions.py @@ -40,4 +40,4 @@ def take(x: array, indices: array, /, *, axis: None | int = None) -> array: assert x, "TODO" assert indices, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 109" diff --git a/src/ragged/common/_spec_linear_algebra_functions.py b/src/ragged/_spec_linear_algebra_functions.py similarity index 98% rename from src/ragged/common/_spec_linear_algebra_functions.py rename to src/ragged/_spec_linear_algebra_functions.py index 5690296..b6aec1b 100644 --- a/src/ragged/common/_spec_linear_algebra_functions.py +++ b/src/ragged/_spec_linear_algebra_functions.py @@ -75,7 +75,7 @@ def matmul(x1: array, x2: array, /) -> array: assert x1, "TODO" assert x2, "TODO" - assert False, "TODO" + assert False, "TODO 110" def matrix_transpose(x: array, /) -> array: @@ -94,7 +94,7 @@ def matrix_transpose(x: array, /) -> array: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 111" def tensordot( @@ -143,7 +143,7 @@ def tensordot( assert x1, "TODO" assert x2, "TODO" assert axes, "TODO" - assert False, "TODO" + assert False, "TODO 112" def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: @@ -187,4 +187,4 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: assert x1, "TODO" assert x2, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 113" diff --git a/src/ragged/common/_spec_manipulation_functions.py b/src/ragged/_spec_manipulation_functions.py similarity index 97% rename from src/ragged/common/_spec_manipulation_functions.py rename to src/ragged/_spec_manipulation_functions.py index 84c4000..39802a1 100644 --- a/src/ragged/common/_spec_manipulation_functions.py +++ b/src/ragged/_spec_manipulation_functions.py @@ -24,7 +24,7 @@ def broadcast_arrays(*arrays: array) -> list[array]: """ assert arrays, "TODO" - assert False, "TODO" + assert False, "TODO 114" def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array: @@ -45,7 +45,7 @@ def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array: assert x, "TODO" assert shape, "TODO" - assert False, "TODO" + assert False, "TODO 115" def concat( @@ -73,7 +73,7 @@ def concat( assert arrays, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 116" def expand_dims(x: array, /, *, axis: int = 0) -> array: @@ -102,7 +102,7 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array: assert x, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 117" def flip(x: array, /, *, axis: None | int | tuple[int, ...] = None) -> array: @@ -126,7 +126,7 @@ def flip(x: array, /, *, axis: None | int | tuple[int, ...] = None) -> array: assert x, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 118" def permute_dims(x: array, /, axes: tuple[int, ...]) -> array: @@ -147,7 +147,7 @@ def permute_dims(x: array, /, axes: tuple[int, ...]) -> array: assert x, "TODO" assert axes, "TODO" - assert False, "TODO" + assert False, "TODO 119" def reshape(x: array, /, shape: tuple[int, ...], *, copy: None | bool = None) -> array: @@ -175,7 +175,7 @@ def reshape(x: array, /, shape: tuple[int, ...], *, copy: None | bool = None) -> assert x, "TODO" assert shape, "TODO" assert copy, "TODO" - assert False, "TODO" + assert False, "TODO 120" def roll( @@ -216,7 +216,7 @@ def roll( assert x, "TODO" assert shift, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 121" def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array: @@ -236,7 +236,7 @@ def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array: assert x, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 122" def stack(arrays: tuple[array, ...] | list[array], /, *, axis: int = 0) -> array: @@ -268,4 +268,4 @@ def stack(arrays: tuple[array, ...] | list[array], /, *, axis: int = 0) -> array assert arrays, "TODO" assert axis, "TODO" - assert False, "TODO" + assert False, "TODO 123" diff --git a/src/ragged/common/_spec_searching_functions.py b/src/ragged/_spec_searching_functions.py similarity index 97% rename from src/ragged/common/_spec_searching_functions.py rename to src/ragged/_spec_searching_functions.py index fd587b5..d5e717c 100644 --- a/src/ragged/common/_spec_searching_functions.py +++ b/src/ragged/_spec_searching_functions.py @@ -37,7 +37,7 @@ def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 124" def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: @@ -68,7 +68,7 @@ def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 125" def nonzero(x: array, /) -> tuple[array, ...]: @@ -90,7 +90,7 @@ def nonzero(x: array, /) -> tuple[array, ...]: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 126" def where(condition: array, x1: array, x2: array, /) -> array: @@ -115,4 +115,4 @@ def where(condition: array, x1: array, x2: array, /) -> array: assert condition, "TODO" assert x1, "TODO" assert x2, "TODO" - assert False, "TODO" + assert False, "TODO 127" diff --git a/src/ragged/common/_spec_set_functions.py b/src/ragged/_spec_set_functions.py similarity index 97% rename from src/ragged/common/_spec_set_functions.py rename to src/ragged/_spec_set_functions.py index 2dd96f4..5b0b97d 100644 --- a/src/ragged/common/_spec_set_functions.py +++ b/src/ragged/_spec_set_functions.py @@ -48,7 +48,7 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 128" unique_counts_result = namedtuple( # pylint: disable=C0103 @@ -79,7 +79,7 @@ def unique_counts(x: array, /) -> tuple[array, array]: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 129" unique_inverse_result = namedtuple( # pylint: disable=C0103 @@ -110,7 +110,7 @@ def unique_inverse(x: array, /) -> tuple[array, array]: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 130" def unique_values(x: array, /) -> array: @@ -130,4 +130,4 @@ def unique_values(x: array, /) -> array: """ assert x, "TODO" - assert False, "TODO" + assert False, "TODO 131" diff --git a/src/ragged/common/_spec_sorting_functions.py b/src/ragged/_spec_sorting_functions.py similarity index 97% rename from src/ragged/common/_spec_sorting_functions.py rename to src/ragged/_spec_sorting_functions.py index d448333..9df0579 100644 --- a/src/ragged/common/_spec_sorting_functions.py +++ b/src/ragged/_spec_sorting_functions.py @@ -38,7 +38,7 @@ def argsort( assert axis, "TODO" assert descending, "TODO" assert stable, "TODO" - assert False, "TODO" + assert False, "TODO 132" def sort( @@ -70,4 +70,4 @@ def sort( assert axis, "TODO" assert descending, "TODO" assert stable, "TODO" - assert False, "TODO" + assert False, "TODO 133" diff --git a/src/ragged/common/_spec_statistical_functions.py b/src/ragged/_spec_statistical_functions.py similarity index 98% rename from src/ragged/common/_spec_statistical_functions.py rename to src/ragged/_spec_statistical_functions.py index eaf55c9..8a8dd0a 100644 --- a/src/ragged/common/_spec_statistical_functions.py +++ b/src/ragged/_spec_statistical_functions.py @@ -38,7 +38,7 @@ def max( # pylint: disable=W0622 assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 134" def mean( @@ -69,7 +69,7 @@ def mean( assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 135" def min( # pylint: disable=W0622 @@ -100,7 +100,7 @@ def min( # pylint: disable=W0622 assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 136" def prod( @@ -158,7 +158,7 @@ def prod( assert axis, "TODO" assert dtype, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 137" def std( @@ -209,7 +209,7 @@ def std( assert axis, "TODO" assert correction, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 138" def sum( # pylint: disable=W0622 @@ -267,7 +267,7 @@ def sum( # pylint: disable=W0622 assert axis, "TODO" assert dtype, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 139" def var( @@ -316,4 +316,4 @@ def var( assert axis, "TODO" assert correction, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 140" diff --git a/src/ragged/common/_spec_utility_functions.py b/src/ragged/_spec_utility_functions.py similarity index 98% rename from src/ragged/common/_spec_utility_functions.py rename to src/ragged/_spec_utility_functions.py index 5184afa..189c7f1 100644 --- a/src/ragged/common/_spec_utility_functions.py +++ b/src/ragged/_spec_utility_functions.py @@ -46,7 +46,7 @@ def all( # pylint: disable=W0622 assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 141" def any( # pylint: disable=W0622 @@ -86,4 +86,4 @@ def any( # pylint: disable=W0622 assert x, "TODO" assert axis, "TODO" assert keepdims, "TODO" - assert False, "TODO" + assert False, "TODO 142" diff --git a/src/ragged/common/_typing.py b/src/ragged/_typing.py similarity index 75% rename from src/ragged/common/_typing.py rename to src/ragged/_typing.py index 8567153..0b2c675 100644 --- a/src/ragged/common/_typing.py +++ b/src/ragged/_typing.py @@ -6,7 +6,7 @@ from __future__ import annotations -import numbers +import enum import sys from typing import Any, Literal, Optional, Protocol, TypeVar, Union @@ -39,7 +39,7 @@ class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... - def item(self) -> numbers.Number: + def __dlpack_device__(self, /) -> tuple[enum.Enum, int]: ... @@ -47,6 +47,7 @@ def item(self) -> numbers.Number: Dtype = np.dtype[ Union[ + np.bool_, np.int8, np.int16, np.int32, @@ -57,7 +58,25 @@ def item(self) -> numbers.Number: np.uint64, np.float32, np.float64, + np.complex64, + np.complex128, ] ] -Device = Union[Literal["cpu"], Literal["cuda"]] +numeric_types = ( + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float32, + np.float64, + np.complex64, + np.complex128, +) + +Device = Literal["cpu", "cuda"] diff --git a/src/ragged/common/__init__.py b/src/ragged/common/__init__.py deleted file mode 100644 index 076de36..0000000 --- a/src/ragged/common/__init__.py +++ /dev/null @@ -1,293 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -Generic definitions used by the version-specific modules, such as -`ragged.v202212`. - -https://data-apis.org/array-api/latest/API_specification/ -""" - -from __future__ import annotations - -from ._spec_array_object import array -from ._spec_constants import ( - e, - inf, - nan, - newaxis, - pi, -) -from ._spec_creation_functions import ( - arange, - asarray, - empty, - empty_like, - eye, - from_dlpack, - full, - full_like, - linspace, - meshgrid, - ones, - ones_like, - tril, - triu, - zeros, - zeros_like, -) -from ._spec_data_type_functions import ( - astype, - can_cast, - finfo, - iinfo, - isdtype, - result_type, -) -from ._spec_elementwise_functions import ( # pylint: disable=W0622 - abs, - acos, - acosh, - add, - asin, - asinh, - atan, - atan2, - atanh, - bitwise_and, - bitwise_invert, - bitwise_left_shift, - bitwise_or, - bitwise_right_shift, - bitwise_xor, - ceil, - conj, - cos, - cosh, - divide, - equal, - exp, - expm1, - floor, - floor_divide, - greater, - greater_equal, - imag, - isfinite, - isinf, - isnan, - less, - less_equal, - log, - log1p, - log2, - log10, - logaddexp, - logical_and, - logical_not, - logical_or, - logical_xor, - multiply, - negative, - not_equal, - positive, - pow, - real, - remainder, - round, - sign, - sin, - sinh, - sqrt, - square, - subtract, - tan, - tanh, - trunc, -) -from ._spec_indexing_functions import ( - take, -) -from ._spec_linear_algebra_functions import ( - matmul, - matrix_transpose, - tensordot, - vecdot, -) -from ._spec_manipulation_functions import ( - broadcast_arrays, - broadcast_to, - concat, - expand_dims, - flip, - permute_dims, - reshape, - roll, - squeeze, - stack, -) -from ._spec_searching_functions import ( - argmax, - argmin, - nonzero, - where, -) -from ._spec_set_functions import ( - unique_all, - unique_counts, - unique_inverse, - unique_values, -) -from ._spec_sorting_functions import ( - argsort, - sort, -) -from ._spec_statistical_functions import ( # pylint: disable=W0622 - max, - mean, - min, - prod, - std, - sum, - var, -) -from ._spec_utility_functions import ( # pylint: disable=W0622 - all, - any, -) - -__all__ = [ - # _spec_array_object - "array", - # _spec_constants - "e", - "inf", - "nan", - "newaxis", - "pi", - # _spec_creation_functions - "arange", - "asarray", - "empty", - "empty_like", - "eye", - "from_dlpack", - "full", - "full_like", - "linspace", - "meshgrid", - "ones", - "ones_like", - "tril", - "triu", - "zeros", - "zeros_like", - # _spec_data_type_functions - "astype", - "can_cast", - "finfo", - "iinfo", - "isdtype", - "result_type", - # _spec_elementwise_functions - "abs", - "acos", - "acosh", - "add", - "asin", - "asinh", - "atan", - "atan2", - "atanh", - "bitwise_and", - "bitwise_left_shift", - "bitwise_invert", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "ceil", - "conj", - "cos", - "cosh", - "divide", - "equal", - "exp", - "expm1", - "floor", - "floor_divide", - "greater", - "greater_equal", - "imag", - "isfinite", - "isinf", - "isnan", - "less", - "less_equal", - "log", - "log1p", - "log2", - "log10", - "logaddexp", - "logical_and", - "logical_not", - "logical_or", - "logical_xor", - "multiply", - "negative", - "not_equal", - "positive", - "pow", - "real", - "remainder", - "round", - "sign", - "sin", - "sinh", - "square", - "sqrt", - "subtract", - "tan", - "tanh", - "trunc", - # _spec_indexing_functions - "take", - # _spec_linear_algebra_functions - "matmul", - "matrix_transpose", - "tensordot", - "vecdot", - # _spec_manipulation_functions - "broadcast_arrays", - "broadcast_to", - "concat", - "expand_dims", - "flip", - "permute_dims", - "reshape", - "roll", - "squeeze", - "stack", - # _spec_searching_functions - "argmax", - "argmin", - "nonzero", - "where", - # _spec_set_functions - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - # _spec_sorting_functions - "argsort", - "sort", - # _spec_statistical_functions - "max", - "mean", - "min", - "prod", - "std", - "sum", - "var", - # _spec_utility_functions - "all", - "any", -] diff --git a/src/ragged/v202212/__init__.py b/src/ragged/v202212/__init__.py deleted file mode 100644 index 8b9ac6f..0000000 --- a/src/ragged/v202212/__init__.py +++ /dev/null @@ -1,298 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -Defines a ragged array module that is compliant with version 2022.12 of the -Array API. - -This is the current default: `ragged.v202212.*` is imported into `ragged.*`. - -https://data-apis.org/array-api/2022.12/API_specification/ -""" - -from __future__ import annotations - -__array_api_version__ = "2022.12" - -from ._spec_array_object import array -from ._spec_constants import ( - e, - inf, - nan, - newaxis, - pi, -) -from ._spec_creation_functions import ( - arange, - asarray, - empty, - empty_like, - eye, - from_dlpack, - full, - full_like, - linspace, - meshgrid, - ones, - ones_like, - tril, - triu, - zeros, - zeros_like, -) -from ._spec_data_type_functions import ( - astype, - can_cast, - finfo, - iinfo, - isdtype, - result_type, -) -from ._spec_elementwise_functions import ( # pylint: disable=W0622 - abs, - acos, - acosh, - add, - asin, - asinh, - atan, - atan2, - atanh, - bitwise_and, - bitwise_invert, - bitwise_left_shift, - bitwise_or, - bitwise_right_shift, - bitwise_xor, - ceil, - conj, - cos, - cosh, - divide, - equal, - exp, - expm1, - floor, - floor_divide, - greater, - greater_equal, - imag, - isfinite, - isinf, - isnan, - less, - less_equal, - log, - log1p, - log2, - log10, - logaddexp, - logical_and, - logical_not, - logical_or, - logical_xor, - multiply, - negative, - not_equal, - positive, - pow, - real, - remainder, - round, - sign, - sin, - sinh, - sqrt, - square, - subtract, - tan, - tanh, - trunc, -) -from ._spec_indexing_functions import ( - take, -) -from ._spec_linear_algebra_functions import ( - matmul, - matrix_transpose, - tensordot, - vecdot, -) -from ._spec_manipulation_functions import ( - broadcast_arrays, - broadcast_to, - concat, - expand_dims, - flip, - permute_dims, - reshape, - roll, - squeeze, - stack, -) -from ._spec_searching_functions import ( - argmax, - argmin, - nonzero, - where, -) -from ._spec_set_functions import ( - unique_all, - unique_counts, - unique_inverse, - unique_values, -) -from ._spec_sorting_functions import ( - argsort, - sort, -) -from ._spec_statistical_functions import ( # pylint: disable=W0622 - max, - mean, - min, - prod, - std, - sum, - var, -) -from ._spec_utility_functions import ( # pylint: disable=W0622 - all, - any, -) - -__all__ = [ - "__array_api_version__", - # _spec_array_object - "array", - # _spec_constants - "e", - "inf", - "nan", - "newaxis", - "pi", - # _spec_creation_functions - "arange", - "asarray", - "empty", - "empty_like", - "eye", - "from_dlpack", - "full", - "full_like", - "linspace", - "meshgrid", - "ones", - "ones_like", - "tril", - "triu", - "zeros", - "zeros_like", - # _spec_data_type_functions - "astype", - "can_cast", - "finfo", - "iinfo", - "isdtype", - "result_type", - # _spec_elementwise_functions - "abs", - "acos", - "acosh", - "add", - "asin", - "asinh", - "atan", - "atan2", - "atanh", - "bitwise_and", - "bitwise_left_shift", - "bitwise_invert", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "ceil", - "conj", - "cos", - "cosh", - "divide", - "equal", - "exp", - "expm1", - "floor", - "floor_divide", - "greater", - "greater_equal", - "imag", - "isfinite", - "isinf", - "isnan", - "less", - "less_equal", - "log", - "log1p", - "log2", - "log10", - "logaddexp", - "logical_and", - "logical_not", - "logical_or", - "logical_xor", - "multiply", - "negative", - "not_equal", - "positive", - "pow", - "real", - "remainder", - "round", - "sign", - "sin", - "sinh", - "square", - "sqrt", - "subtract", - "tan", - "tanh", - "trunc", - # _spec_indexing_functions - "take", - # _spec_linear_algebra_functions - "matmul", - "matrix_transpose", - "tensordot", - "vecdot", - # _spec_manipulation_functions - "broadcast_arrays", - "broadcast_to", - "concat", - "expand_dims", - "flip", - "permute_dims", - "reshape", - "roll", - "squeeze", - "stack", - # _spec_searching_functions - "argmax", - "argmin", - "nonzero", - "where", - # _spec_set_functions - "unique_all", - "unique_counts", - "unique_inverse", - "unique_values", - # _spec_sorting_functions - "argsort", - "sort", - # _spec_statistical_functions - "max", - "mean", - "min", - "prod", - "std", - "sum", - "var", - # _spec_utility_functions - "all", - "any", -] diff --git a/src/ragged/v202212/_spec_array_object.py b/src/ragged/v202212/_spec_array_object.py deleted file mode 100644 index b02a87a..0000000 --- a/src/ragged/v202212/_spec_array_object.py +++ /dev/null @@ -1,18 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/array_object.html -""" - -from __future__ import annotations - -from ..common._spec_array_object import array as common_array - - -class array(common_array): # pylint: disable=C0103 - """ - Ragged array class and constructor for data-apis.org/array-api/2022.12. - """ - - -__all__ = ["array"] diff --git a/src/ragged/v202212/_spec_constants.py b/src/ragged/v202212/_spec_constants.py deleted file mode 100644 index d2a410d..0000000 --- a/src/ragged/v202212/_spec_constants.py +++ /dev/null @@ -1,11 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/constants.html -""" - -from __future__ import annotations - -from ..common._spec_constants import e, inf, nan, newaxis, pi - -__all__ = ["e", "inf", "nan", "newaxis", "pi"] diff --git a/src/ragged/v202212/_spec_creation_functions.py b/src/ragged/v202212/_spec_creation_functions.py deleted file mode 100644 index 5a15f63..0000000 --- a/src/ragged/v202212/_spec_creation_functions.py +++ /dev/null @@ -1,45 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/creation_functions.html -""" - -from __future__ import annotations - -from ..common._spec_creation_functions import ( - arange, - asarray, - empty, - empty_like, - eye, - from_dlpack, - full, - full_like, - linspace, - meshgrid, - ones, - ones_like, - tril, - triu, - zeros, - zeros_like, -) - -__all__ = [ - "arange", - "asarray", - "empty", - "empty_like", - "eye", - "from_dlpack", - "full", - "full_like", - "linspace", - "meshgrid", - "ones", - "ones_like", - "tril", - "triu", - "zeros", - "zeros_like", -] diff --git a/src/ragged/v202212/_spec_data_type_functions.py b/src/ragged/v202212/_spec_data_type_functions.py deleted file mode 100644 index 5e4a54a..0000000 --- a/src/ragged/v202212/_spec_data_type_functions.py +++ /dev/null @@ -1,25 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/data_type_functions.html -""" - -from __future__ import annotations - -from ..common._spec_data_type_functions import ( - astype, - can_cast, - finfo, - iinfo, - isdtype, - result_type, -) - -__all__ = [ - "astype", - "can_cast", - "finfo", - "iinfo", - "isdtype", - "result_type", -] diff --git a/src/ragged/v202212/_spec_elementwise_functions.py b/src/ragged/v202212/_spec_elementwise_functions.py deleted file mode 100644 index c2d670d..0000000 --- a/src/ragged/v202212/_spec_elementwise_functions.py +++ /dev/null @@ -1,131 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/elementwise_functions.html -""" - -from __future__ import annotations - -from ..common._spec_elementwise_functions import ( # pylint: disable=W0622 - abs, - acos, - acosh, - add, - asin, - asinh, - atan, - atan2, - atanh, - bitwise_and, - bitwise_invert, - bitwise_left_shift, - bitwise_or, - bitwise_right_shift, - bitwise_xor, - ceil, - conj, - cos, - cosh, - divide, - equal, - exp, - expm1, - floor, - floor_divide, - greater, - greater_equal, - imag, - isfinite, - isinf, - isnan, - less, - less_equal, - log, - log1p, - log2, - log10, - logaddexp, - logical_and, - logical_not, - logical_or, - logical_xor, - multiply, - negative, - not_equal, - positive, - pow, - real, - remainder, - round, - sign, - sin, - sinh, - sqrt, - square, - subtract, - tan, - tanh, - trunc, -) - -__all__ = [ - "abs", - "acos", - "acosh", - "add", - "asin", - "asinh", - "atan", - "atan2", - "atanh", - "bitwise_and", - "bitwise_left_shift", - "bitwise_invert", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "ceil", - "conj", - "cos", - "cosh", - "divide", - "equal", - "exp", - "expm1", - "floor", - "floor_divide", - "greater", - "greater_equal", - "imag", - "isfinite", - "isinf", - "isnan", - "less", - "less_equal", - "log", - "log1p", - "log2", - "log10", - "logaddexp", - "logical_and", - "logical_not", - "logical_or", - "logical_xor", - "multiply", - "negative", - "not_equal", - "positive", - "pow", - "real", - "remainder", - "round", - "sign", - "sin", - "sinh", - "square", - "sqrt", - "subtract", - "tan", - "tanh", - "trunc", -] diff --git a/src/ragged/v202212/_spec_indexing_functions.py b/src/ragged/v202212/_spec_indexing_functions.py deleted file mode 100644 index 87952c1..0000000 --- a/src/ragged/v202212/_spec_indexing_functions.py +++ /dev/null @@ -1,11 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/indexing_functions.html -""" - -from __future__ import annotations - -from ..common._spec_indexing_functions import take - -__all__ = ["take"] diff --git a/src/ragged/v202212/_spec_linear_algebra_functions.py b/src/ragged/v202212/_spec_linear_algebra_functions.py deleted file mode 100644 index 9840905..0000000 --- a/src/ragged/v202212/_spec_linear_algebra_functions.py +++ /dev/null @@ -1,16 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/linear_algebra_functions.html -""" - -from __future__ import annotations - -from ..common._spec_linear_algebra_functions import ( - matmul, - matrix_transpose, - tensordot, - vecdot, -) - -__all__ = ["matmul", "matrix_transpose", "tensordot", "vecdot"] diff --git a/src/ragged/v202212/_spec_manipulation_functions.py b/src/ragged/v202212/_spec_manipulation_functions.py deleted file mode 100644 index 520185b..0000000 --- a/src/ragged/v202212/_spec_manipulation_functions.py +++ /dev/null @@ -1,33 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/manipulation_functions.html -""" - -from __future__ import annotations - -from ..common._spec_manipulation_functions import ( - broadcast_arrays, - broadcast_to, - concat, - expand_dims, - flip, - permute_dims, - reshape, - roll, - squeeze, - stack, -) - -__all__ = [ - "broadcast_arrays", - "broadcast_to", - "concat", - "expand_dims", - "flip", - "permute_dims", - "reshape", - "roll", - "squeeze", - "stack", -] diff --git a/src/ragged/v202212/_spec_searching_functions.py b/src/ragged/v202212/_spec_searching_functions.py deleted file mode 100644 index 1c9d607..0000000 --- a/src/ragged/v202212/_spec_searching_functions.py +++ /dev/null @@ -1,16 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/searching_functions.html -""" - -from __future__ import annotations - -from ..common._spec_searching_functions import argmax, argmin, nonzero, where - -__all__ = [ - "argmax", - "argmin", - "nonzero", - "where", -] diff --git a/src/ragged/v202212/_spec_set_functions.py b/src/ragged/v202212/_spec_set_functions.py deleted file mode 100644 index 7b53eb4..0000000 --- a/src/ragged/v202212/_spec_set_functions.py +++ /dev/null @@ -1,16 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/set_functions.html -""" - -from __future__ import annotations - -from ..common._spec_set_functions import ( - unique_all, - unique_counts, - unique_inverse, - unique_values, -) - -__all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"] diff --git a/src/ragged/v202212/_spec_sorting_functions.py b/src/ragged/v202212/_spec_sorting_functions.py deleted file mode 100644 index e2dfbf4..0000000 --- a/src/ragged/v202212/_spec_sorting_functions.py +++ /dev/null @@ -1,11 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/sorting_functions.html -""" - -from __future__ import annotations - -from ..common._spec_sorting_functions import argsort, sort - -__all__ = ["argsort", "sort"] diff --git a/src/ragged/v202212/_spec_statistical_functions.py b/src/ragged/v202212/_spec_statistical_functions.py deleted file mode 100644 index c072ef5..0000000 --- a/src/ragged/v202212/_spec_statistical_functions.py +++ /dev/null @@ -1,19 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html -""" - -from __future__ import annotations - -from ..common._spec_statistical_functions import ( # pylint: disable=W0622 - max, - mean, - min, - prod, - std, - sum, - var, -) - -__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"] diff --git a/src/ragged/v202212/_spec_utility_functions.py b/src/ragged/v202212/_spec_utility_functions.py deleted file mode 100644 index 1e60440..0000000 --- a/src/ragged/v202212/_spec_utility_functions.py +++ /dev/null @@ -1,11 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE - -""" -https://data-apis.org/array-api/2022.12/API_specification/utility_functions.html -""" - -from __future__ import annotations - -from ..common._spec_utility_functions import all, any # pylint: disable=W0622 - -__all__ = ["all", "any"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4dd6d6e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +from __future__ import annotations + +import reprlib + +import pytest + +import ragged + + +@pytest.fixture(scope="session", autouse=True) +def _patch_reprlib(): + if not hasattr(reprlib.Repr, "repr1_original"): + + def repr1(self, x, level): + if isinstance(x, ragged.array): + return self.repr_instance(x, level) + return self.repr1_original(x, level) + + reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined] + reprlib.Repr.repr1 = repr1 # type: ignore[method-assign] diff --git a/tests/test_spec_array_object.py b/tests/test_spec_array_object.py index 0af1ab1..61d7d3c 100644 --- a/tests/test_spec_array_object.py +++ b/tests/test_spec_array_object.py @@ -6,8 +6,45 @@ from __future__ import annotations +import pytest + import ragged def test_existence(): assert ragged.array is not None + + +def test_namespace(): + assert ragged.array(123).__array_namespace__() is ragged + assert ( + ragged.array(123).__array_namespace__(api_version=ragged.__array_api_version__) + is ragged + ) + with pytest.raises(NotImplementedError): + ragged.array(123).__array_namespace__(api_version="does not exist") + + +def test_bool(): + assert bool(ragged.array(True)) is True + assert bool(ragged.array(False)) is False + + +def test_complex(): + assert isinstance(complex(ragged.array(1.1 + 0.1j)), complex) + assert complex(ragged.array(1.1 + 0.1j)) == 1.1 + 0.1j + + +def test_float(): + assert isinstance(float(ragged.array(1.1)), float) + assert float(ragged.array(1.1)) == 1.1 + + +def test_index(): + assert isinstance(ragged.array(10).__index__(), int) + assert ragged.array(10).__index__() == 10 + + +def test_int(): + assert isinstance(int(ragged.array(10)), int) + assert int(ragged.array(10)) == 10 diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index c8ebc66..089443e 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -6,8 +6,90 @@ from __future__ import annotations +import warnings +from typing import Any + +import awkward as ak +import numpy as np + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import numpy.array_api as xp + +import pytest + import ragged +devices = ["cpu"] +try: + import cupy as cp + + devices.append("cuda") +except ModuleNotFoundError: + cp = None + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x(request): + if request.param == "regular": + return ragged.array(np.array([1.0, 2.0, 3.0])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10.0)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_lt1(request): + if request.param == "regular": + return ragged.array(np.array([0.1, 0.2, 0.3])) + elif request.param == "irregular": + return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(0.5)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_bool(request): + if request.param == "regular": + return ragged.array(np.array([False, True, False])) + elif request.param == "irregular": + return ragged.array(ak.Array([[True, True, False], [], [False, False]])) + else: # request.param == "scalar" + return ragged.array(np.array(True)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_int(request): + if request.param == "regular": + return ragged.array(np.array([0, 1, 2], dtype=np.int64)) + elif request.param == "irregular": + return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10, dtype=np.int64)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_complex(request): + if request.param == "regular": + return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) + else: # request.param == "scalar" + return ragged.array(np.array(10 + 1j)) + + +y = x +y_lt1 = x_lt1 +y_bool = x_bool +y_int = x_int +y_complex = x_complex + + +def first(x: ragged.array) -> Any: + out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl + return xp.asarray(out.item(), dtype=x.dtype) + def test_existence(): assert ragged.abs is not None @@ -69,3 +151,891 @@ def test_existence(): assert ragged.tan is not None assert ragged.tanh is not None assert ragged.trunc is not None + + +@pytest.mark.parametrize("device", devices) +def test_abs(device, x): + result = ragged.abs(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.abs(first(x)) == first(result) + assert xp.abs(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_abs_method(device, x): + result = abs(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.abs(first(x)) == first(result) + assert xp.abs(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_acos(device, x_lt1): + result = ragged.acos(x_lt1.to_device(device)) + assert type(result) is type(x_lt1) + assert result.shape == x_lt1.shape + assert xp.acos(first(x_lt1)) == pytest.approx(first(result)) + assert xp.acos(first(x_lt1)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_acosh(device, x): + result = ragged.acosh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.acosh(first(x)) == pytest.approx(first(result)) + assert xp.acosh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_add(device, x, y): + result = ragged.add(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.add(first(x), first(y)) == first(result) + assert xp.add(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_add_method(device, x, y): + result = x.to_device(device) + y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.add(first(x), first(y)) == first(result) + assert xp.add(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_add_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.add(first(x), first(y)) + x += y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_asin(device, x_lt1): + result = ragged.asin(x_lt1.to_device(device)) + assert type(result) is type(x_lt1) + assert result.shape == x_lt1.shape + assert xp.asin(first(x_lt1)) == pytest.approx(first(result)) + assert xp.asin(first(x_lt1)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_asinh(device, x): + result = ragged.asinh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.asinh(first(x)) == pytest.approx(first(result)) + assert xp.asinh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_atan(device, x): + result = ragged.atan(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.atan(first(x)) == pytest.approx(first(result)) + assert xp.atan(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_atan2(device, x, y): + result = ragged.atan2(y.to_device(device), x.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.atan2(first(y), first(x)) == pytest.approx(first(result)) + assert xp.atan2(first(y), first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_atanh(device, x_lt1): + result = ragged.atanh(x_lt1.to_device(device)) + assert type(result) is type(x_lt1) + assert result.shape == x_lt1.shape + assert xp.atanh(first(x_lt1)) == pytest.approx(first(result)) + assert xp.atanh(first(x_lt1)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_and(device, x_int, y_int): + result = ragged.bitwise_and(x_int.to_device(device), y_int.to_device(device)) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_and(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_and(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_and_method(device, x_int, y_int): + result = x_int.to_device(device) & y_int.to_device(device) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_and(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_and(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_and_inplace_method(device, x_int, y_int): + x_int = x_int.to_device(device) + y_int = y_int.to_device(device) + z_int = xp.bitwise_and(first(x_int), first(y_int)) + x_int &= y_int + assert first(x_int) == z_int + assert x_int.dtype == z_int.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_invert(device, x_int): + result = ragged.bitwise_invert(x_int.to_device(device)) + assert type(result) is type(x_int) + assert result.shape == x_int.shape + assert xp.bitwise_invert(first(x_int)) == first(result) + assert xp.bitwise_invert(first(x_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_invert_method(device, x_int): + result = ~x_int.to_device(device) + assert type(result) is type(x_int) + assert result.shape == x_int.shape + assert xp.bitwise_invert(first(x_int)) == first(result) + assert xp.bitwise_invert(first(x_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_left_shift(device, x_int, y_int): + result = ragged.bitwise_left_shift(x_int.to_device(device), y_int.to_device(device)) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_left_shift(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_left_shift(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_left_shift_method(device, x_int, y_int): + result = x_int.to_device(device) << y_int.to_device(device) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_left_shift(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_left_shift(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_left_shift_inplace_method(device, x_int, y_int): + x_int = x_int.to_device(device) + y_int = y_int.to_device(device) + z_int = xp.bitwise_left_shift(first(x_int), first(y_int)) + x_int <<= y_int + assert first(x_int) == z_int + assert x_int.dtype == z_int.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_or(device, x_int, y_int): + result = ragged.bitwise_or(x_int.to_device(device), y_int.to_device(device)) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_or(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_or(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_or_method(device, x_int, y_int): + result = x_int.to_device(device) | y_int.to_device(device) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_or(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_or(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_or_inplace_method(device, x_int, y_int): + x_int = x_int.to_device(device) + y_int = y_int.to_device(device) + z_int = xp.bitwise_or(first(x_int), first(y_int)) + x_int |= y_int + assert first(x_int) == z_int + assert x_int.dtype == z_int.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_right_shift(device, x_int, y_int): + result = ragged.bitwise_right_shift( + x_int.to_device(device), y_int.to_device(device) + ) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_right_shift(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_right_shift(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_right_shift_method(device, x_int, y_int): + result = x_int.to_device(device) >> y_int.to_device(device) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_right_shift(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_right_shift(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_right_shift_inplace_method(device, x_int, y_int): + x_int = x_int.to_device(device) + y_int = y_int.to_device(device) + z_int = xp.bitwise_right_shift(first(x_int), first(y_int)) + x_int >>= y_int + assert first(x_int) == z_int + assert x_int.dtype == z_int.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_xor(device, x_int, y_int): + result = ragged.bitwise_xor(x_int.to_device(device), y_int.to_device(device)) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_xor(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_xor(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_xor_method(device, x_int, y_int): + result = x_int.to_device(device) ^ y_int.to_device(device) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.bitwise_xor(first(x_int), first(y_int)) == first(result) + assert xp.bitwise_xor(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_bitwise_xor_inplace_method(device, x_int, y_int): + x_int = x_int.to_device(device) + y_int = y_int.to_device(device) + z_int = xp.bitwise_xor(first(x_int), first(y_int)) + x_int ^= y_int + assert first(x_int) == z_int + assert x_int.dtype == z_int.dtype + + +@pytest.mark.parametrize("device", devices) +def test_ceil(device, x): + result = ragged.ceil(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.ceil(first(x)) == first(result) + assert xp.ceil(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_ceil_int(device, x_int): + result = ragged.ceil(x_int.to_device(device)) + assert type(result) is type(x_int) + assert result.shape == x_int.shape + assert xp.ceil(first(x_int)) == first(result) + assert xp.ceil(first(x_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_conj(device, x_complex): + result = ragged.conj(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.conj(first(x_complex)) == first(result) + assert xp.conj(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_cos(device, x): + result = ragged.cos(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.cos(first(x)) == pytest.approx(first(result)) + assert xp.cos(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_cosh(device, x): + result = ragged.cosh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.cosh(first(x)) == pytest.approx(first(result)) + assert xp.cosh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_divide(device, x, y): + result = ragged.divide(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.divide(first(x), first(y)) == first(result) + assert xp.divide(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_divide_method(device, x, y): + result = x.to_device(device) / y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.divide(first(x), first(y)) == first(result) + assert xp.divide(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_divide_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.divide(first(x), first(y)) + x /= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_equal(device, x, y): + result = ragged.equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.equal(first(x), first(y)) == first(result) + assert xp.equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_equal_method(device, x, y): + result = x.to_device(device) == y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.equal(first(x), first(y)) == first(result) + assert xp.equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_exp(device, x): + result = ragged.exp(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.exp(first(x)) == pytest.approx(first(result)) + assert xp.exp(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_expm1(device, x): + result = ragged.expm1(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.expm1(first(x)) == pytest.approx(first(result)) + assert xp.expm1(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor(device, x): + result = ragged.floor(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.floor(first(x)) == first(result) + assert xp.floor(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor_int(device, x_int): + result = ragged.floor(x_int.to_device(device)) + assert type(result) is type(x_int) + assert result.shape == x_int.shape + assert xp.floor(first(x_int)) == first(result) + assert xp.floor(first(x_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor_divide(device, x, y): + result = ragged.floor_divide(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.floor_divide(first(x), first(y)) == first(result) + assert xp.floor_divide(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor_divide_method(device, x, y): + result = x.to_device(device) // y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.floor_divide(first(x), first(y)) == first(result) + assert xp.floor_divide(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor_divide_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.floor_divide(first(x), first(y)) + x //= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_floor_divide_int(device, x_int, y_int): + with np.errstate(divide="ignore"): + result = ragged.floor_divide(x_int.to_device(device), y_int.to_device(device)) + assert type(result) is type(x_int) is type(y_int) + assert result.shape in (x_int.shape, y_int.shape) + assert xp.floor_divide(first(x_int), first(y_int)) == first(result) + assert xp.floor_divide(first(x_int), first(y_int)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_greater(device, x, y): + result = ragged.greater(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.greater(first(x), first(y)) == first(result) + assert xp.greater(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_greater_method(device, x, y): + result = x.to_device(device) > y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.greater(first(x), first(y)) == first(result) + assert xp.greater(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_greater_equal(device, x, y): + result = ragged.greater_equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.greater_equal(first(x), first(y)) == first(result) + assert xp.greater_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_greater_equal_method(device, x, y): + result = x.to_device(device) >= y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.greater_equal(first(x), first(y)) == first(result) + assert xp.greater_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_imag(device, x_complex): + result = ragged.imag(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.imag(first(x_complex)) == first(result) + assert xp.imag(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isfinite(device, x): + result = ragged.isfinite(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isfinite(first(x)) == first(result) + assert xp.isfinite(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isinf(device, x): + result = ragged.isinf(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isinf(first(x)) == first(result) + assert xp.isinf(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isnan(device, x): + result = ragged.isnan(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isnan(first(x)) == first(result) + assert xp.isnan(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less(device, x, y): + result = ragged.less(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less(first(x), first(y)) == first(result) + assert xp.less(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less_method(device, x, y): + result = x.to_device(device) < y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less(first(x), first(y)) == first(result) + assert xp.less(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less_equal(device, x, y): + result = ragged.less_equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less_equal(first(x), first(y)) == first(result) + assert xp.less_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less_equal_method(device, x, y): + result = x.to_device(device) <= y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less_equal(first(x), first(y)) == first(result) + assert xp.less_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log(device, x): + result = ragged.log(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log(first(x)) == pytest.approx(first(result)) + assert xp.log(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log1p(device, x): + result = ragged.log1p(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log1p(first(x)) == pytest.approx(first(result)) + assert xp.log1p(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log2(device, x): + result = ragged.log2(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log2(first(x)) == pytest.approx(first(result)) + assert xp.log2(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log10(device, x): + result = ragged.log10(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log10(first(x)) == pytest.approx(first(result)) + assert xp.log10(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logaddexp(device, x, y): + result = ragged.logaddexp(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.logaddexp(first(x), first(y)) == pytest.approx(first(result)) + assert xp.logaddexp(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_and(device, x_bool, y_bool): + result = ragged.logical_and(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_and(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_and(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_not(device, x_bool): + result = ragged.logical_not(x_bool.to_device(device)) + assert type(result) is type(x_bool) + assert result.shape == x_bool.shape + assert xp.logical_not(first(x_bool)) == first(result) + assert xp.logical_not(first(x_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_or(device, x_bool, y_bool): + result = ragged.logical_or(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_or(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_or(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_xor(device, x_bool, y_bool): + result = ragged.logical_xor(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_xor(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_xor(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_multiply(device, x, y): + result = ragged.multiply(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.multiply(first(x), first(y)) == first(result) + assert xp.multiply(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_multiply_method(device, x, y): + result = x.to_device(device) * y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.multiply(first(x), first(y)) == first(result) + assert xp.multiply(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_multiply_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.multiply(first(x), first(y)) + x *= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_negative(device, x): + result = ragged.negative(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.negative(first(x)) == pytest.approx(first(result)) + assert xp.negative(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_negative_method(device, x): + result = -x.to_device(device) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.negative(first(x)) == pytest.approx(first(result)) + assert xp.negative(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_not_equal(device, x, y): + result = ragged.not_equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.not_equal(first(x), first(y)) == first(result) + assert xp.not_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_not_equal_method(device, x, y): + result = x.to_device(device) != y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.not_equal(first(x), first(y)) == first(result) + assert xp.not_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_positive(device, x): + result = ragged.positive(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.positive(first(x)) == pytest.approx(first(result)) + assert xp.positive(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_positive_method(device, x): + result = +x.to_device(device) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.positive(first(x)) == pytest.approx(first(result)) + assert xp.positive(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_pow(device, x, y): + result = ragged.pow(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.pow(first(x), first(y)) == first(result) + assert xp.pow(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_pow_method(device, x, y): + result = x.to_device(device) ** y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.pow(first(x), first(y)) == first(result) + assert xp.pow(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_pow_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.pow(first(x), first(y)) + x **= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_real(device, x_complex): + result = ragged.real(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.real(first(x_complex)) == first(result) + assert xp.real(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_remainder(device, x, y): + result = ragged.remainder(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.remainder(first(x), first(y)) == first(result) + assert xp.remainder(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_remainder_method(device, x, y): + result = x.to_device(device) % y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.remainder(first(x), first(y)) == first(result) + assert xp.remainder(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_remainder_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.remainder(first(x), first(y)) + x %= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_round(device, x): + result = ragged.round(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.round(first(x)) == first(result) + assert xp.round(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_round_complex(device, x_complex): + result = ragged.round(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.round(first(x_complex)) == first(result) + assert xp.round(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sign(device, x): + result = ragged.sign(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sign(first(x)) == first(result) + assert xp.sign(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sin(device, x): + result = ragged.sin(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sin(first(x)) == pytest.approx(first(result)) + assert xp.sin(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sinh(device, x): + result = ragged.sinh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sinh(first(x)) == pytest.approx(first(result)) + assert xp.sinh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_square(device, x): + result = ragged.square(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.square(first(x)) == first(result) + assert xp.square(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sqrt(device, x): + result = ragged.sqrt(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sqrt(first(x)) == pytest.approx(first(result)) + assert xp.sqrt(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_subtract(device, x, y): + result = ragged.subtract(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.subtract(first(x), first(y)) == first(result) + assert xp.subtract(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_subtract_inplace_method(device, x, y): + x = x.to_device(device) + y = y.to_device(device) + z = xp.subtract(first(x), first(y)) + x -= y + assert first(x) == z + assert x.dtype == z.dtype + + +@pytest.mark.parametrize("device", devices) +def test_subtract_method(device, x, y): + result = x.to_device(device) - y.to_device(device) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.subtract(first(x), first(y)) == first(result) + assert xp.subtract(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_tan(device, x): + result = ragged.tan(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.tan(first(x)) == pytest.approx(first(result)) + assert xp.tan(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_tanh(device, x): + result = ragged.tanh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.tanh(first(x)) == pytest.approx(first(result)) + assert xp.tanh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_trunc(device, x): + result = ragged.trunc(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.trunc(first(x)) == first(result) + assert xp.trunc(first(x)).dtype == result.dtype diff --git a/tests/test_spec_version.py b/tests/test_spec_version.py index 12ff9b7..ac9b0ca 100644 --- a/tests/test_spec_version.py +++ b/tests/test_spec_version.py @@ -10,6 +10,4 @@ def test_values(): - assert ragged.v202212.__array_api_version__ == "2022.12" - assert ragged.__array_api_version__ == "2022.12"