diff --git a/src/ragged/common/__init__.py b/src/ragged/common/__init__.py index aee0255..84a1624 100644 --- a/src/ragged/common/__init__.py +++ b/src/ragged/common/__init__.py @@ -20,6 +20,12 @@ full_like, linspace, meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, ) from ._obj import array @@ -35,6 +41,12 @@ "full_like", "linspace", "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", # _obj "array", ] diff --git a/src/ragged/common/_creation.py b/src/ragged/common/_creation.py index 88550dd..1be7fd8 100644 --- a/src/ragged/common/_creation.py +++ b/src/ragged/common/_creation.py @@ -391,3 +391,159 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> list[array]: assert arrays, "TODO" assert indexing, "TODO" assert False, "TODO" + + +def ones( + shape: int | tuple[int, ...], + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array having a specified `shape` and filled with ones. + + Args: + shape: Output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array containing ones. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones.html + """ + + assert shape, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def ones_like( + x: array, /, *, dtype: None | Dtype = None, device: None | Device = None +) -> array: + """ + Returns a new array filled with ones and having the same `shape` as an + input array `x`. + + Args: + x: Input array from which to derive the output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, the output array device is inferred from `x`. + + Returns: + An array having the same shape as x and filled with ones. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.ones_like.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def tril(x: array, /, *, k: int = 0) -> array: + """ + Returns the lower triangular part of a matrix (or a stack of matrices) `x`. + + Args: + x: Input array having shape `(..., M, N)` and whose innermost two + dimensions form `M` by `N` matrices. + `k`: Diagonal above which to zero elements. If `k = 0`, the diagonal is + the main diagonal. If `k < 0`, the diagonal is below the main + diagonal. If `k > 0`, the diagonal is above the main diagonal. + + Returns: + An array containing the lower triangular part(s). The returned array + has the same shape and data type as `x`. All elements above the + specified diagonal `k` are zero. The returned array is allocated on the + same device as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.tril.html + """ + + assert x, "TODO" + assert k, "TODO" + assert False, "TODO" + + +def triu(x: array, /, *, k: int = 0) -> array: + """ + Returns the upper triangular part of a matrix (or a stack of matrices) `x`. + + Args: + x: Input array having shape `(..., M, N)` and whose innermost two + dimensions form `M` by `N` matrices. + k: Diagonal below which to zero elements. If `k = 0`, the diagonal is + the main diagonal. If `k < 0`, the diagonal is below the main + diagonal. If `k > 0`, the diagonal is above the main diagonal. + + Returns: + An array containing the upper triangular part(s). The returned array + has the same shape and data type as `x`. All elements below the + specified diagonal `k` are zero. The returned array is allocated on the + same device as `x`. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.triu.html + """ + + assert x, "TODO" + assert k, "TODO" + assert False, "TODO" + + +def zeros( + shape: int | tuple[int, ...], + *, + dtype: None | Dtype = None, + device: None | Device = None, +) -> array: + """ + Returns a new array having a specified shape and filled with zeros. + + Args: + shape: Output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is `np.float64`. + device: Device on which to place the created array. + + Returns: + An array containing zeros. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros.html + """ + + assert shape, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" + + +def zeros_like( + x: array, /, *, dtype: None | Dtype = None, device: None | Device = None +) -> array: + """ + Returns a new array filled with zeros and having the same `shape` as an + input array `x`. + + Args: + x: Input array from which to derive the output array shape. + dtype: Output array data type. If `dtype` is `None`, the output array + data type is inferred from `x`. + device: Device on which to place the created array. If `device` is + `None`, the output array device is inferred from `x`. + + Returns: + An array having the same shape as `x` and filled with zeros. + + https://data-apis.org/array-api/latest/API_specification/generated/array_api.zeros_like.html + """ + + assert x, "TODO" + assert dtype, "TODO" + assert device, "TODO" + assert False, "TODO" diff --git a/src/ragged/v202212/__init__.py b/src/ragged/v202212/__init__.py index 37849d5..5906977 100644 --- a/src/ragged/v202212/__init__.py +++ b/src/ragged/v202212/__init__.py @@ -22,6 +22,12 @@ full_like, linspace, meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, ) from ._obj import array @@ -37,6 +43,12 @@ "full_like", "linspace", "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", # _obj "array", ] diff --git a/src/ragged/v202212/_creation.py b/src/ragged/v202212/_creation.py index 6105f9d..2148440 100644 --- a/src/ragged/v202212/_creation.py +++ b/src/ragged/v202212/_creation.py @@ -17,6 +17,12 @@ full_like, linspace, meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, ) __all__ = [ @@ -30,4 +36,10 @@ "full_like", "linspace", "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", ]