Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: create_diagonal: support broadcasting #137

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays
from ._utils._helpers import asarrays, ndindex
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -172,7 +172,7 @@ def create_diagonal(
Parameters
----------
x : array
A 1-D array.
An array having shape (*broadcast_dims, k).
offset : int, optional
Offset from the leading diagonal (default is ``0``).
Use positive ints for diagonals above the leading diagonal,
Expand All @@ -183,7 +183,8 @@ def create_diagonal(
Returns
-------
array
A 2-D array with `x` on the diagonal (offset by `offset`).
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
on the diagonal (offset by `offset`).

Examples
--------
Expand All @@ -206,18 +207,21 @@ def create_diagonal(
if xp is None:
xp = array_namespace(x)

if x.ndim != 1:
err_msg = "`x` must be 1-dimensional."
if x.ndim == 0:
err_msg = "`x` must be at least 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))

start = offset if offset >= 0 else abs(offset) * n
stop = min(n * (n - offset), diag.shape[0])
step = n + 1
diag = at(diag)[start:stop:step].set(x)

return xp.reshape(diag, (n, n))
pre = x.shape[:-1]
n = x.shape[-1] + abs(offset)
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))

target_slice = slice(
offset if offset >= 0 else abs(offset) * n,
min(n * (n - offset), diag.shape[-1]),
n + 1,
)
for index in ndindex(*pre):
diag = at(diag)[(*index, target_slice)].set(x[(*index, slice(None))])
return xp.reshape(diag, (*pre, n, n))


def expand_dims(
Expand Down
24 changes: 24 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

from collections.abc import Generator
from types import ModuleType
from typing import cast

Expand Down Expand Up @@ -175,3 +176,26 @@ def asarrays(
xa, xb = xp.asarray(a), xp.asarray(b)

return (xb, xa) if swap else (xa, xb)


def ndindex(*x: int) -> Generator[tuple[int, ...]]:
"""
Generate all N-dimensional indices for a given array shape.

Given the shape of an array, an ndindex instance iterates over the N-dimensional
index of the array. At each iteration a tuple of indices is returned, the last
dimension is iterated over first.

This has an identical API to numpy.ndindex.

Parameters
----------
*x : int
The shape of the array.
"""
if not x:
yield ()
return
for i in ndindex(*x[:-1]):
for j in range(x[-1]):
yield *i, j
13 changes: 11 additions & 2 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

Expand Down Expand Up @@ -193,9 +194,17 @@ def test_0d(self, xp: ModuleType):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray(1))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_2d(self, xp: ModuleType):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray([[1]]))
result = create_diagonal(xp.asarray([[1]]))
xp_assert_equal(result, xp.asarray([[[1]]]))
b = xp.zeros((3, 2, 4, 5), dtype=xp.int64)
for i in ndindex(*b.shape):
b = at(b)[i].set(hash(i))
c = create_diagonal(b)
zero = xp.zeros((), dtype=xp.int64)
for i in ndindex(*c.shape):
xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_device(self, xp: ModuleType, device: Device):
Expand Down