Skip to content

Commit 0bf38c2

Browse files
Add getitem to array protocol (#8406)
* Update _typing.py * Update _typing.py * Update test_namedarray.py * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _typing.py * Update _typing.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 562f2f8 commit 0bf38c2

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

xarray/namedarray/_typing.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Default(Enum):
2929
_T = TypeVar("_T")
3030
_T_co = TypeVar("_T_co", covariant=True)
3131

32-
32+
_dtype = np.dtype
3333
_DType = TypeVar("_DType", bound=np.dtype[Any])
3434
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
3535
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
@@ -69,9 +69,16 @@ def dtype(self) -> _DType_co:
6969
_Dims = tuple[_Dim, ...]
7070

7171
_DimsLike = Union[str, Iterable[_Dim]]
72-
_AttrsLike = Union[Mapping[Any, Any], None]
7372

74-
_dtype = np.dtype
73+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
74+
# TODO: np.array_api was bugged and didn't allow (None,), but should!
75+
# https://github.com/numpy/numpy/pull/25022
76+
# https://github.com/data-apis/array-api/pull/674
77+
_IndexKey = Union[int, slice, "ellipsis"]
78+
_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
79+
_IndexKeyLike = Union[_IndexKey, _IndexKeys]
80+
81+
_AttrsLike = Union[Mapping[Any, Any], None]
7582

7683

7784
class _SupportsReal(Protocol[_T_co]):
@@ -113,6 +120,25 @@ class _arrayfunction(
113120
Corresponds to np.ndarray.
114121
"""
115122

123+
@overload
124+
def __getitem__(
125+
self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
126+
) -> _arrayfunction[Any, _DType_co]:
127+
...
128+
129+
@overload
130+
def __getitem__(self, key: _IndexKeyLike, /) -> Any:
131+
...
132+
133+
def __getitem__(
134+
self,
135+
key: _IndexKeyLike
136+
| _arrayfunction[Any, Any]
137+
| tuple[_arrayfunction[Any, Any], ...],
138+
/,
139+
) -> _arrayfunction[Any, _DType_co] | Any:
140+
...
141+
116142
@overload
117143
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
118144
...
@@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
165191
Corresponds to np.ndarray.
166192
"""
167193

194+
def __getitem__(
195+
self,
196+
key: _IndexKeyLike
197+
| Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
198+
/,
199+
) -> _arrayapi[Any, Any]:
200+
...
201+
168202
def __array_namespace__(self) -> ModuleType:
169203
...
170204

xarray/tests/test_namedarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_AttrsLike,
2929
_DimsLike,
3030
_DType,
31+
_IndexKeyLike,
3132
_Shape,
3233
duckarray,
3334
)
@@ -58,6 +59,19 @@ class CustomArrayIndexable(
5859
ExplicitlyIndexed,
5960
Generic[_ShapeType_co, _DType_co],
6061
):
62+
def __getitem__(
63+
self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
64+
) -> CustomArrayIndexable[Any, _DType_co]:
65+
if isinstance(key, CustomArrayIndexable):
66+
if isinstance(key.array, type(self.array)):
67+
# TODO: key.array is duckarray here, can it be narrowed down further?
68+
# an _arrayapi cannot be used on a _arrayfunction for example.
69+
return type(self)(array=self.array[key.array]) # type: ignore[index]
70+
else:
71+
raise TypeError("key must have the same array type as self")
72+
else:
73+
return type(self)(array=self.array[key])
74+
6175
def __array_namespace__(self) -> ModuleType:
6276
return np
6377

0 commit comments

Comments
 (0)