From 62594fc3ad4546ea073c62dade7f5b905f6f8175 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 28 Dec 2023 18:41:04 +0530 Subject: [PATCH] feat: add function stubs (no implementations yet). (#3) * feat: add function stubs (no implementations yet). * empty * empty_like, eye, from_dlpack, full, full_like, linspace, meshgrid * ones, ones_like, tril, triu, zeros, zeros_like * Data type functions: astype. * can_cast, finfo, iinfo, isdtype, result_type * All the elementwise functions. * Made stub files for the rest; not done yet. * Stubbed all remaining functions in ragged.common. * Done with all function stubs. --- pyproject.toml | 5 + src/ragged/__init__.py | 2 +- src/ragged/common/__init__.py | 268 ++++- src/ragged/common/_const.py | 2 +- src/ragged/common/_creation.py | 549 +++++++++ src/ragged/common/_datatype.py | 222 ++++ src/ragged/common/_elementwise.py | 1375 +++++++++++++++++++++++ src/ragged/common/_indexing.py | 43 + src/ragged/common/_linalg.py | 190 ++++ src/ragged/common/_manipulation.py | 271 +++++ src/ragged/common/_obj.py | 189 ++-- src/ragged/common/_search.py | 118 ++ src/ragged/common/_set.py | 133 +++ src/ragged/common/_sorting.py | 73 ++ src/ragged/common/_statistical.py | 319 ++++++ src/ragged/common/_typing.py | 13 + src/ragged/common/_utility.py | 89 ++ src/ragged/v202212/__init__.py | 271 ++++- src/ragged/v202212/_creation.py | 45 + src/ragged/v202212/_datatype.py | 25 + src/ragged/v202212/_elementwise.py | 131 +++ src/ragged/v202212/_indexing.py | 11 + src/ragged/v202212/_linalg.py | 11 + src/ragged/v202212/_manipulation.py | 33 + src/ragged/v202212/_obj.py | 4 + src/ragged/v202212/_search.py | 16 + src/ragged/v202212/_set.py | 11 + src/ragged/v202212/_sorting.py | 11 + src/ragged/v202212/_statistical.py | 19 + src/ragged/v202212/_utility.py | 11 + tests/test_0001_initial_array_object.py | 118 ++ 31 files changed, 4476 insertions(+), 102 deletions(-) create mode 100644 src/ragged/common/_creation.py create mode 100644 src/ragged/common/_datatype.py create mode 100644 src/ragged/common/_elementwise.py create mode 100644 src/ragged/common/_indexing.py create mode 100644 src/ragged/common/_linalg.py create mode 100644 src/ragged/common/_manipulation.py create mode 100644 src/ragged/common/_search.py create mode 100644 src/ragged/common/_set.py create mode 100644 src/ragged/common/_sorting.py create mode 100644 src/ragged/common/_statistical.py create mode 100644 src/ragged/common/_utility.py create mode 100644 src/ragged/v202212/_creation.py create mode 100644 src/ragged/v202212/_datatype.py create mode 100644 src/ragged/v202212/_elementwise.py create mode 100644 src/ragged/v202212/_indexing.py create mode 100644 src/ragged/v202212/_linalg.py create mode 100644 src/ragged/v202212/_manipulation.py create mode 100644 src/ragged/v202212/_search.py create mode 100644 src/ragged/v202212/_set.py create mode 100644 src/ragged/v202212/_sorting.py create mode 100644 src/ragged/v202212/_statistical.py create mode 100644 src/ragged/v202212/_utility.py diff --git a/pyproject.toml b/pyproject.toml index 6b01751..8c40c14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,9 @@ ignore = [ "ISC001", # Conflicts with formatter "RET505", # I like my if (return) elif (return) else (return) pattern "PLR5501", # I like my if (return) elif (return) else (return) pattern + + "PT015", # I'm using `assert False` for not-implemented FIXMEs + "B011", # (can't use NotImplementedError because it's used to incidate abstract methods) ] isort.required-imports = ["from __future__ import annotations"] # Uncomment if using a _compat.typing backport @@ -170,4 +173,6 @@ messages_control.disable = [ "missing-class-docstring", "missing-function-docstring", "R1705", # I like my if (return) elif (return) else (return) pattern + "R0801", # Different files can have similar lines; that's okay + "C0302", # I can have as many lines as I want; what's it with you? ] diff --git a/src/ragged/__init__.py b/src/ragged/__init__.py index d8d04be..54b6b77 100644 --- a/src/ragged/__init__.py +++ b/src/ragged/__init__.py @@ -10,4 +10,4 @@ from __future__ import annotations -from .v202212 import * # noqa: F403 +from .v202212 import * # noqa: F403 # pylint: disable=W0622 diff --git a/src/ragged/common/__init__.py b/src/ragged/common/__init__.py index dcc9188..c1990ac 100644 --- a/src/ragged/common/__init__.py +++ b/src/ragged/common/__init__.py @@ -9,6 +9,272 @@ from __future__ import annotations +from ._creation import ( + arange, + asarray, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, +) +from ._datatype import ( + astype, + can_cast, + finfo, + iinfo, + isdtype, + result_type, +) +from ._elementwise import ( # pylint: disable=W0622 + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + imag, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) +from ._indexing import ( + take, +) +from ._linalg import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) +from ._manipulation import ( + broadcast_arrays, + broadcast_to, + concat, + expand_dims, + flip, + permute_dims, + reshape, + roll, + squeeze, + stack, +) from ._obj import array +from ._search import ( + argmax, + argmin, + nonzero, + where, +) +from ._set import ( + unique_all, + unique_counts, + unique_inverse, + unique_values, +) +from ._sorting import ( + argsort, + sort, +) +from ._statistical import ( # pylint: disable=W0622 + max, + mean, + min, + prod, + std, + sum, + var, +) +from ._utility import ( # pylint: disable=W0622 + all, + any, +) -__all__ = ["array"] +__all__ = [ + # _creation + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # _datatype + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # _elementwise + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # _indexing + "take", + # _linalg + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # _manipulation + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", + # _obj + "array", + # _search + "argmax", + "argmin", + "nonzero", + "where", + # _set + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # _sorting + "argsort", + "sort", + # _statistical + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # _utility + "all", + "any", +] diff --git a/src/ragged/common/_const.py b/src/ragged/common/_const.py index 362ee71..1aa0464 100644 --- a/src/ragged/common/_const.py +++ b/src/ragged/common/_const.py @@ -1,7 +1,7 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE """ -https://data-apis.org/array-api/latest/API_specification/constants.html +https://data-apis.org/array-api/latest/API_specification/creation_functions.html """ from __future__ import annotations diff --git a/src/ragged/common/_creation.py b/src/ragged/common/_creation.py new file mode 100644 index 0000000..1be7fd8 --- /dev/null +++ b/src/ragged/common/_creation.py @@ -0,0 +1,549 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/creation_functions.html +""" + +from __future__ import annotations + +import awkward as ak + +from ._obj import array +from ._typing import ( + Device, + Dtype, + NestedSequence, + SupportsBufferProtocol, + SupportsDLPack, +) + + +def arange( + start: int | float, + /, + stop: None | int | float = None, + step: int | float = 1, + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns evenly spaced values within the half-open interval `[start, stop)` + as a one-dimensional array. + + Args: + start: If `stop` is specified, the start of interval (inclusive); + otherwise, the end of the interval (exclusive). If `stop` is not + specified, the default starting value is 0. + stop: The end of the interval. + step: The distance between two adjacent elements `(out[i+1] - out[i])`. + Must not be 0; may be negative, this results in an empty array if + `stop >= start`. + dtype: Output array data type. If dtype is `None`, the output array + data type is inferred from `start`, `stop` and `step`. If those are + all integers, the output array dtype is `np.int64`; if one or more + have type `float`, then the output array dtype is `np.float64`. + device: Device on which to place the created array. + + Returns: + A one-dimensional array containing evenly spaced values. The length of + the output array is `ceil((stop-start)/step)` if `stop - start` and + `step` have the same sign, and length 0 otherwise. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html + """ + + assert start, "TODO" + assert stop, "TODO" + assert step, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def asarray( + obj: ( + array + | ak.Array + | bool + | int + | float + | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + | SupportsDLPack + ), + dtype: None | Dtype | type | str = None, + device: None | Device = None, + copy: None | bool = None, +) -> array: + """ + Convert the input to an array. + + Args: + obj: Object to be converted to an array. May be a Python scalar, a + (possibly nested) sequence of Python scalars, or an object + supporting the Python buffer protocol or DLPack. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from the data type(s) in `obj`. If all input + values are Python scalars, then, in order of precedence, + - if all values are of type `bool`, the output data type is + `bool`. + - if all values are of type `int` or are a mixture of `bool` + and `int`, the output data type is `np.int64`. + - if one or more values are `complex` numbers, the output data + type is `np.complex128`. + - if one or more values are `float`s, the output data type is + `np.float64`. + device: Device on which to place the created array. If device is `None` + and `obj` is an array, the output array device is inferred from + `obj`. If `"cpu"`, the array is backed by NumPy and resides in main + memory; if `"cuda"`, the array is backed by CuPy and resides in + CUDA global memory. + copy: Boolean indicating whether or not to copy the input. If `True`, + this function always copies. If `False`, the function never copies + for input which supports the buffer protocol and raises a + ValueError in case a copy would be necessary. If `None`, the + function reuses the existing memory buffer if possible and copies + otherwise. + + Returns: + An array containing the data from `obj`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.asarray.html + """ + return array(obj, dtype=dtype, device=device, copy=copy) + + +def empty( + shape: int | tuple[int, ...], + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns an uninitialized array having a specified shape. + + Args: + shape: Output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array containing uninitialized data. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty.html + """ + + assert shape, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def empty_like( + x: array, /, *, dtype: None | Dtype = None, device: None | Device = None +) -> array: + """ + Returns an uninitialized array with the same shape as an input array x. + + Args: + x: Input array from which to derive the output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, output array device is inferred from `x`. + + Returns: + An array having the same shape as `x` and containing uninitialized data. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.empty_like.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def eye( + n_rows: int, + n_cols: None | int = None, + /, + *, + k: int = 0, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a two-dimensional array with ones on the kth diagonal and zeros elsewhere. + + Args: + n_rows: Number of rows in the output array. + n_cols: Number of columns in the output array. If `None`, the default + number of columns in the output array is equal to `n_rows`. + k: Index of the diagonal. A positive value refers to an upper diagonal, + a negative value to a lower diagonal, and 0 to the main diagonal. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array where all elements are equal to zero, except for the kth + diagonal, whose values are equal to one. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.eye.html + """ + + assert n_rows, "TODO" + assert n_cols, "TODO" + assert k, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def from_dlpack(x: object, /) -> array: + """ + Returns a new array containing the data from another (array) object with a `__dlpack__` method. + + Args: + x: Input (array) object. + + Returns: + An array containing the data in `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def full( + shape: int | tuple[int, ...], + fill_value: bool | int | float | complex, + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array having a specified shape and filled with fill_value. + + Args: + shape: Output array shape. + fill_value: Fill value. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `fill_value` according to the following + rules: + - if the fill value is an `int`, the output array data type is + `np.int64`. + - if the fill value is a `float`, the output array data type + is `np.float64`. + - if the fill value is a `complex` number, the output array + data type is `np.complex128`. + - if the fill value is a `bool`, the output array is + `np.bool_`. + device: Device on which to place the created array. + + Returns: + An array where every element is equal to fill_value. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html + """ + + assert shape, "TODO" + assert fill_value, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def full_like( + x: array, + /, + fill_value: bool | int | float | complex, + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array filled with fill_value and having the same shape as an input array x. + + Args: + x: Input array from which to derive the output array shape. + fill_value: Fill value. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, the output array device is inferred from `x`. + + Returns: + An array having the same shape as `x` and where every element is equal + to `fill_value`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.full_like.html + """ + + assert x, "TODO" + assert fill_value, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def linspace( + start: int | float | complex, + stop: int | float | complex, + /, + num: int, + *, + dtype: None | Dtype = None, + device: None | Device = None, + endpoint: bool = True, +) -> array: + r""" + Returns evenly spaced numbers over a specified interval. + + Let `N` be the number of generated values (which is either `num` or `num+1` + depending on whether `endpoint` is `True` or `False`, respectively). For + real-valued output arrays, the spacing between values is given by + + $$\Delta_{\textrm{real}} = \frac{\textrm{stop} - \textrm{start}}{N - 1}$$ + + For complex output arrays, let `a = real(start)`, `b = imag(start)`, + `c = real(stop)`, and `d = imag(stop)`. The spacing between complex values + is given by + + $$\Delta_{\textrm{complex}} = \frac{c-a}{N-1} + \frac{d-b}{N-1} j$$ + + Args: + start: The start of the interval. + stop: The end of the interval. If `endpoint` is `False`, the function + generates a sequence of `num+1` evenly spaced numbers starting with + `start` and ending with `stop` and exclude the `stop` from the + returned array such that the returned array consists of evenly + spaced numbers over the half-open interval `[start, stop)`. If + endpoint is `True`, the output array consists of evenly spaced + numbers over the closed interval `[start, stop]`. + num: Number of samples. Must be a nonnegative integer value. + dtype: Output array data type. Should be a floating-point data type. + If `dtype` is `None`, + - if either `start` or `stop` is a `complex` number, the + output data type is `np.complex128`. + - if both `start` and `stop` are real-valued, the output data + type is `np.float64`. + device: Device on which to place the created array. + endpoint: Boolean indicating whether to include `stop` in the interval. + + Returns: + A one-dimensional array containing evenly spaced values. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.linspace.html + """ + + assert start, "TODO" + assert stop, "TODO" + assert num, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert endpoint, "TODO" + assert False, "TODO" + + +def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: + """ + Returns coordinate matrices from coordinate vectors. + + Args: + arrays: An arbitrary number of one-dimensional arrays representing + grid coordinates. Each array should have the same numeric data type. + indexing: Cartesian `"xy"` or matrix `"ij"` indexing of output. If + provided zero or one one-dimensional vector(s) (i.e., the zero- and + one-dimensional cases, respectively), the `indexing` keyword has no + effect and should be ignored. + + Returns: + List of `N` arrays, where `N` is the number of provided one-dimensional + input arrays. Each returned array must have rank `N`. For `N` + one-dimensional arrays having lengths `Ni = len(xi)`, + - if matrix indexing `"ij"`, then each returned array must have the + shape `(N1, N2, N3, ..., Nn)`. + - if Cartesian indexing `"xy"`, then each returned array must have + shape `(N2, N1, N3, ..., Nn)`. + + Accordingly, for the two-dimensional case with input one-dimensional + arrays of length `M` and `N`, if matrix indexing `"ij"`, then each + returned array must have shape `(M, N)`, and, if Cartesian indexing + `"xy"`, then each returned array must have shape `(N, M)`. + + Similarly, for the three-dimensional case with input one-dimensional + arrays of length `M`, `N`, and `P`, if matrix indexing `"ij"`, then + each returned array must have shape `(M, N, P)`, and, if Cartesian + indexing `"xy"`, then each returned array must have shape `(N, M, P)`. + + Each returned array should have the same data type as the input arrays. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.meshgrid.html + """ + + assert arrays, "TODO" + assert indexing, "TODO" + assert False, "TODO" + + +def ones( + shape: int | tuple[int, ...], + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array having a specified `shape` and filled with ones. + + Args: + shape: Output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array containing ones. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html + """ + + assert shape, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def ones_like( + x: array, /, *, dtype: None | Dtype = None, device: None | Device = None +) -> array: + """ + Returns a new array filled with ones and having the same `shape` as an + input array `x`. + + Args: + x: Input array from which to derive the output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, the output array device is inferred from `x`. + + Returns: + An array having the same shape as x and filled with ones. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def tril(x: array, /, *, k: int = 0) -> array: + """ + Returns the lower triangular part of a matrix (or a stack of matrices) `x`. + + Args: + x: Input array having shape `(..., M, N)` and whose innermost two + dimensions form `M` by `N` matrices. + `k`: Diagonal above which to zero elements. If `k = 0`, the diagonal is + the main diagonal. If `k < 0`, the diagonal is below the main + diagonal. If `k > 0`, the diagonal is above the main diagonal. + + Returns: + An array containing the lower triangular part(s). The returned array + has the same shape and data type as `x`. All elements above the + specified diagonal `k` are zero. The returned array is allocated on the + same device as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.tril.html + """ + + assert x, "TODO" + assert k, "TODO" + assert False, "TODO" + + +def triu(x: array, /, *, k: int = 0) -> array: + """ + Returns the upper triangular part of a matrix (or a stack of matrices) `x`. + + Args: + x: Input array having shape `(..., M, N)` and whose innermost two + dimensions form `M` by `N` matrices. + k: Diagonal below which to zero elements. If `k = 0`, the diagonal is + the main diagonal. If `k < 0`, the diagonal is below the main + diagonal. If `k > 0`, the diagonal is above the main diagonal. + + Returns: + An array containing the upper triangular part(s). The returned array + has the same shape and data type as `x`. All elements below the + specified diagonal `k` are zero. The returned array is allocated on the + same device as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.triu.html + """ + + assert x, "TODO" + assert k, "TODO" + assert False, "TODO" + + +def zeros( + shape: int | tuple[int, ...], + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array having a specified shape and filled with zeros. + + Args: + shape: Output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array containing zeros. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html + """ + + assert shape, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def zeros_like( + x: array, /, *, dtype: None | Dtype = None, device: None | Device = None +) -> array: + """ + Returns a new array filled with zeros and having the same `shape` as an + input array `x`. + + Args: + x: Input array from which to derive the output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, the output array device is inferred from `x`. + + Returns: + An array having the same shape as `x` and filled with zeros. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_datatype.py b/src/ragged/common/_datatype.py new file mode 100644 index 0000000..2c7a2be --- /dev/null +++ b/src/ragged/common/_datatype.py @@ -0,0 +1,222 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/data_type_functions.html +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +from ._obj import array +from ._typing import Dtype + + +def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: + """ + Copies an array to a specified data type irrespective of type promotion rules. + + Args: + x: Array to cast. + dtype: Desired data type. + copy: Specifies whether to copy an array when the specified `dtype` + matches the data type of the input array `x`. If `True`, a newly + allocated array is always returned. If `False` and the specified + `dtype` matches the data type of the input array, the input array + is returned; otherwise, a newly allocated array is returned. + + Returns: + An array having the specified data type. The returned array has the + same `shape` as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert copy, "TODO" + assert False, "TODO" + + +def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: + """ + Determines if one data type can be cast to another data type according type + promotion rules. + + Args: + from: Input data type or array from which to cast. + to: Desired data type. + + Returns: + `True` if the cast can occur according to type promotion rules; + otherwise, `False`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.can_cast.html + """ + + assert from_, "TODO" + assert to, "TODO" + assert False, "TODO" + + +@dataclass +class finfo_object: # pylint: disable=C0103 + """ + Output of `ragged.finfo` with the following attributes. + + - bits (int): number of bits occupied by the real-valued floating-point + data type. + - eps (float): difference between 1.0 and the next smallest representable + real-valued floating-point number larger than 1.0 according to the + IEEE-754 standard. + - max (float): largest representable real-valued number. + - min (float): smallest representable real-valued number. + - smallest_normal (float): smallest positive real-valued floating-point + number with full precision. + - dtype (np.dtype): real-valued floating-point data type. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html + """ + + bits: int + eps: float + max: float + min: float + smallest_normal: float + dtype: np.dtype + + +def finfo(type: Dtype | array, /) -> finfo_object: # pylint: disable=W0622 + """ + Machine limits for floating-point data types. + + Args: + type: the kind of floating-point data-type about which to get + information. If complex, the information is about its component + data type. + + Returns: + An object having the following attributes: + + - bits (int): number of bits occupied by the real-valued floating-point + data type. + - eps (float): difference between 1.0 and the next smallest + representable real-valued floating-point number larger than 1.0 + according to the IEEE-754 standard. + - max (float): largest representable real-valued number. + - min (float): smallest representable real-valued number. + - smallest_normal (float): smallest positive real-valued floating-point + number with full precision. + - dtype (np.dtype): real-valued floating-point data type. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.finfo.html + """ + + assert type, "TODO" + assert False, "TODO" + + +@dataclass +class iinfo_object: # pylint: disable=C0103 + """ + Output of `ragged.iinfo` with the following attributes. + + - bits (int): number of bits occupied by the type. + - max (int): largest representable number. + - min (int): smallest representable number. + - dtype (np.dtype): integer data type. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html + """ + + bits: int + max: int + min: int + dtype: np.dtype + + +def iinfo(type: Dtype | array, /) -> iinfo_object: # pylint: disable=W0622 + """ + Machine limits for integer data types. + + Args: + type: The kind of integer data-type about which to get information. + + Returns: + An object having the following attributes: + + - bits (int): number of bits occupied by the type. + - max (int): largest representable number. + - min (int): smallest representable number. + - dtype (np.dtype): integer data type. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.iinfo.html + """ + + assert type, "TODO" + assert False, "TODO" + + +def isdtype(dtype: Dtype, kind: Dtype | str | tuple[Dtype | str, ...]) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type "kind". + + Args: + dtype: The input dtype. + kind: Data type kind. + If `kind` is a `dtype`, the function returns a boolean indicating + whether the input `dtype` is equal to the dtype specified by `kind`. + + If `kind` is a string, the function returns a boolean indicating + whether the input `dtype` is of a specified data type kind. The + following dtype kinds must be supported: + + - `"bool"`: boolean data types (e.g., bool). + - `"signed integer"`: signed integer data types (e.g., `int8`, + `int16`, `int32`, `int64`). + - `"unsigned integer"`: unsigned integer data types (e.g., + `uint8`, `uint16`, `uint32`, `uint64`). + - `"integral"`: integer data types. Shorthand for + (`"signed integer"`, `"unsigned integer"`). + - `"real floating"`: real-valued floating-point data types + (e.g., `float32`, `float64`). + - `"complex floating"`: complex floating-point data types + (e.g., `complex64`, `complex128`). + - `"numeric"`: numeric data types. Shorthand for (`"integral"`, + `"real floating"`, `"complex floating"`). + + If `kind` is a tuple, the tuple specifies a union of dtypes and/or + kinds, and the function returns a boolean indicating whether the + input `dtype` is either equal to a specified dtype or belongs to at + least one specified data type kind. + + Returns: + Boolean indicating whether a provided dtype is of a specified data type + kind. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + """ + + assert dtype, "TODO" + assert kind, "TODO" + assert False, "TODO" + + +def result_type(*arrays_and_dtypes: array | Dtype) -> Dtype: + """ + Returns the dtype that results from applying the type promotion rules to + the arguments. + + Args: + arrays_and_dtypes: An arbitrary number of input arrays and/or dtypes. + + Returns: + The dtype resulting from an operation involving the input arrays and dtypes. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.result_type.html + """ + + assert arrays_and_dtypes, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_elementwise.py b/src/ragged/common/_elementwise.py new file mode 100644 index 0000000..12b100d --- /dev/null +++ b/src/ragged/common/_elementwise.py @@ -0,0 +1,1375 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def abs(x: array, /) -> array: # pylint: disable=W0622 + r""" + Calculates the absolute value for each element `x_i` of the input array `x`. + + For real-valued input arrays, the element-wise result has the same + magnitude as the respective element in `x` but has positive sign. + + For complex floating-point operands, the complex absolute value is known as + the norm, modulus, or magnitude and, for a complex number `z = a + bj` is + computed as + + $$\operatorname{abs}(z) = \sqrt{a^2 + b^2}$$ + + Args: + x: Input array. + + Returns: + An array containing the absolute value of each element in `x`. If `x` + has a real-valued data type, the returned array has the same data type + as `x`. If `x` has a complex floating-point data type, the returned + array has a real-valued floating-point data type whose precision + matches the precision of `x` (e.g., if `x` is `complex128`, then the + returned array has a `float64` data type). + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.abs.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def acos(x: array, /) -> array: + r""" + Calculates an approximation of the principal value of the inverse cosine + for each element `x_i` of the input array `x`. + + Each element-wise result is expressed in radians. + + The principal value of the arc cosine of a complex number `z` is + + $$\operatorname{acos}(z) = \frac{1}{2}\pi + j\ \ln(zj + \sqrt{1-z^2})$$ + + For any `z`, + + $$\operatorname{acos}(z) = \pi - \operatorname{acos}(-z)$$ + + Args: + x: Input array. + + Returns: + An array containing the inverse cosine of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.acos.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def acosh(x: array, /) -> array: + r""" + Calculates an approximation to the inverse hyperbolic cosine for each + element `x_i` of the input array `x`. + + The principal value of the inverse hyperbolic cosine of a complex number + `z` is + + $$\operatorname{acosh}(z) = \ln(z + \sqrt{z+1}\sqrt{z-1})$$ + + For any `z`, + + $$\operatorname{acosh}(z) = \frac{\sqrt{z-1}}{\sqrt{1-z}}\operatorname{acos}(z)$$ + + or simply + + $$\operatorname{acosh}(z) = j\ \operatorname{acos}(z)$$ + + in the upper half of the complex plane. + + Args: + x: Input array whose elements each represent the area of a hyperbolic + sector. + + Returns: + An array containing the inverse hyperbolic cosine of each element in + `x`. The returned array has a floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.acosh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def add(x1: array, x2: array, /) -> array: + """ + Calculates the sum for each element `x1_i` of the input array `x1` with the + respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise sums. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.add.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def asin(x: array, /) -> array: + r""" + Calculates an approximation of the principal value of the inverse sine for + each element `x_i` of the input array `x`. + + Each element-wise result is expressed in radians. + + The principal value of the arc sine of a complex number `z` is + + $$\operatorname{asin}(z) = -j\ \ln(zj + \sqrt{1-z^2})$$ + + For any `z`, + + $$\operatorname{asin}(z) = \operatorname{acos}(-z) - \frac{\pi}{2}$$ + + Args: + x: Input array. + + Returns: + An array containing the inverse sine of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.asin.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def asinh(x: array, /) -> array: + r""" + Calculates an approximation to the inverse hyperbolic sine for each element + `x_i` in the input array `x`. + + The principal value of the inverse hyperbolic sine of a complex number `z` + is + + $$\operatorname{asinh}(z) = \ln(z + \sqrt{1+z^2})$$ + + For any `z`, + + $$\operatorname{asinh}(z) = \frac{\operatorname{asin}(zj)}{j}$$ + + Args: + x: Input array whose elements each represent the area of a hyperbolic + sector. + + Returns: + An array containing the inverse hyperbolic sine of each element in `x`. + The returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.asinh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def atan(x: array, /) -> array: + r""" + Calculates an approximation of the principal value of the inverse tangent + for each element `x_i` of the input array `x`. + + Each element-wise result is expressed in radians. + + The principal value of the inverse tangent of a complex number `z` is + + $$\operatorname{atan}(z) = -\frac{\ln(1 - zj) - \ln(1 + zj)}{2}j$$ + + Args: + x: Input array. Should have a floating-point data type. + + Returns: + An array containing the inverse tangent of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.atan.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def atan2(x1: array, x2: array, /) -> array: + """ + Calculates an approximation of the inverse tangent of the quotient `x1/x2`, + having domain `[-infinity, +infinity] \u00d7 [-infinity, +infinity]` (where + the `\u00d7` notation denotes the set of ordered pairs of elements + `(x1_i, x2_i)`) and codomain `[-π, +π]`, for each pair of elements + `(x1_i, x2_i)` of the input arrays `x1` and `x2`, respectively. Each + element-wise result is expressed in radians. + + The mathematical signs of `x1_i` and `x2_i` determine the quadrant of each + element-wise result. The quadrant (i.e., branch) is chosen such that each + element-wise result is the signed angle in radians between the ray ending + at the origin and passing through the point `(1, 0)` and the ray ending at + the origin and passing through the point `(x2_i, x1_i)`. + + Note the role reversal: the "y-coordinate" is the first function parameter; + the "x-coordinate" is the second function parameter. The parameter order is + intentional and traditional for the two-argument inverse tangent function + where the y-coordinate argument is first and the x-coordinate argument is + second. + + By IEEE 754 convention, the inverse tangent of the quotient `x1/x2` is + defined for `x2_i` equal to positive or negative zero and for either or + both of `x1_i` and `x2_i` equal to positive or negative `infinity`. + + Args: + x1: Input array corresponding to the y-coordinates. + x2: Input array corresponding to the x-coordinates. Must be + broadcastable with `x1`. + + Returns: + An array containing the inverse tangent of the quotient `x1/x2`. The + returned array has a real-valued floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.atan2.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def atanh(x: array, /) -> array: + r""" + Calculates an approximation to the inverse hyperbolic tangent for each + element `x_i` of the input array `x`. + + The principal value of the inverse hyperbolic tangent of a complex number + `z` is + + $$\operatorname{atanh}(z) = \frac{\ln(1+z)-\ln(z-1)}{2}$$ + + For any `z`, + + $$\operatorname{atanh}(z) = \frac{\operatorname{atan}(zj)}{j}$$ + + Args: + x: Input array whose elements each represent the area of a hyperbolic + sector. + + Returns: + An array containing the inverse hyperbolic tangent of each element in + `x`. The returned array has a floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.atanh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def bitwise_and(x1: array, x2: array, /) -> array: + """ + Computes the bitwise AND of the underlying binary representation of each + element `x1_i` of the input array `x1` with the respective element `x2_i` + of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_and.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def bitwise_invert(x: array, /) -> array: + """ + Inverts (flips) each bit for each element `x_i` of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the element-wise results. The returned array has + the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_invert.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def bitwise_left_shift(x1: array, x2: array, /) -> array: + """ + Shifts the bits of each element `x1_i` of the input array `x1` to the left + by appending `x2_i` (i.e., the respective element in the input array `x2`) + zeros to the right of `x1_i`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. Each element + must be greater than or equal to 0. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_left_shift.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def bitwise_or(x1: array, x2: array, /) -> array: + """ + Computes the bitwise OR of the underlying binary representation of each + element `x1_i` of the input array `x1` with the respective element `x2_i` + of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_or.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def bitwise_right_shift(x1: array, x2: array, /) -> array: + """ + Shifts the bits of each element `x1_i` of the input array `x1` to the right + according to the respective element `x2_i` of the input array `x2`. + + This operation is equivalent to floor division by a power of two. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. Each element + must be greater than or equal to 0. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_right_shift.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def bitwise_xor(x1: array, x2: array, /) -> array: + """ + Computes the bitwise XOR of the underlying binary representation of each + element `x1_i` of the input array `x1` with the respective element `x2_i` + of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.bitwise_xor.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def ceil(x: array, /) -> array: + """ + Rounds each element `x_i` of the input array `x` to the smallest (i.e., + closest to `-infinity`) integer-valued number that is not less than `x_i`. + + Args: + x: Input array. + + Returns: + An array containing the rounded result for each element in `x`. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def conj(x: array, /) -> array: + """ + Returns the complex conjugate for each element `x_i` of the input array + `x`. + + For complex numbers of the form + + $$a + bj$$ + + the complex conjugate is defined as + + $$a - bj$$ + + Hence, the returned complex conjugates is computed by negating the + imaginary component of each element `x_i`. + + Args: + x: Input array. + + Returns: + An array containing the element-wise results. The returned array has + the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.conj.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def cos(x: array, /) -> array: + r""" + Calculates an approximation to the cosine for each element `x_i` of the + input array `x`. + + Each element `x_i` is assumed to be expressed in radians. + + Args: + x: Input array whose elements are each expressed in radians. + + Returns: + An array containing the cosine of each element in `x`. The returned + array has a floating-point data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.cos.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def cosh(x: array, /) -> array: + r""" + Calculates an approximation to the hyperbolic cosine for each element `x_i` + in the input array `x`. + + The mathematical definition of the hyperbolic cosine is + + $$\operatorname{cosh}(x) = \frac{e^x + e^{-x}}{2}$$ + + Args: + x: Input array whose elements each represent a hyperbolic angle. + + Returns: + An array containing the hyperbolic cosine of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.cosh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def divide(x1: array, x2: array, /) -> array: + r""" + Calculates the division of each element `x1_i` of the input array `x1` with + the respective element `x2_i` of the input array `x2`. + + Args: + x1: Dividend input array. + x2: Divisor input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + floating-point data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.divide.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def equal(x1: array, x2: array, /) -> array: + r""" + Computes the truth value of `x1_i == x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.equal.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def exp(x: array, /) -> array: + """ + Calculates an approximation to the exponential function for each element + `x_i` of the input array `x` (`e` raised to the power of `x_i`, where `e` + is the base of the natural logarithm). + + Args: + x: Input array. + + Returns: + An array containing the evaluated exponential function result for each + element in `x`. The returned array has a floating-point data type + determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.exp.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def expm1(x: array, /) -> array: + """ + Calculates an approximation to `exp(x)-1` for each element `x_i` of the + input array `x`. + + The purpose of this function is to calculate `exp(x)-1.0` more accurately + when `x` is close to zero. + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.expm1.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def floor(x: array, /) -> array: + """ + Rounds each element `x_i` of the input array `x` to the greatest (i.e., + closest to `+infinity`) integer-valued number that is not greater than + `x_i`. + + Args: + x: Input array. + + Returns: + An array containing the rounded result for each element in `x`. The + returned array must have the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def floor_divide(x1: array, x2: array, /) -> array: + r""" + Rounds the result of dividing each element `x1_i` of the input array `x1` + by the respective element `x2_i` of the input array `x2` to the greatest + (i.e., closest to `+infinity`) integer-value number that is not greater + than the division result. + + Args: + x1: Dividend input array. + x2: Divisor input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor_divide.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def greater(x1: array, x2: array, /) -> array: + """ + Computes the truth value of `x1_i > x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def greater_equal(x1: array, x2: array, /) -> array: + """ + Computes the truth value of `x1_i >= x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater_equal.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def imag(x: array, /) -> array: + """ + Returns the imaginary component of a complex number for each element `x_i` + of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the element-wise results. The returned array has a + floating-point data type with the same floating-point precision as `x` + (e.g., if `x` is `complex64`, the returned array has the floating-point + data type `float32`). + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.imag.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def isfinite(x: array, /) -> array: + """ + Tests each element `x_i` of the input array `x` to determine if finite. + + Args: + x: Input array. + + Returns: + An array containing test results. The returned array has a data type of + `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isfinite.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def isinf(x: array, /) -> array: + """ + Tests each element `x_i` of the input array `x` to determine if equal to + positive or negative infinity. + + Args: + x: Input array. + + Returns: + An array containing test results. The returned array has a data type of + `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isinf.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def isnan(x: array, /) -> array: + """ + Tests each element `x_i` of the input array `x` to determine whether the + element is `NaN`. + + Args: + x: Input array. + + Returns: + An array containing test results. The returned array has a data type of + `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isnan.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def less(x1: array, x2: array, /) -> array: + """ + Computes the truth value of `x1_i < x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.less.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def less_equal(x1: array, x2: array, /) -> array: + """ + Computes the truth value of `x1_i <= x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.less_equal.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def log(x: array, /) -> array: + r""" + Calculates an approximation to the natural (base `e`) logarithm for each + element `x_i` of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the evaluated natural logarithm for each element in + `x`. The returned array has a floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.log.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def log1p(x: array, /) -> array: + r""" + Calculates an approximation to `log(1+x)`, where `log` refers to the + natural (base `e`) logarithm, for each element `x_i` of the input array + `x`. + + The purpose of this function is to calculate `log(1+x)` more accurately + when `x` is close to zero. + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.log1p.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def log2(x: array, /) -> array: + r""" + Calculates an approximation to the base 2 logarithm for each element `x_i` + of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the evaluated base 2 logarithm for each element in + `x`. The returned array has a floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.log2.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def log10(x: array, /) -> array: + r""" + Calculates an approximation to the base 10 logarithm for each element `x_i` + of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the evaluated base 10 logarithm for each element in + `x`. The returned array has a floating-point data type determined by + type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.log10.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def logaddexp(x1: array, x2: array, /) -> array: + """ + Calculates the logarithm of the sum of exponentiations + `log(exp(x1) + exp(x2))` for each element `x1_i` of the input array `x1` + with the respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + real-valued floating-point data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.logaddexp.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def logical_and(x1: array, x2: array, /) -> array: + """ + Computes the logical AND for each element `x1_i` of the input array `x1` + with the respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_and.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def logical_not(x: array, /) -> array: + """ + Computes the logical NOT for each element `x_i` of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_not.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def logical_or(x1: array, x2: array, /) -> array: + """ + Computes the logical OR for each element `x1_i` of the input array `x1` + with the respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_or.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def logical_xor(x1: array, x2: array, /) -> array: + """ + Computes the logical XOR for each element `x1_i` of the input array `x1` + with the respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_xor.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def multiply(x1: array, x2: array, /) -> array: + r""" + Calculates the product for each element `x1_i` of the input array `x1` with + the respective element `x2_i` of the input array `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise products. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.multiply.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def negative(x: array, /) -> array: + """ + Computes the numerical negative of each element `x_i` (i.e., `y_i = -x_i`) + of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has a data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.negative.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def not_equal(x1: array, x2: array, /) -> array: + """ + Computes the truth value of `x1_i != x2_i` for each element `x1_i` of the + input array `x1` with the respective element `x2_i` of the input array + `x2`. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type of `bool`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.not_equal.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def positive(x: array, /) -> array: + """ + Computes the numerical positive of each element `x_i` (i.e., `y_i = +x_i`) + of the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.positive.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622 + r""" + Calculates an approximation of exponentiation by raising each element + `x1_i` (the base) of the input array `x1` to the power of `x2_i` (the + exponent), where `x2_i` is the corresponding element of the input array + `x2`. + + Args: + x1: First input array whose elements correspond to the exponentiation + base. + x2: Second input array whose elements correspond to the exponentiation + exponent. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. The returned array has a + data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.pow.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def real(x: array, /) -> array: + """ + Returns the real component of a complex number for each element `x_i` of + the input array `x`. + + Args: + x: Input array. + + Returns: + An array containing the element-wise results. The returned array has a + floating-point data type with the same floating-point precision as `x` + (e.g., if `x` is `complex64`, the returned array has the floating-point + data type `float32`). + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.real.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def remainder(x1: array, x2: array, /) -> array: + """ + Returns the remainder of division for each element `x1_i` of the input + array `x1` and the respective element `x2_i` of the input array `x2`. + + This function is equivalent to the Python modulus operator `x1_i % x2_i`. + + Args: + x1: Dividend input array. + x2: Divisor input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise results. Each element-wise result + has the same sign as the respective element `x2_i`. The returned array + has a data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.remainder.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def round(x: array, /) -> array: # pylint: disable=W0622 + """ + Rounds each element `x_i` of the input array `x` to the nearest + integer-valued number. + + For complex floating-point operands, real and imaginary components are + independently rounded to the nearest integer-valued number. + + Rounded real and imaginary components are equal to their equivalent rounded + real-valued floating-point counterparts (i.e., for complex-valued `x`, + `real(round(x))` must equal `round(real(x)))` and `imag(round(x))` equals + `round(imag(x))`). + + Args: + x: Input array. + + Returns: + An array containing the rounded result for each element in `x`. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.round.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def sign(x: array, /) -> array: + r""" + Returns an indication of the sign of a number for each element `x_i` of the + input array `x`. + + The sign function (also known as the **signum function**) of a number $x_i$ + is defined as + + $$\operatorname{sign}(x_i) = \begin{cases} 0 & \textrm{if } x_i = 0 \\ \frac{x}{|x|} & \textrm{otherwise} \end{cases}$$ + + where $|x_i|$ is the absolute value of $x_i$. + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sign.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def sin(x: array, /) -> array: + r""" + Calculates an approximation to the sine for each element `x_i` of the input + array `x`. + + Each element `x_i` is assumed to be expressed in radians. + + Args: + x: Input array whose elements are each expressed in radians. + + Returns: + An array containing the sine of each element in `x`. The returned array + has a floating-point data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def sinh(x: array, /) -> array: + r""" + Calculates an approximation to the hyperbolic sine for each element `x_i` + of the input array `x`. + + The mathematical definition of the hyperbolic sine is + + $$\operatorname{sinh}(x) = \frac{e^x - e^{-x}}{2}$$ + + Args: + x: Input array whose elements each represent a hyperbolic angle. + + Returns: + An array containing the hyperbolic sine of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sinh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def square(x: array, /) -> array: + r""" + Squares each element `x_i` of the input array `x`. + + The square of a number `x_i` is defined as + + $$x_i^2 = x_i \cdot x_i$$ + + Args: + x: Input array. + + Returns: + An array containing the evaluated result for each element in `x`. The + returned array has a data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.square.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def sqrt(x: array, /) -> array: + r""" + Calculates the principal square root for each element `x_i` of the input + array `x`. + + Args: + x: Input array. + + Returns: + An array containing the square root of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sqrt.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def subtract(x1: array, x2: array, /) -> array: + """ + Calculates the difference for each element `x1_i` of the input array `x1` + with the respective element `x2_i` of the input array `x2`. + + The result of `x1_i - x2_i` is the same as `x1_i + (-x2_i)` and is governed + by the same floating-point rules as addition. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1`. + + Returns: + An array containing the element-wise differences. The returned array + has a data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.subtract.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def tan(x: array, /) -> array: + r""" + Calculates an approximation to the tangent for each element `x_i` of the + input array `x`. + + Each element `x_i` is assumed to be expressed in radians. + + Args: + x: Input array whose elements are expressed in radians. + + Returns: + An array containing the tangent of each element in `x`. The returned + array has a floating-point data type determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.tan.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def tanh(x: array, /) -> array: + r""" + Calculates an approximation to the hyperbolic tangent for each element + `x_i` of the input array `x`. + + The mathematical definition of the hyperbolic tangent is + + $$\begin{align} \operatorname{tanh}(x) &= \frac{\operatorname{sinh}(x)}{\operatorname{cosh}(x)} \\ &= \frac{e^x - e^{-x}}{e^x + e^{-x}} \end{align}$$ + + where $\operatorname{sinh}(x)$ is the hyperbolic sine and + $\operatorname{cosh}(x)$ is the hyperbolic cosine. + + Args: + x: Input array whose elements each represent a hyperbolic angle. + + Returns: + An array containing the hyperbolic tangent of each element in `x`. The + returned array has a floating-point data type determined by type + promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.tanh.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def trunc(x: array, /) -> array: + """ + Rounds each element `x_i` of the input array `x` to the nearest + integer-valued number that is closer to zero than `x_i`. + + Args: + x: Input array. + + Returns: + An array containing the rounded result for each element in `x`. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.trunc.html + """ + + assert x, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_indexing.py b/src/ragged/common/_indexing.py new file mode 100644 index 0000000..27d546b --- /dev/null +++ b/src/ragged/common/_indexing.py @@ -0,0 +1,43 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/indexing_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def take(x: array, indices: array, /, *, axis: None | int = None) -> array: + """ + Returns elements of an array along an axis. + + Conceptually, `take(x, indices, axis=3)` is equivalent to + `x[:,:,:,indices,...]`. + + Args: + x: Input array. + indices: Array indices. The array must be one-dimensional and have an + integer data type. + axis: Axis over which to select values. If `axis` is negative, the + function determines the axis along which to select values by + counting from the last dimension. + + If `x` is a one-dimensional array, providing an axis is optional; + however, if `x` has more than one dimension, providing an `axis` is + required. + + Returns: + An array having the same data type as `x`. The output array has the + same rank (i.e., number of dimensions) as `x` and has the same shape as + `x`, except for the axis specified by `axis` whose size must equal the + number of elements in indices. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html + """ + + assert x, "TODO" + assert indices, "TODO" + assert axis, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_linalg.py b/src/ragged/common/_linalg.py new file mode 100644 index 0000000..7c08a86 --- /dev/null +++ b/src/ragged/common/_linalg.py @@ -0,0 +1,190 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ._obj import array + + +def matmul(x1: array, x2: array, /) -> array: + """ + Computes the matrix product. + + Args: + x1: First input array. Must have at least one dimension. If `x1` is + one-dimensional having shape `(M,)` and `x2` has more than one + dimension, `x1` is promoted to a two-dimensional array by prepending 1 + to its dimensions (i.e., has shape `(1, M)`). After matrix + multiplication, the prepended dimensions in the returned array are + removed. If `x1` has more than one dimension (including after + vector-to-matrix promotion), `shape(x1)[:-2]` is compatible with + `shape(x2)[:-2]` (after vector-to-matrix promotion). If `x1` has shape + `(..., M, K)`, the innermost two dimensions form matrices on which to + perform matrix multiplication. + x2: Second input array. Must have at least one dimension. If `x2` is + one-dimensional having shape `(N,)` and `x1` has more than one + dimension, `x2` is promoted to a two-dimensional array by appending 1 + to its dimensions (i.e., has shape `(N, 1)`). After matrix + multiplication, the appended dimensions in the returned array are + removed. If `x2` has more than one dimension (including after + vector-to-matrix promotion), `shape(x2)[:-2]` is compatible with + `shape(x1)[:-2]` (after vector-to-matrix promotion). If `x2` has shape + `(..., K, N)`, the innermost two dimensions form matrices on which to + perform matrix multiplication. + + Returns: + If both `x1` and `x2` are one-dimensional arrays having shape `(N,)`, a + zero-dimensional array containing the inner product as its only + element. + + If `x1` is a two-dimensional array having shape `(M, K)` and `x2` is a + two-dimensional array having shape `(K, N)`, a two-dimensional array + containing the conventional matrix product and having shape `(M, N)`. + + If `x1` is a one-dimensional array having shape `(K,)` and `x2` is an + array having shape `(..., K, N)`, an array having shape `(..., N)` + (i.e., prepended dimensions during vector-to-matrix promotion are + removed) and containing the conventional matrix product. + + If `x1` is an array having shape `(..., M, K)` and `x2` is a + one-dimensional array having shape `(K,)`, an array having shape + `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion + are removed) and containing the conventional matrix product. + + If `x1` is a two-dimensional array having shape `(M, K)` and `x2` is an + array having shape `(..., K, N)`, an array having shape `(..., M, N)` + and containing the conventional matrix product for each stacked matrix. + + If `x1` is an array having shape `(..., M, K)` and `x2` is a + two-dimensional array having shape `(K, N)`, an array having shape + `(..., M, N)` and containing the conventional matrix product for each + stacked matrix. + + If either `x1` or `x2` has more than two dimensions, an array having a + shape determined by broadcasting `shape(x1)[:-2]` against + `Shape(x2)[:-2]` and containing the conventional matrix product for + each stacked matrix. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" + + +def matrix_transpose(x: array, /) -> array: + """ + Transposes a matrix (or a stack of matrices) x. + + Args: + x: Input array having shape `(..., M, N)` and whose innermost two + dimensions form `M` by `N` matrices. + + Returns: + An array containing the transpose for each matrix and having shape + `(..., N, M)`. The returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.matrix_transpose.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def tensordot( + x1: array, x2: array, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2 +) -> array: + """ + Returns a tensor contraction of `x1` and `x2` over specific axes. + + The tensordot function corresponds to the generalized matrix product. + + Args: + x1: First input array. + x2: Second input array. Corresponding contracted axes of `x1` and `x2` + must be equal. + axes: Number of axes (dimensions) to contract or explicit sequences of + axes (dimensions) for `x1` and `x2`, respectively. + + If `axes` is an `int` equal to `N`, then contraction is performed + over the last `N` axes of `x1` and the first `N` axes of `x2` in + order. The size of each corresponding axis (dimension) match. Must + be nonnegative. + + If `N` equals 0, the result is the tensor (outer) product. + + If `N` equals 1, the result is the tensor dot product. + + If `N` equals 2, the result is the tensor double contraction. + + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the + first sequence applies to `x1` and the second sequence to `x2`. + Both sequences must have the same length. Each axis (dimension) + `x1_axes[i]` for `x1` must have the same size as the respective + axis (dimension) `x2_axes[i]` for `x2`. Each sequence must consist + of unique (nonnegative) integers that specify valid axes for each + respective array. + + Returns: + An array containing the tensor contraction whose shape consists of the + non-contracted axes (dimensions) of the first array `x1`, followed by + the non-contracted axes (dimensions) of the second array `x2`. The + returned array has a data type determined by type promotion rules. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.tensordot.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert axes, "TODO" + assert False, "TODO" + + +def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: + r""" + Computes the (vector) dot product of two arrays. + + Let $\mathbf{a}$ be a vector in `x1` and $\mathbf{b}$ be a corresponding + vector in `x2`. The dot product is defined as + + $$\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i$$ + + over the dimension specified by `axis` and where $n$ is the dimension size + and $\overline{a_i}$ denotes the complex conjugate if $a_i$ is complex and + the identity if $a_i$ is real-valued. + + Args: + x1: First input array. + x2: Second input array. Must be broadcastable with `x1` for all + non-contracted axes. The size of the axis over which to compute the + dot product is the same size as the respective axis in `x1`. + + The contracted axis (dimension) is not broadcasted. + axis: Axis over which to compute the dot product. Must be an integer on + the interval `[-N, N)`, where `N` is the rank (number of dimensions) of + the shape determined by broadcasting. If specified as a negative + integer, the function determines the axis along which to compute the + dot product by counting backward from the last dimension (where `-1` + refers to the last dimension). + + Returns: + If `x1` and `x2` are both one-dimensional arrays, a zero-dimensional + containing the dot product; otherwise, a non-zero-dimensional array + containing the dot products and having rank `N - 1`, where `N` is the + rank (number of dimensions) of the shape determined by broadcasting + along the non-contracted axes. The returned array has a data type + determined by type promotion. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.vecdot.html + """ + + assert x1, "TODO" + assert x2, "TODO" + assert axis, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_manipulation.py b/src/ragged/common/_manipulation.py new file mode 100644 index 0000000..b287e38 --- /dev/null +++ b/src/ragged/common/_manipulation.py @@ -0,0 +1,271 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def broadcast_arrays(*arrays: array) -> list[array]: + """ + Broadcasts one or more arrays against one another. + + Args: + arrays: An arbitrary number of to-be broadcasted arrays. + + Returns: + A list of broadcasted arrays. Each array has the same shape. Each array + has the same dtype as its corresponding input array. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_arrays.html + """ + + assert arrays, "TODO" + assert False, "TODO" + + +def broadcast_to(x: array, /, shape: tuple[int, ...]) -> array: + """ + Broadcasts an array to a specified shape. + + Args: + x: Array to broadcast. + shape: Array shape. Must be compatible with `x`. If the array is + incompatible with the specified shape, the function raises an + exception. + + Returns: + An array having a specified shape. Must have the same data type as x. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html + """ + + assert x, "TODO" + assert shape, "TODO" + assert False, "TODO" + + +def concat( + arrays: tuple[array, ...] | list[array], /, *, axis: None | int = 0 +) -> array: + """ + Joins a sequence of arrays along an existing axis. + + Args: + arrays: Input arrays to join. The arrays must have the same shape, + except in the dimension specified by `axis`. + axis: Axis along which the arrays will be joined. If `axis` is `None`, + arrays are flattened before concatenation. If `axis` is negative, + the function determines the axis along which to join by counting + from the last dimension. + + Returns: + An output array containing the concatenated values. If the input arrays + have different data types, normal type promotion rules apply. If the + input arrays have the same data type, the output array has the same + data type as the input arrays. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.concat.html + """ + + assert arrays, "TODO" + assert axis, "TODO" + assert False, "TODO" + + +def expand_dims(x: array, /, *, axis: int = 0) -> array: + """ + Expands the shape of an array by inserting a new axis (dimension) of size + one at the position specified by `axis`. + + Args: + x: Input array. + axis: Axis position (zero-based). If `x` has rank (i.e, number of + dimensions) `N`, a valid `axis` must reside on the closed-interval + `[-N-1, N]`. If provided a negative axis, the axis position at + which to insert a singleton dimension is computed as + `N + axis + 1`. Hence, if provided -1, the resolved axis position + is `N` (i.e., a singleton dimension is appended to the input array + `x`). If provided `-N - 1`, the resolved axis position is 0 (i.e., + a singleton dimension is prepended to the input array x). An + `IndexError` exception is raised if provided an invalid axis + position. + + Returns: + An expanded output array having the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert False, "TODO" + + +def flip(x: array, /, *, axis: None | int | tuple[int, ...] = None) -> array: + """ + Reverses the order of elements in an array along the given axis. The shape + of the array is preserved. + + Args: + x: Input array. + axis: Axis (or axes) along which to flip. If `axis` is `None`, the + function flips all input array axes. If `axis` is negative, the + function counts from the last dimension. If provided more than one + axis, the function flips only the specified axes. + + Returns: + An output array having the same data type and shape as `x` and whose + elements, relative to `x`, are reordered. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.flip.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert False, "TODO" + + +def permute_dims(x: array, /, axes: tuple[int, ...]) -> array: + """ + Permutes the axes (dimensions) of an array `x`. + + Args: + x: Input array. + axes: Tuple containing a permutation of `(0, 1, ..., N-1)` where `N` is + the number of axes (dimensions) of `x`. + + Returns: + An array containing the axes permutation. The returned array has the + same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.permute_dims.html + """ + + assert x, "TODO" + assert axes, "TODO" + assert False, "TODO" + + +def reshape(x: array, /, shape: tuple[int, ...], *, copy: None | bool = None) -> array: + """ + Reshapes an array without changing its data. + + Args: + x: Input array to reshape. + shape: A new shape compatible with the original shape. One shape + dimension is allowed to be -1. When a shape dimension is -1, the + corresponding output array shape dimension is inferred from the + length of the array and the remaining dimensions. + copy: Boolean indicating whether or not to copy the input array. If + `True`, the function always copies. If `False`, the function never + copies and raises a `ValueError` in case a copy would be necessary. + If `None`, the function reuses the existing memory buffer if + possible and copies otherwise. + + Returns: + An output array having the same data type and elements as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.reshape.html + """ + + assert x, "TODO" + assert shape, "TODO" + assert copy, "TODO" + assert False, "TODO" + + +def roll( + x: array, + /, + shift: int | tuple[int, ...], + *, + axis: None | int | tuple[int, ...] = None, +) -> array: + """ + Rolls array elements along a specified axis. Array elements that roll + beyond the last position are re-introduced at the first position. Array + elements that roll beyond the first position are re-introduced at the last + position. + + Args: + x: Input array. + shift: Number of places by which the elements are shifted. If `shift` + is a tuple, then `axis` must be a tuple of the same size, and each + of the given axes must be shifted by the corresponding element in + `shift`. If `shift` is an `int` and `axis` a tuple, then the same + shift is used for all specified axes. If a shift is positive, then + array elements are shifted positively (toward larger indices) along + the dimension of `axis`. If a `shift` is negative, then array + elements are shifted negatively (toward smaller indices) along the + dimension of `axis`. + axis: Axis (or axes) along which elements to shift. If `axis` is + `None`, the array is flattened, shifted, and then restored to its + original shape. + + Returns: + An output array having the same data type as `x` and whose elements, + relative to `x`, are shifted. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.roll.html + """ + + assert x, "TODO" + assert shift, "TODO" + assert axis, "TODO" + assert False, "TODO" + + +def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array: + """ + Removes singleton dimensions (axes) from `x`. + + Args: + x: Input array. + axis: Axis (or axes) to squeeze. If a specified axis has a size + greater than one, a `ValueError` is raised. + + Returns: + An output array having the same data type and elements as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert False, "TODO" + + +def stack(arrays: tuple[array, ...] | list[array], /, *, axis: int = 0) -> array: + """ + Joins a sequence of arrays along a new axis. + + Args: + arrays: Input arrays to join. Each array must have the same shape. + axis: Axis along which the arrays will be joined. Providing an `axis` + specifies the index of the new axis in the dimensions of the + result. For example, if `axis` is 0, the new axis will be the first + dimension and the output array will have shape `(N, A, B, C)`; if + `axis` is 1, the new axis will be the second dimension and the + output array will have shape `(A, N, B, C)`; and, if `axis` is -1, + the new axis will be the last dimension and the output array will + have shape `(A, B, C, N)`. A valid axis must be on the interval + `[-N, N)`, where `N` is the rank (number of dimensions) of `x`. + If provided an `axis` outside of the required interval, the + function raises an exception. + + Returns: + An output array having rank `N + 1`, where `N` is the rank (number of + dimensions) of `x`. If the input arrays have different data types, + normal type promotion rules apply. If the input arrays have the same + data type, the output array has the same data type as the input arrays. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.stack.html + """ + + assert arrays, "TODO" + assert axis, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_obj.py b/src/ragged/common/_obj.py index a6aafc6..c237458 100644 --- a/src/ragged/common/_obj.py +++ b/src/ragged/common/_obj.py @@ -1,5 +1,9 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE +""" +https://data-apis.org/array-api/latest/API_specification/array_object.html +""" + from __future__ import annotations import enum @@ -23,6 +27,7 @@ NestedSequence, PyCapsule, Shape, + SupportsBufferProtocol, SupportsDLPack, ) @@ -101,44 +106,67 @@ def _new(cls, impl: ak.Array, shape: Shape, dtype: Dtype, device: Device) -> arr def __init__( self, - array_like: ( + obj: ( array | ak.Array - | SupportsDLPack | bool | int | float - | NestedSequence[bool | int | float] + | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + | SupportsDLPack ), dtype: None | Dtype | type | str = None, device: None | Device = None, + copy: None | bool = None, ): """ - Primary array constructor. + Primary array constructor, same as `ragged.asarray`. Args: - array_like: Data to use as or convert into a ragged array. - dtype: NumPy dtype describing the data (subclass of `np.number`, - without `shape` or `fields`). - device: If `"cpu"`, the array is backed by NumPy and resides in - main memory; if `"cuda"`, the array is backed by CuPy and - resides in CUDA global memory. - """ - - if isinstance(array_like, array): - self._impl = array_like._impl - self._shape, self._dtype = array_like._shape, array_like._dtype - - elif isinstance(array_like, ak.Array): - self._impl = array_like + obj: Object to be converted to an array. May be a Python scalar, a + (possibly nested) sequence of Python scalars, or an object + supporting the Python buffer protocol or DLPack. + dtype: Output array data type. If `dtype` is `None`, the output + array data type is inferred from the data type(s) in `obj`. + If all input values are Python scalars, then, in order of + precedence, + - if all values are of type `bool`, the output data type is + `bool`. + - if all values are of type `int` or are a mixture of `bool` + and `int`, the output data type is `np.int64`. + - if one or more values are `complex` numbers, the output + data type is `np.complex128`. + - if one or more values are `float`s, the output data type + is `np.float64`. + device: Device on which to place the created array. If device is + `None` and `obj` is an array, the output array device is + inferred from `obj`. If `"cpu"`, the array is backed by NumPy + and resides in main memory; if `"cuda"`, the array is backed by + CuPy and resides in CUDA global memory. + copy: Boolean indicating whether or not to copy the input. If `True`, + this function always copies. If `False`, the function never + copies for input which supports the buffer protocol and raises + a ValueError in case a copy would be necessary. If `None`, the + function reuses the existing memory buffer if possible and + copies otherwise. + """ + + if isinstance(obj, array): + self._impl = obj._impl + self._shape, self._dtype = obj._shape, obj._dtype + + elif isinstance(obj, ak.Array): + self._impl = obj self._shape, self._dtype = _shape_dtype(self._impl.layout) - elif isinstance(array_like, (bool, Real)): - self._impl = np.array(array_like) + elif isinstance(obj, (bool, Real)): + self._impl = np.array(obj) self._shape, self._dtype = (), self._impl.dtype else: - self._impl = ak.Array(array_like) + self._impl = ak.Array(obj) self._shape, self._dtype = _shape_dtype(self._impl.layout) if not isinstance(dtype, np.dtype): @@ -149,7 +177,7 @@ def __init__( self._impl = ak.values_astype(self._impl, dtype) self._shape, self._dtype = _shape_dtype(self._impl.layout) else: - self._impl = np.array(array_like, dtype=dtype) + self._impl = np.array(obj, dtype=dtype) self._dtype = dtype if self._dtype.fields is not None: @@ -171,6 +199,8 @@ def __init__( cp = _import.cupy() self._impl = cp.array(self._impl.item()) + assert copy is None, "TODO" + def __str__(self) -> str: """ String representation of the array. @@ -236,8 +266,7 @@ def mT(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.mT.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" @property def ndim(self) -> int: @@ -300,8 +329,7 @@ def T(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.T.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" # methods: https://data-apis.org/array-api/latest/API_specification/array_object.html#methods @@ -312,8 +340,7 @@ def __abs__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__abs__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __add__(self, other: int | float | array, /) -> array: """ @@ -323,8 +350,7 @@ def __add__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__add__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __and__(self, other: int | bool | array, /) -> array: """ @@ -334,8 +360,7 @@ def __and__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__and__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __array_namespace__(self, *, api_version: None | str = None) -> Any: """ @@ -344,10 +369,8 @@ def __array_namespace__(self, *, api_version: None | str = None) -> Any: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html """ - assert api_version is None, "FIXME" - - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert api_version, "TODO" + assert False, "TODO" def __bool__(self) -> bool: # FIXME pylint: disable=E0304 """ @@ -356,8 +379,7 @@ def __bool__(self) -> bool: # FIXME pylint: disable=E0304 https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__bool__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __complex__(self) -> complex: """ @@ -366,8 +388,7 @@ def __complex__(self) -> complex: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__complex__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __dlpack__(self, *, stream: None | int | Any = None) -> PyCapsule: """ @@ -384,10 +405,8 @@ def __dlpack__(self, *, stream: None | int | Any = None) -> PyCapsule: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html """ - assert stream is None, "FIXME" - - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert stream, "TODO" + assert False, "TODO" def __dlpack_device__(self) -> tuple[enum.Enum, int]: """ @@ -399,8 +418,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __eq__(self, other: int | float | bool | array, /) -> array: # type: ignore[override] """ @@ -410,8 +428,7 @@ def __eq__(self, other: int | float | bool | array, /) -> array: # type: ignore https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__eq__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __float__(self) -> float: """ @@ -420,8 +437,7 @@ def __float__(self) -> float: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__float__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __floordiv__(self, other: int | float | array, /) -> array: """ @@ -431,8 +447,7 @@ def __floordiv__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__floordiv__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __ge__(self, other: int | float | array, /) -> array: """ @@ -442,8 +457,7 @@ def __ge__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__ge__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __getitem__(self, key: GetSliceKey, /) -> array: """ @@ -452,8 +466,7 @@ def __getitem__(self, key: GetSliceKey, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__getitem__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __gt__(self, other: int | float | array, /) -> array: """ @@ -463,8 +476,7 @@ def __gt__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__gt__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __index__(self) -> int: # FIXME pylint: disable=E0305 """ @@ -473,8 +485,7 @@ def __index__(self) -> int: # FIXME pylint: disable=E0305 https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__index__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __int__(self) -> int: """ @@ -483,8 +494,7 @@ def __int__(self) -> int: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__int__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __invert__(self) -> array: """ @@ -493,8 +503,7 @@ def __invert__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__invert__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __le__(self, other: int | float | array, /) -> array: """ @@ -504,8 +513,7 @@ def __le__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__le__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __lshift__(self, other: int | array, /) -> array: """ @@ -515,8 +523,7 @@ def __lshift__(self, other: int | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__lshift__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __lt__(self, other: int | float | array, /) -> array: """ @@ -526,8 +533,7 @@ def __lt__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__lt__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __matmul__(self, other: array, /) -> array: """ @@ -536,8 +542,7 @@ def __matmul__(self, other: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__matmul__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __mod__(self, other: int | float | array, /) -> array: """ @@ -547,8 +552,7 @@ def __mod__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__mod__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __mul__(self, other: int | float | array, /) -> array: """ @@ -558,8 +562,7 @@ def __mul__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__mul__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __ne__(self, other: int | float | bool | array, /) -> array: # type: ignore[override] """ @@ -569,8 +572,7 @@ def __ne__(self, other: int | float | bool | array, /) -> array: # type: ignore https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__ne__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __neg__(self) -> array: """ @@ -579,8 +581,7 @@ def __neg__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__neg__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __or__(self, other: int | bool | array, /) -> array: """ @@ -590,8 +591,7 @@ def __or__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__or__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __pos__(self) -> array: """ @@ -600,8 +600,7 @@ def __pos__(self) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__pos__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __pow__(self, other: int | float | array, /) -> array: """ @@ -613,8 +612,7 @@ def __pow__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__pow__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __rshift__(self, other: int | array, /) -> array: """ @@ -624,8 +622,7 @@ def __rshift__(self, other: int | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__rshift__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __setitem__( self, key: SetSliceKey, value: int | float | bool | array, / @@ -636,8 +633,7 @@ def __setitem__( https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__setitem__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __sub__(self, other: int | float | array, /) -> array: """ @@ -647,8 +643,7 @@ def __sub__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__sub__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __truediv__(self, other: int | float | array, /) -> array: """ @@ -658,8 +653,7 @@ def __truediv__(self, other: int | float | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__truediv__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def __xor__(self, other: int | bool | array, /) -> array: """ @@ -669,8 +663,7 @@ def __xor__(self, other: int | bool | array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__xor__.html """ - msg = "not implemented yet, but will be" - raise RuntimeError(msg) + assert False, "TODO" def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> array: """ @@ -688,12 +681,12 @@ def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> ar """ if isinstance(self._impl, ak.Array) and device != ak.backend(self._impl): - assert stream is None, "FIXME: use CuPy stream" + assert stream is None, "TODO" impl = ak.to_backend(self._impl, device) elif isinstance(self._impl, np.ndarray): if device == "cuda": - assert stream is None, "FIXME: use CuPy stream" + assert stream is None, "TODO" cp = _import.cupy() impl = cp.array(self._impl.item()) else: diff --git a/src/ragged/common/_search.py b/src/ragged/common/_search.py new file mode 100644 index 0000000..41e7a90 --- /dev/null +++ b/src/ragged/common/_search.py @@ -0,0 +1,118 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/searching_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: + """ + Returns the indices of the maximum values along a specified axis. + + When the maximum value occurs multiple times, only the indices + corresponding to the first occurrence are returned. + + Args: + x: Input array. + axis: Axis along which to search. If `None`, the function returns the + index of the maximum value of the flattened array. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If `axis` is `None`, a zero-dimensional array containing the index of + the first occurrence of the maximum value; otherwise, a + non-zero-dimensional array containing the indices of the maximum + values. The returned array has data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmax.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: + """ + Returns the indices of the minimum values along a specified axis. + + When the minimum value occurs multiple times, only the indices + corresponding to the first occurrence are returned. + + Args: + x: Input array. + axis: Axis along which to search. If `None`, the function returns the + index of the minimum value of the flattened array. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If `axis` is `None`, a zero-dimensional array containing the index of + the first occurrence of the minimum value; otherwise, a + non-zero-dimensional array containing the indices of the minimum + values. The returned array has data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmin.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def nonzero(x: array, /) -> tuple[array, ...]: + """ + Returns the indices of the array elements which are non-zero. + + Args: + x: Input array. Must have a positive rank. If `x` is zero-dimensional, + the function raises an exception. + + Returns: + A tuple of `k` arrays, one for each dimension of `x` and each of size + `n` (where `n` is the total number of non-zero elements), containing + the indices of the non-zero elements in that dimension. The indices + are returned in row-major, C-style order. The returned array has data + type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def where(condition: array, x1: array, x2: array, /) -> array: + """ + Returns elements chosen from `x1` or `x2` depending on `condition`. + + Args: + condition: When `True`, yield `x1_i`; otherwise, yield `x2_i`. Must be + broadcastable with `x1` and `x2`. + x1: First input array. Must be broadcastable with `condition` and `x2`. + x2: Second input array. Must be broadcastable with `condition` and + `x1`. + + Returns: + An array with elements from `x1` where condition is `True`, and + elements from `x2` elsewhere. The returned array has a data type + determined by type promotion rules with the arrays `x1` and `x2`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html + """ + + assert condition, "TODO" + assert x1, "TODO" + assert x2, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_set.py b/src/ragged/common/_set.py new file mode 100644 index 0000000..a2039f5 --- /dev/null +++ b/src/ragged/common/_set.py @@ -0,0 +1,133 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/set_functions.html +""" + +from __future__ import annotations + +from collections import namedtuple + +from ._obj import array + +unique_all_result = namedtuple( # pylint: disable=C0103 + "unique_all_result", ["values", "indices", "inverse_indices", "counts"] +) + + +def unique_all(x: array, /) -> tuple[array, array, array, array]: + """ + Returns the unique elements of an input array `x`, the first occurring + indices for each unique element in `x`, the indices from the set of unique + elements that reconstruct `x`, and the corresponding `counts` for each + unique element in `x`. + + Args: + x: Input array. If `x` has more than one dimension, the function + flattens `x` and returns the unique elements of the flattened + array. + + Returns: + A namedtuple `(values, indices, inverse_indices, counts)` whose + + - first element has the field name `values` and must be an array + containing the unique elements of `x`. The array has the same data + type as `x`. + - second element has the field name `indices` and is an array containing + the indices (first occurrences) of `x` that result in values. The + array has the same shape as `values` and has the default array index + data type. + - third element has the field name `inverse_indices` and is an array + containing the indices of values that reconstruct `x`. The array has + the same shape as `x` and has data type `np.int64`. + - fourth element has the field name `counts` and is an array containing + the number of times each unique element occurs in `x`. The returned + array has same shape as `values` and has data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html + """ + + assert x, "TODO" + assert False, "TODO" + + +unique_counts_result = namedtuple( # pylint: disable=C0103 + "unique_counts_result", ["values", "counts"] +) + + +def unique_counts(x: array, /) -> tuple[array, array]: + """ + Returns the unique elements of an input array `x` and the corresponding + counts for each unique element in `x`. + + Args: + x: Input array. If `x` has more than one dimension, the function + flattens `x` and returns the unique elements of the flattened + array. + + Returns: + A namedtuple `(values, counts)` whose + + - first element has the field name `values` and is an array containing + the unique elements of `x`. The array has the same data type as `x`. + - second element has the field name `counts` and is an array containing + the number of times each unique element occurs in `x`. The returned + array has same shape as `values` and has data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_counts.html + """ + + assert x, "TODO" + assert False, "TODO" + + +unique_inverse_result = namedtuple( # pylint: disable=C0103 + "unique_inverse_result", ["values", "inverse_indices"] +) + + +def unique_inverse(x: array, /) -> tuple[array, array]: + """ + Returns the unique elements of an input array `x` and the indices from the + set of unique elements that reconstruct `x`. + + Args: + x: Input array. If `x` has more than one dimension, the function + flattens `x` and returns the unique elements of the flattened + array. + + Returns: + A namedtuple `(values, inverse_indices)` whose + + - first element has the field name `values` and is an array containing + the unique elements of `x`. The array has the same data type as `x`. + - second element has the field name `inverse_indices` and is an array + containing the indices of `values` that reconstruct `x`. The array + has the same shape as `x` and data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_inverse.html + """ + + assert x, "TODO" + assert False, "TODO" + + +def unique_values(x: array, /) -> array: + """ + Returns the unique elements of an input array `x`. + + Args: + x: Input array. If `x` has more than one dimension, the function + flattens `x` and returns the unique elements of the flattened + array. + + Returns: + An array containing the set of unique elements in `x`. The returned + array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html + """ + + assert x, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_sorting.py b/src/ragged/common/_sorting.py new file mode 100644 index 0000000..ba7eeec --- /dev/null +++ b/src/ragged/common/_sorting.py @@ -0,0 +1,73 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/sorting_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def argsort( + x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> array: + """ + Returns the indices that sort an array `x` along a specified axis. + + Args: + x: Input array. + axis: Axis along which to sort. If set to -1, the function sorts along + the last axis. + descending: Sort order. If `True`, the returned indices sort `x` in + descending order (by value). If `False`, the returned indices sort + `x` in ascending order (by value). + stable: Sort stability. If `True`, the returned indices will maintain + the relative order of `x` values which compare as equal. If + `False`, the returned indices may or may not maintain the relative + order of `x` values which compare as equal. + + Returns: + An array of indices. The returned array has the same shape as `x`. + The returned array has data type `np.int64`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.argsort.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert descending, "TODO" + assert stable, "TODO" + assert False, "TODO" + + +def sort( + x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True +) -> array: + """ + Returns a sorted copy of an input array `x`. + + Args: + x: Input array. + axis: Axis along which to sort. If set to -1, the function sorts along + the last axis. + descending: Sort order. If `True`, the array is sorted in descending + order (by value). If `False`, the array is sorted in ascending + order (by value). + stable: Sort stability. If `True`, the returned array will maintain the + relative order of `x` values which compare as equal. If `False`, + the returned array may or may not maintain the relative order of + `x` values which compare as equal. + + Returns: + A sorted array. The returned array has the same data type and shape as + `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sort.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert descending, "TODO" + assert stable, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_statistical.py b/src/ragged/common/_statistical.py new file mode 100644 index 0000000..410b955 --- /dev/null +++ b/src/ragged/common/_statistical.py @@ -0,0 +1,319 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/statistical_functions.html +""" + +from __future__ import annotations + +from ._obj import array +from ._typing import Dtype + + +def max( # pylint: disable=W0622 + x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False +) -> array: + """ + Calculates the maximum value of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which maximum values are computed. By default, + the maximum value is computed over the entire array. If a tuple of + integers, maximum values must be computed over multiple axes. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the maximum value was computed over the entire array, a + zero-dimensional array containing the maximum value; otherwise, a + non-zero-dimensional array containing the maximum values. The returned + array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.max.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def mean( + x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False +) -> array: + """ + Calculates the arithmetic mean of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which arithmetic means are computed. By + default, the mean is computed over the entire array. If a tuple of + integers, arithmetic means are computed over multiple axes. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the arithmetic mean was computed over the entire array, a + zero-dimensional array containing the arithmetic mean; otherwise, a + non-zero-dimensional array containing the arithmetic means. The + returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.mean.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def min( # pylint: disable=W0622 + x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False +) -> array: + """ + Calculates the minimum value of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which minimum values are computed. By default, + the minimum value are computed over the entire array. If a tuple of + integers, minimum values are computed over multiple axes. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the minimum value was computed over the entire array, a + zero-dimensional array containing the minimum value; otherwise, a + non-zero-dimensional array containing the minimum values. The returned + array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.min.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def prod( + x: array, + /, + *, + axis: None | int | tuple[int, ...] = None, + dtype: None | Dtype = None, + keepdims: bool = False, +) -> array: + """ + Calculates the product of input array `x` elements. + + Args: + x: Input array. + axis: Axis or axes along which products are computed. By default, the + product is computed over the entire array. If a tuple of integers, + products are computed over multiple axes. + dtype: Data type of the returned array. If `None`, + + - if the default data type corresponding to the data type "kind" + (integer, real-valued floating-point, or complex floating-point) + of `x` has a smaller range of values than the data type of `x` + (e.g., `x` has data type `int64` and the default data type is + `int32`, or `x` has data type `uint64` and the default data type + is `int64`), the returned array has the same data type as `x`. + - if `x` has a real-valued floating-point data type, the returned + array has the default real-valued floating-point data type. + - if `x` has a complex floating-point data type, the returned array + has data type `np.complex128`. + - if `x` has a signed integer data type (e.g., `int16`), the + returned array has data type `np.int64`. + - if `x` has an unsigned integer data type (e.g., `uint16`), the + returned array has data type `np.uint64`. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array will be cast to the specified + data type before computing the product. + + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the product was computed over the entire array, a zero-dimensional + array containing the product; otherwise, a non-zero-dimensional array + containing the products. The returned array has a data type as + described by the `dtype` parameter above. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert dtype, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def std( + x: array, + /, + *, + axis: None | int | tuple[int, ...] = None, + correction: None | int | float = 0.0, + keepdims: bool = False, +) -> array: + """ + Calculates the standard deviation of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which standard deviations are computed. By + default, the standard deviation is computed over the entire array. + If a tuple of integers, standard deviations is computed over + multiple axes. + correction: Degrees of freedom adjustment. Setting this parameter to a + value other than 0 has the effect of adjusting the divisor during + the calculation of the standard deviation according to `N - c` + where `N` corresponds to the total number of elements over which + the standard deviation is computed and `c` corresponds to the + provided degrees of freedom adjustment. When computing the standard + deviation of a population, setting this parameter to 0 is the + standard choice (i.e., the provided array contains data + constituting an entire population). When computing the corrected + sample standard deviation, setting this parameter to 1 is the + standard choice (i.e., the provided array contains data sampled + from a larger population; this is commonly referred to as Bessel's + correction). + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the standard deviation was computed over the entire array, a + zero-dimensional array containing the standard deviation; otherwise, a + non-zero-dimensional array containing the standard deviations. + The returned array has the same data type as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert correction, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def sum( # pylint: disable=W0622 + x: array, + /, + *, + axis: None | int | tuple[int, ...] = None, + dtype: None | Dtype = None, + keepdims: bool = False, +) -> array: + """ + Calculates the sum of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which sums are computed. By default, the sum + is computed over the entire array. If a tuple of integers, sums + are computed over multiple axes. + dtype: Data type of the returned array. If `None`, + + - if the default data type corresponding to the data type "kind" + (integer, real-valued floating-point, or complex floating-point) + of `x` has a smaller range of values than the data type of `x` + (e.g., `x` has data type `int64` and the default data type is + `int32`, or `x` has data type `uint64` and the default data type + is `int64`), the returned array has the same data type as `x`. + - if `x` has a real-valued floating-point data type, the returned + array has the default real-valued floating-point data type. + - if `x` has a complex floating-point data type, the returned array + has data type `np.complex128`. + - if `x` has a signed integer data type (e.g., `int16`), the + returned array has data type `np.int64`. + - if `x` has an unsigned integer data type (e.g., `uint16`), the + returned array has data type `np.uint64`. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array is cast to the specified data + type before computing the sum. + + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the sum was computed over the entire array, a zero-dimensional array + containing the sum; otherwise, an array containing the sums. The + returned array must have a data type as described by the `dtype` + parameter above. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert dtype, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def var( + x: array, + /, + *, + axis: None | int | tuple[int, ...] = None, + correction: None | int | float = 0.0, + keepdims: bool = False, +) -> array: + """ + Calculates the variance of the input array `x`. + + Args: + x: Input array. + axis: Axis or axes along which variances are computed. By default, the + variance is computed over the entire array. If a tuple of integers, + variances are computed over multiple axes. + correction: Degrees of freedom adjustment. Setting this parameter to a + value other than 0 has the effect of adjusting the divisor during + the calculation of the variance according to `N - c` where `N` + corresponds to the total number of elements over which the variance + is computed and `c` corresponds to the provided degrees of freedom + adjustment. When computing the variance of a population, setting + this parameter to 0 is the standard choice (i.e., the provided + array contains data constituting an entire population). When + computing the unbiased sample variance, setting this parameter to 1 + is the standard choice (i.e., the provided array contains data + sampled from a larger population; this is commonly referred to as + Bessel's correction). + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the variance was computed over the entire array, a zero-dimensional + array containing the variance; otherwise, a non-zero-dimensional array + containing the variances. The returned array has the same data type as + `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert correction, "TODO" + assert keepdims, "TODO" + assert False, "TODO" diff --git a/src/ragged/common/_typing.py b/src/ragged/common/_typing.py index 72bbca8..8567153 100644 --- a/src/ragged/common/_typing.py +++ b/src/ragged/common/_typing.py @@ -1,8 +1,13 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE +""" +Borrows liberally from https://github.com/numpy/numpy/blob/main/numpy/array_api/_typing.py +""" + from __future__ import annotations import numbers +import sys from typing import Any, Literal, Optional, Protocol, TypeVar, Union import numpy as np @@ -10,6 +15,14 @@ T_co = TypeVar("T_co", covariant=True) +if sys.version_info >= (3, 12): + from collections.abc import ( # pylint: disable=W0611 + Buffer as SupportsBufferProtocol, + ) +else: + SupportsBufferProtocol = Any + + # not actually checked because of https://github.com/python/typing/discussions/1145 class NestedSequence(Protocol[T_co]): def __getitem__(self, key: int, /) -> T_co | NestedSequence[T_co]: diff --git a/src/ragged/common/_utility.py b/src/ragged/common/_utility.py new file mode 100644 index 0000000..0e74e1c --- /dev/null +++ b/src/ragged/common/_utility.py @@ -0,0 +1,89 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/latest/API_specification/utility_functions.html +""" + +from __future__ import annotations + +from ._obj import array + + +def all( # pylint: disable=W0622 + x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False +) -> array: + """ + Tests whether all input array elements evaluate to `True` along a specified + axis. + + Args: + x: Input array. + axis: Axis or axes along which to perform a logical AND reduction. By + default, a logical AND reduction is performed over the entire + array. If a tuple of integers, logical AND reductions are performed + over multiple axes. A valid `axis` must be an integer on the + interval `[-N, N)`, where `N` is the rank (number of dimensions) of + `x`. If an `axis` is specified as a negative integer, the function + must determine the axis along which to perform a reduction by + counting backward from the last dimension (where -1 refers to the + last dimension). If provided an invalid `axis`, the function raises + an exception. + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If a logical AND reduction was performed over the entire array, the + returned array is a zero-dimensional array containing the test result; + otherwise, the returned array is a non-zero-dimensional array + containing the test results. The returned array has data type + `np.bool_`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.all.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" + + +def any( # pylint: disable=W0622 + x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False +) -> array: + """ + Tests whether any input array element evaluates to True along a specified + axis. + + Args: + x: Input array. + axis: Axis or axes along which to perform a logical OR reduction. By + default, a logical OR reduction is performed over the entire array. + If a tuple of integers, logical OR reductions aer performed over + multiple axes. A valid `axis` must be an integer on the interval + `[-N, N)`, where `N` is the rank (number of dimensions) of `x`. If + an `axis` is specified as a negative integer, the function + determines the axis along which to perform a reduction by counting + backward from the last dimension (where -1 refers to the last + dimension). If provided an invalid `axis`, the function raises an + exception. + keepdims: If `True`, the reduced axes (dimensions) aer included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If a logical OR reduction was performed over the entire array, the + returned array is a zero-dimensional array containing the test result; + otherwise, the returned array is a non-zero-dimensional array + containing the test results. The returned array has data type + `np.bool_`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.any.html + """ + + assert x, "TODO" + assert axis, "TODO" + assert keepdims, "TODO" + assert False, "TODO" diff --git a/src/ragged/v202212/__init__.py b/src/ragged/v202212/__init__.py index b49396c..b193f44 100644 --- a/src/ragged/v202212/__init__.py +++ b/src/ragged/v202212/__init__.py @@ -11,6 +11,275 @@ from __future__ import annotations +__array_api_version__ = "2022.12" + +from ._creation import ( + arange, + asarray, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, +) +from ._datatype import ( + astype, + can_cast, + finfo, + iinfo, + isdtype, + result_type, +) +from ._elementwise import ( # pylint: disable=W0622 + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + imag, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) +from ._indexing import ( + take, +) +from ._linalg import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) +from ._manipulation import ( + broadcast_arrays, + broadcast_to, + concat, + expand_dims, + flip, + permute_dims, + reshape, + roll, + squeeze, + stack, +) from ._obj import array +from ._search import ( + argmax, + argmin, + nonzero, + where, +) +from ._set import ( + unique_all, + unique_counts, + unique_inverse, + unique_values, +) +from ._sorting import ( + argsort, + sort, +) +from ._statistical import ( # pylint: disable=W0622 + max, + mean, + min, + prod, + std, + sum, + var, +) +from ._utility import ( # pylint: disable=W0622 + all, + any, +) -__all__ = ["array"] +__all__ = [ + "__array_api_version__", + # _creation + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # _datatype + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # _elementwise + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # _indexing + "take", + # _linalg + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # _manipulation + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", + # _obj + "array", + # _search + "argmax", + "argmin", + "nonzero", + "where", + # _set + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # _sorting + "argsort", + "sort", + # _statistical + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # _utility + "all", + "any", +] diff --git a/src/ragged/v202212/_creation.py b/src/ragged/v202212/_creation.py new file mode 100644 index 0000000..2148440 --- /dev/null +++ b/src/ragged/v202212/_creation.py @@ -0,0 +1,45 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/creation_functions.html +""" + +from __future__ import annotations + +from ..common._creation import ( + arange, + asarray, + empty, + empty_like, + eye, + from_dlpack, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, +) + +__all__ = [ + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", +] diff --git a/src/ragged/v202212/_datatype.py b/src/ragged/v202212/_datatype.py new file mode 100644 index 0000000..35d9923 --- /dev/null +++ b/src/ragged/v202212/_datatype.py @@ -0,0 +1,25 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/data_type_functions.html +""" + +from __future__ import annotations + +from ..common._datatype import ( + astype, + can_cast, + finfo, + iinfo, + isdtype, + result_type, +) + +__all__ = [ + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", +] diff --git a/src/ragged/v202212/_elementwise.py b/src/ragged/v202212/_elementwise.py new file mode 100644 index 0000000..6960051 --- /dev/null +++ b/src/ragged/v202212/_elementwise.py @@ -0,0 +1,131 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/elementwise_functions.html +""" + +from __future__ import annotations + +from ..common._elementwise import ( # pylint: disable=W0622 + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + imag, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) + +__all__ = [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", +] diff --git a/src/ragged/v202212/_indexing.py b/src/ragged/v202212/_indexing.py new file mode 100644 index 0000000..f1553c2 --- /dev/null +++ b/src/ragged/v202212/_indexing.py @@ -0,0 +1,11 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/indexing_functions.html +""" + +from __future__ import annotations + +from ..common._indexing import take + +__all__ = ["take"] diff --git a/src/ragged/v202212/_linalg.py b/src/ragged/v202212/_linalg.py new file mode 100644 index 0000000..5dbe366 --- /dev/null +++ b/src/ragged/v202212/_linalg.py @@ -0,0 +1,11 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/linear_algebra_functions.html +""" + +from __future__ import annotations + +from ..common._linalg import matmul, matrix_transpose, tensordot, vecdot + +__all__ = ["matmul", "matrix_transpose", "tensordot", "vecdot"] diff --git a/src/ragged/v202212/_manipulation.py b/src/ragged/v202212/_manipulation.py new file mode 100644 index 0000000..17caf9a --- /dev/null +++ b/src/ragged/v202212/_manipulation.py @@ -0,0 +1,33 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/manipulation_functions.html +""" + +from __future__ import annotations + +from ..common._manipulation import ( + broadcast_arrays, + broadcast_to, + concat, + expand_dims, + flip, + permute_dims, + reshape, + roll, + squeeze, + stack, +) + +__all__ = [ + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", +] diff --git a/src/ragged/v202212/_obj.py b/src/ragged/v202212/_obj.py index d9b59ac..3a4e2b2 100644 --- a/src/ragged/v202212/_obj.py +++ b/src/ragged/v202212/_obj.py @@ -1,5 +1,9 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE +""" +https://data-apis.org/array-api/2022.12/API_specification/array_object.html +""" + from __future__ import annotations from ..common._obj import array as common_array diff --git a/src/ragged/v202212/_search.py b/src/ragged/v202212/_search.py new file mode 100644 index 0000000..3d2d700 --- /dev/null +++ b/src/ragged/v202212/_search.py @@ -0,0 +1,16 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/searching_functions.html +""" + +from __future__ import annotations + +from ..common._search import argmax, argmin, nonzero, where + +__all__ = [ + "argmax", + "argmin", + "nonzero", + "where", +] diff --git a/src/ragged/v202212/_set.py b/src/ragged/v202212/_set.py new file mode 100644 index 0000000..bb1ea0b --- /dev/null +++ b/src/ragged/v202212/_set.py @@ -0,0 +1,11 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/set_functions.html +""" + +from __future__ import annotations + +from ..common._set import unique_all, unique_counts, unique_inverse, unique_values + +__all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"] diff --git a/src/ragged/v202212/_sorting.py b/src/ragged/v202212/_sorting.py new file mode 100644 index 0000000..b6702d5 --- /dev/null +++ b/src/ragged/v202212/_sorting.py @@ -0,0 +1,11 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/sorting_functions.html +""" + +from __future__ import annotations + +from ..common._sorting import argsort, sort + +__all__ = ["argsort", "sort"] diff --git a/src/ragged/v202212/_statistical.py b/src/ragged/v202212/_statistical.py new file mode 100644 index 0000000..1c2a625 --- /dev/null +++ b/src/ragged/v202212/_statistical.py @@ -0,0 +1,19 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html +""" + +from __future__ import annotations + +from ..common._statistical import ( # pylint: disable=W0622 + max, + mean, + min, + prod, + std, + sum, + var, +) + +__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"] diff --git a/src/ragged/v202212/_utility.py b/src/ragged/v202212/_utility.py new file mode 100644 index 0000000..1a120cb --- /dev/null +++ b/src/ragged/v202212/_utility.py @@ -0,0 +1,11 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +""" +https://data-apis.org/array-api/2022.12/API_specification/utility_functions.html +""" + +from __future__ import annotations + +from ..common._utility import all, any # pylint: disable=W0622 + +__all__ = ["all", "any"] diff --git a/tests/test_0001_initial_array_object.py b/tests/test_0001_initial_array_object.py index d2138dd..2cdd1ef 100644 --- a/tests/test_0001_initial_array_object.py +++ b/tests/test_0001_initial_array_object.py @@ -6,5 +6,123 @@ def test(): + assert ragged.__array_api_version__ == "2022.12" + a = ragged.array([[1, 2], [3]]) assert a is not None + + assert ragged.arange is not None + assert ragged.asarray is not None + assert ragged.empty is not None + assert ragged.empty_like is not None + assert ragged.eye is not None + assert ragged.from_dlpack is not None + assert ragged.full is not None + assert ragged.full_like is not None + assert ragged.linspace is not None + assert ragged.meshgrid is not None + assert ragged.ones is not None + assert ragged.ones_like is not None + assert ragged.tril is not None + assert ragged.triu is not None + assert ragged.zeros is not None + assert ragged.zeros_like is not None + assert ragged.astype is not None + assert ragged.can_cast is not None + assert ragged.finfo is not None + assert ragged.iinfo is not None + assert ragged.isdtype is not None + assert ragged.result_type is not None + assert ragged.abs is not None + assert ragged.acos is not None + assert ragged.acosh is not None + assert ragged.add is not None + assert ragged.asin is not None + assert ragged.asinh is not None + assert ragged.atan is not None + assert ragged.atan2 is not None + assert ragged.atanh is not None + assert ragged.bitwise_and is not None + assert ragged.bitwise_invert is not None + assert ragged.bitwise_left_shift is not None + assert ragged.bitwise_or is not None + assert ragged.bitwise_right_shift is not None + assert ragged.bitwise_xor is not None + assert ragged.ceil is not None + assert ragged.conj is not None + assert ragged.cos is not None + assert ragged.cosh is not None + assert ragged.divide is not None + assert ragged.equal is not None + assert ragged.exp is not None + assert ragged.expm1 is not None + assert ragged.floor is not None + assert ragged.floor_divide is not None + assert ragged.greater is not None + assert ragged.greater_equal is not None + assert ragged.imag is not None + assert ragged.isfinite is not None + assert ragged.isinf is not None + assert ragged.isnan is not None + assert ragged.less is not None + assert ragged.less_equal is not None + assert ragged.log is not None + assert ragged.log1p is not None + assert ragged.log2 is not None + assert ragged.log10 is not None + assert ragged.logaddexp is not None + assert ragged.logical_and is not None + assert ragged.logical_not is not None + assert ragged.logical_or is not None + assert ragged.logical_xor is not None + assert ragged.multiply is not None + assert ragged.negative is not None + assert ragged.not_equal is not None + assert ragged.positive is not None + assert ragged.pow is not None + assert ragged.real is not None + assert ragged.remainder is not None + assert ragged.round is not None + assert ragged.sign is not None + assert ragged.sin is not None + assert ragged.sinh is not None + assert ragged.square is not None + assert ragged.sqrt is not None + assert ragged.subtract is not None + assert ragged.tan is not None + assert ragged.tanh is not None + assert ragged.trunc is not None + assert ragged.take is not None + assert ragged.matmul is not None + assert ragged.matrix_transpose is not None + assert ragged.tensordot is not None + assert ragged.vecdot is not None + assert ragged.broadcast_arrays is not None + assert ragged.broadcast_to is not None + assert ragged.concat is not None + assert ragged.expand_dims is not None + assert ragged.flip is not None + assert ragged.permute_dims is not None + assert ragged.reshape is not None + assert ragged.roll is not None + assert ragged.squeeze is not None + assert ragged.stack is not None + assert ragged.argmax is not None + assert ragged.argmin is not None + assert ragged.nonzero is not None + assert ragged.where is not None + assert ragged.unique_all is not None + assert ragged.unique_counts is not None + assert ragged.unique_inverse is not None + assert ragged.unique_values is not None + assert ragged.argsort is not None + assert ragged.sort is not None + assert ragged.max is not None + assert ragged.mean is not None + assert ragged.min is not None + assert ragged.prod is not None + assert ragged.std is not None + assert ragged.sum is not None + assert ragged.var is not None + assert ragged.all is not None + assert ragged.any is not None