@@ -29,7 +29,7 @@ class Default(Enum):
29
29
_T = TypeVar ("_T" )
30
30
_T_co = TypeVar ("_T_co" , covariant = True )
31
31
32
-
32
+ _dtype = np . dtype
33
33
_DType = TypeVar ("_DType" , bound = np .dtype [Any ])
34
34
_DType_co = TypeVar ("_DType_co" , covariant = True , bound = np .dtype [Any ])
35
35
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
@@ -69,9 +69,16 @@ def dtype(self) -> _DType_co:
69
69
_Dims = tuple [_Dim , ...]
70
70
71
71
_DimsLike = Union [str , Iterable [_Dim ]]
72
- _AttrsLike = Union [Mapping [Any , Any ], None ]
73
72
74
- _dtype = np .dtype
73
+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
74
+ # TODO: np.array_api was bugged and didn't allow (None,), but should!
75
+ # https://github.com/numpy/numpy/pull/25022
76
+ # https://github.com/data-apis/array-api/pull/674
77
+ _IndexKey = Union [int , slice , "ellipsis" ]
78
+ _IndexKeys = tuple [Union [_IndexKey ], ...] # tuple[Union[_IndexKey, None], ...]
79
+ _IndexKeyLike = Union [_IndexKey , _IndexKeys ]
80
+
81
+ _AttrsLike = Union [Mapping [Any , Any ], None ]
75
82
76
83
77
84
class _SupportsReal (Protocol [_T_co ]):
@@ -113,6 +120,25 @@ class _arrayfunction(
113
120
Corresponds to np.ndarray.
114
121
"""
115
122
123
+ @overload
124
+ def __getitem__ (
125
+ self , key : _arrayfunction [Any , Any ] | tuple [_arrayfunction [Any , Any ], ...], /
126
+ ) -> _arrayfunction [Any , _DType_co ]:
127
+ ...
128
+
129
+ @overload
130
+ def __getitem__ (self , key : _IndexKeyLike , / ) -> Any :
131
+ ...
132
+
133
+ def __getitem__ (
134
+ self ,
135
+ key : _IndexKeyLike
136
+ | _arrayfunction [Any , Any ]
137
+ | tuple [_arrayfunction [Any , Any ], ...],
138
+ / ,
139
+ ) -> _arrayfunction [Any , _DType_co ] | Any :
140
+ ...
141
+
116
142
@overload
117
143
def __array__ (self , dtype : None = ..., / ) -> np .ndarray [Any , _DType_co ]:
118
144
...
@@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
165
191
Corresponds to np.ndarray.
166
192
"""
167
193
194
+ def __getitem__ (
195
+ self ,
196
+ key : _IndexKeyLike
197
+ | Any , # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
198
+ / ,
199
+ ) -> _arrayapi [Any , Any ]:
200
+ ...
201
+
168
202
def __array_namespace__ (self ) -> ModuleType :
169
203
...
170
204
0 commit comments