Skip to content

Commit

Permalink
argmax, argmin; done with reducers
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Jan 2, 2024
1 parent 0cac682 commit a8401b6
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 80 deletions.
52 changes: 43 additions & 9 deletions src/ragged/_spec_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@

from __future__ import annotations

from ._spec_array_object import array
import awkward as ak
import numpy as np

from ._spec_array_object import _box, _unbox, array


def _remove_optiontype(x: ak.contents.Content) -> ak.contents.Content:
if x.is_list:
return x.copy(content=_remove_optiontype(x.content))
elif x.is_option:
return x.content
else:
return x


def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -34,10 +46,21 @@ def argmax(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a
https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmax.html
"""

assert x, "TODO"
assert axis, "TODO"
assert keepdims, "TODO"
assert False, "TODO 124"
out = np.argmax(*_unbox(x), axis=axis, keepdims=keepdims)

if out is None:
msg = "cannot compute argmax of an array with no data"
raise ValueError(msg)

if isinstance(out, ak.Array):
if ak.any(ak.is_none(out, axis=-1)):
msg = f"cannot compute argmax at axis={axis} because some lists at this depth have zero length"
raise ValueError(msg)
out = ak.Array(
_remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs
)

return _box(type(x), out)


def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -65,10 +88,21 @@ def argmin(x: array, /, *, axis: None | int = None, keepdims: bool = False) -> a
https://data-apis.org/array-api/latest/API_specification/generated/array_api.argmin.html
"""

assert x, "TODO"
assert axis, "TODO"
assert keepdims, "TODO"
assert False, "TODO 125"
out = np.argmin(*_unbox(x), axis=axis, keepdims=keepdims)

if out is None:
msg = "cannot compute argmin of an array with no data"
raise ValueError(msg)

if isinstance(out, ak.Array):
if ak.any(ak.is_none(out, axis=-1)):
msg = f"cannot compute argmin at axis={axis} because some lists at this depth have zero length"
raise ValueError(msg)
out = ak.Array(
_remove_optiontype(out.layout), behavior=out.behavior, attrs=out.attrs
)

return _box(type(x), out)


def nonzero(x: array, /) -> tuple[array, ...]:
Expand Down
146 changes: 80 additions & 66 deletions src/ragged/_spec_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def _regularize_axis(
axis: None | int | tuple[int, ...], ndim: int
) -> None | int | tuple[int, ...]:
) -> None | tuple[int, ...]:
if axis is None:
return axis
elif isinstance(axis, numbers.Integral):
Expand Down Expand Up @@ -96,13 +96,20 @@ def max( # pylint: disable=W0622
if isinstance(axis, tuple):
(out,) = _unbox(x)
for axis_item in axis[::-1]:
out = ak.max(out, axis=axis_item, keepdims=keepdims, mask_identity=False)
if isinstance(out, ak.Array):
out = ak.max(
out, axis=axis_item, keepdims=keepdims, mask_identity=False
)
else:
out = np.max(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(
type(x),
ak.max(*_unbox(x), axis=axis, keepdims=keepdims, mask_identity=False),
)
(tmp,) = _unbox(x)
if isinstance(tmp, ak.Array):
out = ak.max(tmp, axis=axis, keepdims=keepdims, mask_identity=False)
else:
out = np.max(tmp, axis=axis, keepdims=keepdims)
return _box(type(x), out)


def mean(
Expand Down Expand Up @@ -133,13 +140,13 @@ def mean(
axis = _regularize_axis(axis, x.ndim)

if isinstance(axis, tuple):
sumwx = ak.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims)
sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims)
sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims)
for axis_item in axis[-2::-1]:
sumwx = ak.sum(sumwx, axis=axis_item, keepdims=keepdims)
sumw = ak.sum(sumw, axis=axis_item, keepdims=keepdims)
sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims)
sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims)
else:
sumwx = ak.sum(*_unbox(x), axis=axis, keepdims=keepdims)
sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims)
sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims)

with np.errstate(invalid="ignore", divide="ignore"):
Expand Down Expand Up @@ -176,13 +183,20 @@ def min( # pylint: disable=W0622
if isinstance(axis, tuple):
(out,) = _unbox(x)
for axis_item in axis[::-1]:
out = ak.min(out, axis=axis_item, keepdims=keepdims, mask_identity=False)
if isinstance(out, ak.Array):
out = ak.min(
out, axis=axis_item, keepdims=keepdims, mask_identity=False
)
else:
out = np.min(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(
type(x),
ak.min(*_unbox(x), axis=axis, keepdims=keepdims, mask_identity=False),
)
(tmp,) = _unbox(x)
if isinstance(tmp, ak.Array):
out = ak.min(tmp, axis=axis, keepdims=keepdims, mask_identity=False)
else:
out = np.min(tmp, axis=axis, keepdims=keepdims)
return _box(type(x), out)


def prod(
Expand All @@ -194,46 +208,46 @@ def prod(
keepdims: bool = False,
) -> array:
"""
Calculates the product of input array `x` elements.
Args:
x: Input array.
axis: Axis or axes along which products are computed. By default, the
product is computed over the entire array. If a tuple of integers,
products are computed over multiple axes.
dtype: Data type of the returned array. If `None`,
- if the default data type corresponding to the data type "kind"
(integer, real-valued floating-point, or complex floating-point)
of `x` has a smaller range of values than the data type of `x`
(e.g., `x` has data type `int64` and the default data type is
`int32`, or `x` has data type `uint64` and the default data type
is `int64`), the returned array has the same data type as `x`.
- if `x` has a real-valued floating-point data type, the returned
array has the default real-valued floating-point data type.
- if `x` has a complex floating-point data type, the returned array
has data type `np.complex128`.
- if `x` has a signed integer data type (e.g., `int16`), the
returned array has data type `np.int64`.
- if `x` has an unsigned integer data type (e.g., `uint16`), the
returned array has data type `np.uint64`.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array will be cast to the specified
data type before computing the product.
keepdims: If `True`, the reduced axes (dimensions) are included in the
result as singleton dimensions, and, accordingly, the result is
broadcastable with the input array. Otherwise, if `False`, the
reduced axes (dimensions) are not included in the result.
Returns:
If the product was computed over the entire array, a zero-dimensional
array containing the product; otherwise, a non-zero-dimensional array
containing the products. The returned array has a data type as
described by the `dtype` parameter above.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html
Calculates the product of input array `x` elements.
Args:
x: Input array.
axis: Axis or axes along which products are computed. By default, the
product is computed over the entire array. If a tuple of integers,
products are computed over multiple axes.
dtype: Data type of the returned array. If `None`,
- if the default data type corresponding to the data type "kind"
a (integer, real-valued floating-point, or complex floating-point)
of `x` has a smaller range of values than the data type of `x`
(e.g., `x` has data type `int64` and the default data type is
`int32`, or `x` has data type `uint64` and the default data type
is `int64`), the returned array has the same data type as `x`.
- if `x` has a real-valued floating-point data type, the returned
array has the default real-valued floating-point data type.
- if `x` has a complex floating-point data type, the returned array
has data type `np.complex128`.
- if `x` has a signed integer data type (e.g., `int16`), the
returned array has data type `np.int64`.
- if `x` has an unsigned integer data type (e.g., `uint16`), the
returned array has data type `np.uint64`.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array will be cast to the specified
data type before computing the product.
keepdims: If `True`, the reduced axes (dimensions) are included in the
result as singleton dimensions, and, accordingly, the result is
broadcastable with the input array. Otherwise, if `False`, the
reduced axes (dimensions) are not included in the result.
Returns:
If the product was computed over the entire array, a zero-dimensional
array containing the product; otherwise, a non-zero-dimensional array
containing the products. The returned array has a data type as
described by the `dtype` parameter above.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.prod.html
"""

axis = _regularize_axis(axis, x.ndim)
Expand All @@ -243,10 +257,10 @@ def prod(
if isinstance(axis, tuple):
(out,) = _unbox(arr)
for axis_item in axis[::-1]:
out = ak.prod(out, axis=axis_item, keepdims=keepdims)
out = np.prod(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(type(x), ak.prod(*_unbox(arr), axis=axis, keepdims=keepdims))
return _box(type(x), np.prod(*_unbox(arr), axis=axis, keepdims=keepdims))


def std(
Expand Down Expand Up @@ -357,10 +371,10 @@ def sum( # pylint: disable=W0622
if isinstance(axis, tuple):
(out,) = _unbox(arr)
for axis_item in axis[::-1]:
out = ak.sum(out, axis=axis_item, keepdims=keepdims)
out = np.sum(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(type(x), ak.sum(*_unbox(arr), axis=axis, keepdims=keepdims))
return _box(type(x), np.sum(*_unbox(arr), axis=axis, keepdims=keepdims))


def var(
Expand Down Expand Up @@ -408,16 +422,16 @@ def var(
axis = _regularize_axis(axis, x.ndim)

if isinstance(axis, tuple):
sumwxx = ak.sum(np.square(*_unbox(x)), axis=axis[-1], keepdims=keepdims)
sumwx = ak.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims)
sumwxx = np.sum(np.square(*_unbox(x)), axis=axis[-1], keepdims=keepdims)
sumwx = np.sum(*_unbox(x), axis=axis[-1], keepdims=keepdims)
sumw = ak.count(*_unbox(x), axis=axis[-1], keepdims=keepdims)
for axis_item in axis[-2::-1]:
sumwxx = ak.sum(sumwxx, axis=axis_item, keepdims=keepdims)
sumwx = ak.sum(sumwx, axis=axis_item, keepdims=keepdims)
sumw = ak.sum(sumw, axis=axis_item, keepdims=keepdims)
sumwxx = np.sum(sumwxx, axis=axis_item, keepdims=keepdims)
sumwx = np.sum(sumwx, axis=axis_item, keepdims=keepdims)
sumw = np.sum(sumw, axis=axis_item, keepdims=keepdims)
else:
sumwxx = ak.sum(np.square(*_unbox(x)), axis=axis, keepdims=keepdims)
sumwx = ak.sum(*_unbox(x), axis=axis, keepdims=keepdims)
sumwxx = np.sum(np.square(*_unbox(x)), axis=axis, keepdims=keepdims)
sumwx = np.sum(*_unbox(x), axis=axis, keepdims=keepdims)
sumw = ak.count(*_unbox(x), axis=axis, keepdims=keepdims)

with np.errstate(invalid="ignore", divide="ignore"):
Expand Down
10 changes: 5 additions & 5 deletions src/ragged/_spec_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

import awkward as ak
import numpy as np

from ._spec_array_object import _box, _unbox, array
from ._spec_statistical_functions import _regularize_axis
Expand Down Expand Up @@ -51,10 +51,10 @@ def all( # pylint: disable=W0622
if isinstance(axis, tuple):
(out,) = _unbox(x)
for axis_item in axis[::-1]:
out = ak.all(out, axis=axis_item, keepdims=keepdims)
out = np.all(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(type(x), ak.all(*_unbox(x), axis=axis, keepdims=keepdims))
return _box(type(x), np.all(*_unbox(x), axis=axis, keepdims=keepdims))


def any( # pylint: disable=W0622
Expand Down Expand Up @@ -96,7 +96,7 @@ def any( # pylint: disable=W0622
if isinstance(axis, tuple):
(out,) = _unbox(x)
for axis_item in axis[::-1]:
out = ak.any(out, axis=axis_item, keepdims=keepdims)
out = np.any(out, axis=axis_item, keepdims=keepdims)
return _box(type(x), out)
else:
return _box(type(x), ak.any(*_unbox(x), axis=axis, keepdims=keepdims))
return _box(type(x), np.any(*_unbox(x), axis=axis, keepdims=keepdims))
44 changes: 44 additions & 0 deletions tests/test_spec_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

import pytest

import ragged


Expand All @@ -14,3 +16,45 @@ def test_existence():
assert ragged.argmin is not None
assert ragged.nonzero is not None
assert ragged.where is not None


def test_argmax():
data = ragged.array(
[[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]]
)
assert ragged.argmax(data, axis=None).tolist() == 9
assert (
ragged.argmax(data, axis=0).tolist()
== ragged.argmax(data, axis=-3).tolist()
== [[1, 1, 0], [1], [0, 0, 0, 0]]
)
assert (
ragged.argmax(data, axis=1).tolist() # type: ignore[comparison-overlap]
== ragged.argmax(data, axis=-2).tolist()
== [[0, 0, 0], [], [2, 2, 2, 2]]
)
with pytest.raises(ValueError, match=".*axis.*"):
ragged.argmax(data, axis=2)
with pytest.raises(ValueError, match=".*axis.*"):
ragged.argmax(data, axis=-1)


def test_argmin():
data = ragged.array(
[[[0, 1.1, 2.2], []], [], [[3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]]]
)
assert ragged.argmin(data, axis=None).tolist() == 0
assert (
ragged.argmin(data, axis=0).tolist()
== ragged.argmin(data, axis=-3).tolist()
== [[0, 0, 0], [1], [0, 0, 0, 0]]
)
assert (
ragged.argmin(data, axis=1).tolist() # type: ignore[comparison-overlap]
== ragged.argmin(data, axis=-2).tolist()
== [[0, 0, 0], [], [0, 0, 2, 2]]
)
with pytest.raises(ValueError, match=".*axis.*"):
ragged.argmin(data, axis=2)
with pytest.raises(ValueError, match=".*axis.*"):
ragged.argmin(data, axis=-1)

0 comments on commit a8401b6

Please sign in to comment.