Skip to content

Commit

Permalink
ones, ones_like, tril, triu, zeros, zeros_like
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 27, 2023
1 parent a416042 commit 59c94b6
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/ragged/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
)
from ._obj import array

Expand All @@ -35,6 +41,12 @@
"full_like",
"linspace",
"meshgrid",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
# _obj
"array",
]
156 changes: 156 additions & 0 deletions src/ragged/common/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,159 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]:
assert arrays, "TODO"
assert indexing, "TODO"
assert False, "TODO"


def ones(
shape: int | tuple[int, ...],
*,
dtype: None | Dtype = None,
device: None | Device = None,
) -> array:
"""
Returns a new array having a specified `shape` and filled with ones.
Args:
shape: Output array shape.
dtype: Output array data type. If `dtype` is `None`, the output array
data type is `np.float64`.
device: Device on which to place the created array.
Returns:
An array containing ones.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html
"""

assert shape, "TODO"
assert dtype, "TODO"
assert device, "TODO"
assert False, "TODO"


def ones_like(
x: array, /, *, dtype: None | Dtype = None, device: None | Device = None
) -> array:
"""
Returns a new array filled with ones and having the same `shape` as an
input array `x`.
Args:
x: Input array from which to derive the output array shape.
dtype: Output array data type. If `dtype` is `None`, the output array
data type is inferred from `x`.
device: Device on which to place the created array. If `device` is
`None`, the output array device is inferred from `x`.
Returns:
An array having the same shape as x and filled with ones.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html
"""

assert x, "TODO"
assert dtype, "TODO"
assert device, "TODO"
assert False, "TODO"


def tril(x: array, /, *, k: int = 0) -> array:
"""
Returns the lower triangular part of a matrix (or a stack of matrices) `x`.
Args:
x: Input array having shape `(..., M, N)` and whose innermost two
dimensions form `M` by `N` matrices.
`k`: Diagonal above which to zero elements. If `k = 0`, the diagonal is
the main diagonal. If `k < 0`, the diagonal is below the main
diagonal. If `k > 0`, the diagonal is above the main diagonal.
Returns:
An array containing the lower triangular part(s). The returned array
has the same shape and data type as `x`. All elements above the
specified diagonal `k` are zero. The returned array is allocated on the
same device as `x`.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.tril.html
"""

assert x, "TODO"
assert k, "TODO"
assert False, "TODO"


def triu(x: array, /, *, k: int = 0) -> array:
"""
Returns the upper triangular part of a matrix (or a stack of matrices) `x`.
Args:
x: Input array having shape `(..., M, N)` and whose innermost two
dimensions form `M` by `N` matrices.
k: Diagonal below which to zero elements. If `k = 0`, the diagonal is
the main diagonal. If `k < 0`, the diagonal is below the main
diagonal. If `k > 0`, the diagonal is above the main diagonal.
Returns:
An array containing the upper triangular part(s). The returned array
has the same shape and data type as `x`. All elements below the
specified diagonal `k` are zero. The returned array is allocated on the
same device as `x`.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.triu.html
"""

assert x, "TODO"
assert k, "TODO"
assert False, "TODO"


def zeros(
shape: int | tuple[int, ...],
*,
dtype: None | Dtype = None,
device: None | Device = None,
) -> array:
"""
Returns a new array having a specified shape and filled with zeros.
Args:
shape: Output array shape.
dtype: Output array data type. If `dtype` is `None`, the output array
data type is `np.float64`.
device: Device on which to place the created array.
Returns:
An array containing zeros.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html
"""

assert shape, "TODO"
assert dtype, "TODO"
assert device, "TODO"
assert False, "TODO"


def zeros_like(
x: array, /, *, dtype: None | Dtype = None, device: None | Device = None
) -> array:
"""
Returns a new array filled with zeros and having the same `shape` as an
input array `x`.
Args:
x: Input array from which to derive the output array shape.
dtype: Output array data type. If `dtype` is `None`, the output array
data type is inferred from `x`.
device: Device on which to place the created array. If `device` is
`None`, the output array device is inferred from `x`.
Returns:
An array having the same shape as `x` and filled with zeros.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html
"""

assert x, "TODO"
assert dtype, "TODO"
assert device, "TODO"
assert False, "TODO"
12 changes: 12 additions & 0 deletions src/ragged/v202212/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
)
from ._obj import array

Expand All @@ -37,6 +43,12 @@
"full_like",
"linspace",
"meshgrid",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
# _obj
"array",
]
12 changes: 12 additions & 0 deletions src/ragged/v202212/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
)

__all__ = [
Expand All @@ -30,4 +36,10 @@
"full_like",
"linspace",
"meshgrid",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
]

0 comments on commit 59c94b6

Please sign in to comment.