Skip to content

Add count_nonzero and cumulative_prod from 2024.12 revision draft #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@

__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]

from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where

__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"]

from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values

Expand All @@ -305,9 +305,9 @@

__all__ += ["argsort", "sort"]

from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var

__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"]

from ._utility_functions import all, any, diff

Expand Down
20 changes: 19 additions & 1 deletion array_api_strict/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Literal, Optional, Tuple
from typing import Literal, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array))


@requires_api_version('2024.12')
def count_nonzero(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.count_nonzero <numpy.count_nonzero>`

See its docstring for more information.
"""
arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims)
return Array._new(np.asarray(arr), device=x.device)


@requires_api_version('2023.12')
def searchsorted(
x1: Array,
Expand Down
36 changes: 33 additions & 3 deletions array_api_strict/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._array_object import Array
from ._dtypes import float32, complex64
from ._flags import requires_api_version, get_array_api_strict_flags
from ._creation_functions import zeros
from ._creation_functions import zeros, ones
from ._manipulation_functions import concat

from typing import TYPE_CHECKING
Expand All @@ -31,7 +31,6 @@ def cumulative_sum(
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
dt = x.dtype if dtype is None else dtype
if dtype is not None:
dtype = dtype._np_dtype

Expand All @@ -44,9 +43,40 @@ def cumulative_sum(
if include_initial:
if axis < 0:
axis += x.ndim
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)


@requires_api_version('2024.12')
def cumulative_prod(
x: Array,
/,
*,
axis: Optional[int] = None,
dtype: Optional[Dtype] = None,
include_initial: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in cumulative_prod")
if x.ndim == 0:
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")

if dtype is not None:
dtype = dtype._np_dtype

if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
axis = 0

# np.cumprod does not support include_initial
if include_initial:
if axis < 0:
axis += x.ndim
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device)


def max(
x: Array,
/,
Expand Down
2 changes: 2 additions & 0 deletions array_api_strict/tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def test_api_version_2023_12(func_name):
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)),
xp.zeros((1, 4), dtype=xp.int64)),
'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)),
'cumulative_prod': lambda: xp.cumulative_prod(xp.arange(1, 5)),
}

@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())
Expand Down
Loading