From a8585e6b1a95f988f3ac8bd464929ae655921177 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 2 Jan 2024 12:55:27 -0600 Subject: [PATCH] feat: reducer functions --- src/ragged/_spec_array_object.py | 5 ++ src/ragged/_spec_statistical_functions.py | 60 ++++++++++++++++++++--- tests/test_spec_statistical_functions.py | 56 +++++++++++++++++++++ 3 files changed, 115 insertions(+), 6 deletions(-) diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index 493e669..f8a8f16 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -249,6 +249,11 @@ def __repr__(self) -> str: ) return f"ragged.array([\n {prep}\n])" + def tolist( + self, + ) -> bool | int | float | complex | NestedSequence[bool | int | float | complex]: + return self._impl.tolist() # type: ignore[no-any-return,union-attr] + # Attributes: https://data-apis.org/array-api/latest/API_specification/array_object.html#attributes @property diff --git a/src/ragged/_spec_statistical_functions.py b/src/ragged/_spec_statistical_functions.py index 8a8dd0a..43ba2d1 100644 --- a/src/ragged/_spec_statistical_functions.py +++ b/src/ragged/_spec_statistical_functions.py @@ -6,10 +6,52 @@ from __future__ import annotations -from ._spec_array_object import array +import numbers + +import awkward as ak +import numpy as np + +from ._spec_array_object import _box, _unbox, array from ._typing import Dtype +def _regularize_axis( + axis: None | int | tuple[int, ...], ndim: int +) -> None | int | tuple[int, ...]: + if axis is None: + return axis + elif isinstance(axis, numbers.Integral): + out = axis + ndim if axis < 0 else axis # type: ignore[operator] + if not 0 <= out < ndim: + msg = f"axis {axis} is out of bounds for an array with {ndim} dimensions" + raise ak.errors.AxisError(msg) + return out # type: ignore[no-any-return] + else: + out = [] + for x in axis: # type: ignore[union-attr] + out.append(x + ndim if x < 0 else x) + if not 0 < out[-1] < ndim: + msg = f"axis {x} is out of bounds for an array with {ndim} dimensions" + return tuple(sorted(out)) + + +def _regularize_dtype(dtype: None | Dtype, array_dtype: Dtype) -> Dtype: + if dtype is None: + if array_dtype.kind in ("b", "i"): + return np.dtype(np.int64) + elif array_dtype.kind == "u": + return np.dtype(np.uint64) + elif array_dtype.kind == "f": + return np.dtype(np.float64) + elif array_dtype.kind == "c": + return np.dtype(np.complex128) + else: + msg = f"unrecognized dtype.kind: {array_dtype.kind}" + raise AssertionError(msg) + else: + return dtype + + def max( # pylint: disable=W0622 x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False ) -> array: @@ -263,11 +305,17 @@ def sum( # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html """ - assert x, "TODO" - assert axis, "TODO" - assert dtype, "TODO" - assert keepdims, "TODO" - assert False, "TODO 139" + axis = _regularize_axis(axis, x.ndim) + dtype = _regularize_dtype(dtype, x.dtype) + arr = _box(type(x), ak.values_astype(*_unbox(x), dtype)) if x.dtype == dtype else x + + if isinstance(axis, tuple): + (out,) = _unbox(arr) + for axis_item in axis[::-1]: + out = ak.sum(out, axis=axis_item, keepdims=keepdims) + return _box(type(x), out) + else: + return _box(type(x), ak.sum(*_unbox(arr), axis=axis, keepdims=keepdims)) def var( diff --git a/tests/test_spec_statistical_functions.py b/tests/test_spec_statistical_functions.py index 252a5f3..ff8003e 100644 --- a/tests/test_spec_statistical_functions.py +++ b/tests/test_spec_statistical_functions.py @@ -6,6 +6,8 @@ from __future__ import annotations +import pytest + import ragged @@ -17,3 +19,57 @@ def test_existence(): assert ragged.std is not None assert ragged.sum is not None assert ragged.var is not None + + +def test_sum(): + data = ragged.array( + [[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]] + ) + assert ragged.sum(data, axis=None).tolist() == pytest.approx(49.5) + assert ( + ragged.sum(data, axis=0).tolist() # type: ignore[comparison-overlap] + == ragged.sum(data, axis=-3).tolist() + == [ + pytest.approx([3.3, 5.5, 2.2]), + pytest.approx([5.5]), + pytest.approx([6.6, 7.7, 8.8, 9.9]), + ] + ) + assert ( + ragged.sum(data, axis=1).tolist() # type: ignore[comparison-overlap] + == ragged.sum(data, axis=-2).tolist() + == [ + pytest.approx([0.0, 1.1, 2.2]), + pytest.approx([]), + pytest.approx([15.4, 12.1, 8.8, 9.9]), + ] + ) + assert ( + ragged.sum(data, axis=2).tolist() # type: ignore[comparison-overlap] + == ragged.sum(data, axis=-1).tolist() + == [ + pytest.approx([3.3, 0.0]), + pytest.approx([]), + pytest.approx([7.7, 5.5, 33.0]), + ] + ) + assert ( + ragged.sum(data, axis=(0, 1)).tolist() + == ragged.sum(data, axis=(1, 0)).tolist() + == pytest.approx([15.4, 13.2, 11.0, 9.9]) + ) + assert ( + ragged.sum(data, axis=(0, 2)).tolist() + == ragged.sum(data, axis=(2, 0)).tolist() + == pytest.approx([11.0, 5.5, 33.0]) + ) + assert ( + ragged.sum(data, axis=(1, 2)).tolist() + == ragged.sum(data, axis=(2, 1)).tolist() + == pytest.approx([3.3, 0.0, 46.2]) + ) + assert ( + ragged.sum(data, axis=(0, 1, 2)).tolist() + == ragged.sum(data, axis=(-1, 0, 1)).tolist() + == pytest.approx(49.5) + )