diff --git a/src/ragged/_spec_searching_functions.py b/src/ragged/_spec_searching_functions.py index d5e717c..10399ae 100644 --- a/src/ragged/_spec_searching_functions.py +++ b/src/ragged/_spec_searching_functions.py @@ -6,7 +6,19 @@ from __future__ import annotations -from ._spec_array_object import array +import awkward as ak +import numpy as np + +from ._spec_array_object import _box, _unbox, array + + +def _remove_optiontype(x: ak.contents.Content) -> ak.contents.Content: + if x.is_list: + return x.copy(content=_remove_optiontype(x.content)) + elif x.is_option: + return x.content + else: + return x def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: @@ -34,10 +46,21 @@ def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmax.html """ - assert x, "TODO" - assert axis, "TODO" - assert keepdims, "TODO" - assert False, "TODO 124" + out = np.argmax(*_unbox(x), axis=axis, keepdims=keepdims) + + if out is None: + msg = "cannot compute argmax of an array with no data" + raise ValueError(msg) + + if isinstance(out, ak.Array): + if ak.any(ak.is_none(out, axis=-1)): + msg = f"cannot compute argmax at axis={axis} because some lists at this depth have zero length" + raise ValueError(msg) + out = ak.Array( + _remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs + ) + + return _box(type(x), out) def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array: @@ -65,10 +88,21 @@ def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmin.html """ - assert x, "TODO" - assert axis, "TODO" - assert keepdims, "TODO" - assert False, "TODO 125" + out = np.argmin(*_unbox(x), axis=axis, keepdims=keepdims) + + if out is None: + msg = "cannot compute argmin of an array with no data" + raise ValueError(msg) + + if isinstance(out, ak.Array): + if ak.any(ak.is_none(out, axis=-1)): + msg = f"cannot compute argmin at axis={axis} because some lists at this depth have zero length" + raise ValueError(msg) + out = ak.Array( + _remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs + ) + + return _box(type(x), out) def nonzero(x: array, /) -> tuple[array, ...]: diff --git a/src/ragged/_spec_statistical_functions.py b/src/ragged/_spec_statistical_functions.py index d087411..1d20973 100644 --- a/src/ragged/_spec_statistical_functions.py +++ b/src/ragged/_spec_statistical_functions.py @@ -17,7 +17,7 @@ def _regularize_axis( axis: None | int | tuple[int, ...], ndim: int -) -> None | int | tuple[int, ...]: +) -> None | tuple[int, ...]: if axis is None: return axis elif isinstance(axis, numbers.Integral): @@ -96,13 +96,20 @@ def max( # pylint: disable=W0622 if isinstance(axis, tuple): (out,) = _unbox(x) for axis_item in axis[::-1]: - out = ak.max(out, axis=axis_item, keepdims=keepdims, mask_identity=False) + if isinstance(out, ak.Array): + out = ak.max( + out, axis=axis_item, keepdims=keepdims, mask_identity=False + ) + else: + out = np.max(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box( - type(x), - ak.max(*_unbox(x), axis=axis, keepdims=keepdims, mask_identity=False), - ) + (tmp,) = _unbox(x) + if isinstance(tmp, ak.Array): + out = ak.max(tmp, axis=axis, keepdims=keepdims, mask_identity=False) + else: + out = np.max(tmp, axis=axis, keepdims=keepdims) + return _box(type(x), out) def mean( @@ -133,13 +140,13 @@ def mean( axis = _regularize_axis(axis, x.ndim) if isinstance(axis, tuple): - sumwx = ak.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) + sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims) for axis_item in axis[-2::-1]: - sumwx = ak.sum(sumwx, axis=axis_item, keepdims=keepdims) - sumw = ak.sum(sumw, axis=axis_item, keepdims=keepdims) + sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims) + sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims) else: - sumwx = ak.sum(*_unbox(x), axis=axis, keepdims=keepdims) + sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims) sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims) with np.errstate(invalid="ignore", divide="ignore"): @@ -176,13 +183,20 @@ def min( # pylint: disable=W0622 if isinstance(axis, tuple): (out,) = _unbox(x) for axis_item in axis[::-1]: - out = ak.min(out, axis=axis_item, keepdims=keepdims, mask_identity=False) + if isinstance(out, ak.Array): + out = ak.min( + out, axis=axis_item, keepdims=keepdims, mask_identity=False + ) + else: + out = np.min(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box( - type(x), - ak.min(*_unbox(x), axis=axis, keepdims=keepdims, mask_identity=False), - ) + (tmp,) = _unbox(x) + if isinstance(tmp, ak.Array): + out = ak.min(tmp, axis=axis, keepdims=keepdims, mask_identity=False) + else: + out = np.min(tmp, axis=axis, keepdims=keepdims) + return _box(type(x), out) def prod( @@ -194,46 +208,46 @@ def prod( keepdims: bool = False, ) -> array: """ - Calculates the product of input array `x` elements. - - Args: - x: Input array. - axis: Axis or axes along which products are computed. By default, the - product is computed over the entire array. If a tuple of integers, - products are computed over multiple axes. - dtype: Data type of the returned array. If `None`, - - - if the default data type corresponding to the data type "kind" - (integer, real-valued floating-point, or complex floating-point) - of `x` has a smaller range of values than the data type of `x` - (e.g., `x` has data type `int64` and the default data type is - `int32`, or `x` has data type `uint64` and the default data type - is `int64`), the returned array has the same data type as `x`. - - if `x` has a real-valued floating-point data type, the returned - array has the default real-valued floating-point data type. - - if `x` has a complex floating-point data type, the returned array - has data type `np.complex128`. - - if `x` has a signed integer data type (e.g., `int16`), the - returned array has data type `np.int64`. - - if `x` has an unsigned integer data type (e.g., `uint16`), the - returned array has data type `np.uint64`. - - If the data type (either specified or resolved) differs from the - data type of `x`, the input array will be cast to the specified - data type before computing the product. - - keepdims: If `True`, the reduced axes (dimensions) are included in the - result as singleton dimensions, and, accordingly, the result is - broadcastable with the input array. Otherwise, if `False`, the - reduced axes (dimensions) are not included in the result. - - Returns: - If the product was computed over the entire array, a zero-dimensional - array containing the product; otherwise, a non-zero-dimensional array - containing the products. The returned array has a data type as - described by the `dtype` parameter above. - - https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html + Calculates the product of input array `x` elements. + + Args: + x: Input array. + axis: Axis or axes along which products are computed. By default, the + product is computed over the entire array. If a tuple of integers, + products are computed over multiple axes. + dtype: Data type of the returned array. If `None`, + + - if the default data type corresponding to the data type "kind" + a (integer, real-valued floating-point, or complex floating-point) + of `x` has a smaller range of values than the data type of `x` + (e.g., `x` has data type `int64` and the default data type is + `int32`, or `x` has data type `uint64` and the default data type + is `int64`), the returned array has the same data type as `x`. + - if `x` has a real-valued floating-point data type, the returned + array has the default real-valued floating-point data type. + - if `x` has a complex floating-point data type, the returned array + has data type `np.complex128`. + - if `x` has a signed integer data type (e.g., `int16`), the + returned array has data type `np.int64`. + - if `x` has an unsigned integer data type (e.g., `uint16`), the + returned array has data type `np.uint64`. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array will be cast to the specified + data type before computing the product. + + keepdims: If `True`, the reduced axes (dimensions) are included in the + result as singleton dimensions, and, accordingly, the result is + broadcastable with the input array. Otherwise, if `False`, the + reduced axes (dimensions) are not included in the result. + + Returns: + If the product was computed over the entire array, a zero-dimensional + array containing the product; otherwise, a non-zero-dimensional array + containing the products. The returned array has a data type as + described by the `dtype` parameter above. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html """ axis = _regularize_axis(axis, x.ndim) @@ -243,10 +257,10 @@ def prod( if isinstance(axis, tuple): (out,) = _unbox(arr) for axis_item in axis[::-1]: - out = ak.prod(out, axis=axis_item, keepdims=keepdims) + out = np.prod(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box(type(x), ak.prod(*_unbox(arr), axis=axis, keepdims=keepdims)) + return _box(type(x), np.prod(*_unbox(arr), axis=axis, keepdims=keepdims)) def std( @@ -357,10 +371,10 @@ def sum( # pylint: disable=W0622 if isinstance(axis, tuple): (out,) = _unbox(arr) for axis_item in axis[::-1]: - out = ak.sum(out, axis=axis_item, keepdims=keepdims) + out = np.sum(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box(type(x), ak.sum(*_unbox(arr), axis=axis, keepdims=keepdims)) + return _box(type(x), np.sum(*_unbox(arr), axis=axis, keepdims=keepdims)) def var( @@ -408,16 +422,16 @@ def var( axis = _regularize_axis(axis, x.ndim) if isinstance(axis, tuple): - sumwxx = ak.sum(np.square(*_unbox(x)), axis=axis[-1], keepdims=keepdims) - sumwx = ak.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) + sumwxx = np.sum(np.square(*_unbox(x)), axis=axis[-1], keepdims=keepdims) + sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims) sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims) for axis_item in axis[-2::-1]: - sumwxx = ak.sum(sumwxx, axis=axis_item, keepdims=keepdims) - sumwx = ak.sum(sumwx, axis=axis_item, keepdims=keepdims) - sumw = ak.sum(sumw, axis=axis_item, keepdims=keepdims) + sumwxx = np.sum(sumwxx, axis=axis_item, keepdims=keepdims) + sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims) + sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims) else: - sumwxx = ak.sum(np.square(*_unbox(x)), axis=axis, keepdims=keepdims) - sumwx = ak.sum(*_unbox(x), axis=axis, keepdims=keepdims) + sumwxx = np.sum(np.square(*_unbox(x)), axis=axis, keepdims=keepdims) + sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims) sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims) with np.errstate(invalid="ignore", divide="ignore"): diff --git a/src/ragged/_spec_utility_functions.py b/src/ragged/_spec_utility_functions.py index dacc778..05aed13 100644 --- a/src/ragged/_spec_utility_functions.py +++ b/src/ragged/_spec_utility_functions.py @@ -6,7 +6,7 @@ from __future__ import annotations -import awkward as ak +import numpy as np from ._spec_array_object import _box, _unbox, array from ._spec_statistical_functions import _regularize_axis @@ -51,10 +51,10 @@ def all( # pylint: disable=W0622 if isinstance(axis, tuple): (out,) = _unbox(x) for axis_item in axis[::-1]: - out = ak.all(out, axis=axis_item, keepdims=keepdims) + out = np.all(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box(type(x), ak.all(*_unbox(x), axis=axis, keepdims=keepdims)) + return _box(type(x), np.all(*_unbox(x), axis=axis, keepdims=keepdims)) def any( # pylint: disable=W0622 @@ -96,7 +96,7 @@ def any( # pylint: disable=W0622 if isinstance(axis, tuple): (out,) = _unbox(x) for axis_item in axis[::-1]: - out = ak.any(out, axis=axis_item, keepdims=keepdims) + out = np.any(out, axis=axis_item, keepdims=keepdims) return _box(type(x), out) else: - return _box(type(x), ak.any(*_unbox(x), axis=axis, keepdims=keepdims)) + return _box(type(x), np.any(*_unbox(x), axis=axis, keepdims=keepdims)) diff --git a/tests/test_spec_searching_functions.py b/tests/test_spec_searching_functions.py index 4410eb6..b9ecf0d 100644 --- a/tests/test_spec_searching_functions.py +++ b/tests/test_spec_searching_functions.py @@ -6,6 +6,8 @@ from __future__ import annotations +import pytest + import ragged @@ -14,3 +16,45 @@ def test_existence(): assert ragged.argmin is not None assert ragged.nonzero is not None assert ragged.where is not None + + +def test_argmax(): + data = ragged.array( + [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] + ) + assert ragged.argmax(data, axis=None).tolist() == 9 + assert ( + ragged.argmax(data, axis=0).tolist() + == ragged.argmax(data, axis=-3).tolist() + == [[1, 1, 0], [1], [0, 0, 0, 0]] + ) + assert ( + ragged.argmax(data, axis=1).tolist() # type: ignore[comparison-overlap] + == ragged.argmax(data, axis=-2).tolist() + == [[0, 0, 0], [], [2, 2, 2, 2]] + ) + with pytest.raises(ValueError, match=".*axis.*"): + ragged.argmax(data, axis=2) + with pytest.raises(ValueError, match=".*axis.*"): + ragged.argmax(data, axis=-1) + + +def test_argmin(): + data = ragged.array( + [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] + ) + assert ragged.argmin(data, axis=None).tolist() == 0 + assert ( + ragged.argmin(data, axis=0).tolist() + == ragged.argmin(data, axis=-3).tolist() + == [[0, 0, 0], [1], [0, 0, 0, 0]] + ) + assert ( + ragged.argmin(data, axis=1).tolist() # type: ignore[comparison-overlap] + == ragged.argmin(data, axis=-2).tolist() + == [[0, 0, 0], [], [0, 0, 2, 2]] + ) + with pytest.raises(ValueError, match=".*axis.*"): + ragged.argmin(data, axis=2) + with pytest.raises(ValueError, match=".*axis.*"): + ragged.argmin(data, axis=-1)