Skip to content

Commit

Permalink
feat: reducer functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Jan 2, 2024
1 parent 4f1670c commit a8585e6
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 6 deletions.
5 changes: 5 additions & 0 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def __repr__(self) -> str:
)
return f"ragged.array([\n {prep}\n])"

def tolist(
self,
) -> bool | int | float | complex | NestedSequence[bool | int | float | complex]:
return self._impl.tolist() # type: ignore[no-any-return,union-attr]

# Attributes: https://data-apis.org/array-api/latest/API_specification/array_object.html#attributes

@property
Expand Down
60 changes: 54 additions & 6 deletions src/ragged/_spec_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,52 @@

from __future__ import annotations

from ._spec_array_object import array
import numbers

import awkward as ak
import numpy as np

from ._spec_array_object import _box, _unbox, array
from ._typing import Dtype


def _regularize_axis(
axis: None | int | tuple[int, ...], ndim: int
) -> None | int | tuple[int, ...]:
if axis is None:
return axis
elif isinstance(axis, numbers.Integral):
out = axis + ndim if axis < 0 else axis # type: ignore[operator]
if not 0 <= out < ndim:
msg = f"axis {axis} is out of bounds for an array with {ndim} dimensions"
raise ak.errors.AxisError(msg)
return out # type: ignore[no-any-return]
else:
out = []
for x in axis: # type: ignore[union-attr]
out.append(x + ndim if x < 0 else x)
if not 0 < out[-1] < ndim:
msg = f"axis {x} is out of bounds for an array with {ndim} dimensions"
return tuple(sorted(out))


def _regularize_dtype(dtype: None | Dtype, array_dtype: Dtype) -> Dtype:
if dtype is None:
if array_dtype.kind in ("b", "i"):
return np.dtype(np.int64)
elif array_dtype.kind == "u":
return np.dtype(np.uint64)
elif array_dtype.kind == "f":
return np.dtype(np.float64)
elif array_dtype.kind == "c":
return np.dtype(np.complex128)
else:
msg = f"unrecognized dtype.kind: {array_dtype.kind}"
raise AssertionError(msg)
else:
return dtype


def max( # pylint: disable=W0622
x: array, /, *, axis: None | int | tuple[int, ...] = None, keepdims: bool = False
) -> array:
Expand Down Expand Up @@ -263,11 +305,17 @@ def sum( # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html
"""

assert x, "TODO"
assert axis, "TODO"
assert dtype, "TODO"
assert keepdims, "TODO"
assert False, "TODO 139"
axis = _regularize_axis(axis, x.ndim)
dtype = _regularize_dtype(dtype, x.dtype)
arr = _box(type(x), ak.values_astype(*_unbox(x), dtype)) if x.dtype == dtype else x

if isinstance(axis, tuple):
(out,) = _unbox(arr)
for axis_item in axis[::-1]:
out = ak.sum(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(type(x), ak.sum(*_unbox(arr), axis=axis, keepdims=keepdims))


def var(
Expand Down
56 changes: 56 additions & 0 deletions tests/test_spec_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

import pytest

import ragged


Expand All @@ -17,3 +19,57 @@ def test_existence():
assert ragged.std is not None
assert ragged.sum is not None
assert ragged.var is not None


def test_sum():
data = ragged.array(
[[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]]
)
assert ragged.sum(data, axis=None).tolist() == pytest.approx(49.5)
assert (
ragged.sum(data, axis=0).tolist() # type: ignore[comparison-overlap]
== ragged.sum(data, axis=-3).tolist()
== [
pytest.approx([3.3, 5.5, 2.2]),
pytest.approx([5.5]),
pytest.approx([6.6, 7.7, 8.8, 9.9]),
]
)
assert (
ragged.sum(data, axis=1).tolist() # type: ignore[comparison-overlap]
== ragged.sum(data, axis=-2).tolist()
== [
pytest.approx([0.0, 1.1, 2.2]),
pytest.approx([]),
pytest.approx([15.4, 12.1, 8.8, 9.9]),
]
)
assert (
ragged.sum(data, axis=2).tolist() # type: ignore[comparison-overlap]
== ragged.sum(data, axis=-1).tolist()
== [
pytest.approx([3.3, 0.0]),
pytest.approx([]),
pytest.approx([7.7, 5.5, 33.0]),
]
)
assert (
ragged.sum(data, axis=(0, 1)).tolist()
== ragged.sum(data, axis=(1, 0)).tolist()
== pytest.approx([15.4, 13.2, 11.0, 9.9])
)
assert (
ragged.sum(data, axis=(0, 2)).tolist()
== ragged.sum(data, axis=(2, 0)).tolist()
== pytest.approx([11.0, 5.5, 33.0])
)
assert (
ragged.sum(data, axis=(1, 2)).tolist()
== ragged.sum(data, axis=(2, 1)).tolist()
== pytest.approx([3.3, 0.0, 46.2])
)
assert (
ragged.sum(data, axis=(0, 1, 2)).tolist()
== ragged.sum(data, axis=(-1, 0, 1)).tolist()
== pytest.approx(49.5)
)

0 comments on commit a8585e6

Please sign in to comment.