Skip to content

Commit

Permalink
Make create_diagonal support broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 9, 2025
1 parent a71bd2e commit 91b2096
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 24 deletions.
44 changes: 22 additions & 22 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays
from ._utils._helpers import asarrays, ndindex
from ._utils._typing import Array

__all__ = [
Expand All @@ -29,8 +29,7 @@


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.
"""Recursively expand the dimension of an array to at least `ndim`.
Parameters
----------
Expand Down Expand Up @@ -72,8 +71,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix.
"""Estimate a covariance matrix.
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
Expand Down Expand Up @@ -166,13 +164,12 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
def create_diagonal(
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
) -> Array:
"""
Construct a diagonal array.
"""Construct a diagonal array.
Parameters
----------
x : array
A 1-D array.
An array having shape (*broadcast_dims, k).
offset : int, optional
Offset from the leading diagonal (default is ``0``).
Use positive ints for diagonals above the leading diagonal,
Expand All @@ -183,7 +180,8 @@ def create_diagonal(
Returns
-------
array
A 2-D array with `x` on the diagonal (offset by `offset`).
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
on the diagonal (offset by `offset`).
Examples
--------
Expand All @@ -206,25 +204,27 @@ def create_diagonal(
if xp is None:
xp = array_namespace(x)

if x.ndim != 1:
err_msg = "`x` must be 1-dimensional."
if x.ndim == 0:
err_msg = "`x` must be at least 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))

start = offset if offset >= 0 else abs(offset) * n
stop = min(n * (n - offset), diag.shape[0])
step = n + 1
diag = at(diag)[start:stop:step].set(x)

return xp.reshape(diag, (n, n))
pre = x.shape[:-1]
n = x.shape[-1] + abs(offset)
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))

target_slice = slice(
offset if offset >= 0 else abs(offset) * n,
min(n * (n - offset), diag.shape[-1]),
n + 1,
)
for index in ndindex(*pre):
diag = at(diag)[(*index, target_slice)].set(x[*index, :])
return xp.reshape(diag, (*pre, n, n))


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
"""
Expand the shape of an array.
"""Expand the shape of an array.
Insert (a) new axis/axes that will appear at the position(s) specified by
`axis` in the expanded array shape.
Expand Down
23 changes: 23 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

from collections.abc import Generator
from types import ModuleType
from typing import cast

Expand Down Expand Up @@ -175,3 +176,25 @@ def asarrays(
xa, xb = xp.asarray(a), xp.asarray(b)

return (xb, xa) if swap else (xa, xb)


def ndindex(*x: int) -> Generator[tuple[int, ...]]:
"""Generate all N-dimensional indices for a given array shape.
Given the shape of an array, an ndindex instance iterates over the N-dimensional
index of the array. At each iteration a tuple of indices is returned, the last
dimension is iterated over first.
This has an identical API to numpy.ndindex.
Parameters
----------
*x : int
The shape of the array.
"""
if not x:
yield ()
return
for i in ndindex(*x[:-1]):
for j in range(x[-1]):
yield *i, j
13 changes: 11 additions & 2 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

Expand Down Expand Up @@ -193,9 +194,17 @@ def test_0d(self, xp: ModuleType):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray(1))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_2d(self, xp: ModuleType):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray([[1]]))
result = create_diagonal(xp.asarray([[1]]))
xp_assert_equal(result, xp.asarray([[[1]]]))
b = xp.zeros((3, 2, 4, 5), dtype=xp.int64)
for i in ndindex(*b.shape):
b = at(b)[*i].set(hash(i))
c = create_diagonal(b)
zero = xp.zeros((), dtype=xp.int64)
for i in ndindex(*c.shape):
xp_assert_equal(c[*i], b[*(i[:-1])] if i[-2] == i[-1] else zero)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_device(self, xp: ModuleType, device: Device):
Expand Down

0 comments on commit 91b2096

Please sign in to comment.