Skip to content

Commit 4b8fbef

Browse files
authored
Merge pull request #127 from ev-br/1D_array_indices
fancy indexing with ints and integer arrays
2 parents 5ccb0e7 + 6664e6d commit 4b8fbef

File tree

2 files changed

+95
-35
lines changed

2 files changed

+95
-35
lines changed

array_api_strict/_array_object.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
327327

328328
# Note: A large fraction of allowed indices are disallowed here (see the
329329
# docstring below)
330-
def _validate_index(self, key):
330+
def _validate_index(self, key, op="getitem"):
331331
"""
332332
Validate an index according to the array API.
333333
@@ -390,11 +390,16 @@ def _validate_index(self, key):
390390
"zero-dimensional integer arrays and boolean arrays "
391391
"are specified in the Array API."
392392
)
393+
if op == "setitem":
394+
if isinstance(i, Array) and i.dtype in _integer_dtypes:
395+
raise IndexError("Fancy indexing __setitem__ is not supported.")
393396

394397
nonexpanding_key = []
395398
single_axes = []
396399
n_ellipsis = 0
397400
key_has_mask = False
401+
key_has_index_array = False
402+
key_has_slices = False
398403
for i in _key:
399404
if i is not None:
400405
nonexpanding_key.append(i)
@@ -403,13 +408,17 @@ def _validate_index(self, key):
403408
if isinstance(i, Array):
404409
if i.dtype in _boolean_dtypes:
405410
key_has_mask = True
411+
elif i.dtype in _integer_dtypes:
412+
key_has_index_array = True
406413
single_axes.append(i)
407414
else:
408415
# i must not be an array here, to avoid elementwise equals
409416
if i == Ellipsis:
410417
n_ellipsis += 1
411418
else:
412419
single_axes.append(i)
420+
if isinstance(i, slice):
421+
key_has_slices = True
413422

414423
n_single_axes = len(single_axes)
415424
if n_ellipsis > 1:
@@ -427,6 +436,12 @@ def _validate_index(self, key):
427436
"specified in the Array API."
428437
)
429438

439+
if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)):
440+
raise IndexError(
441+
"Integer index arrays are only allowed with integer indices; "
442+
f"got {key}."
443+
)
444+
430445
if n_ellipsis == 0:
431446
indexed_shape = self.shape
432447
else:
@@ -483,14 +498,9 @@ def _validate_index(self, key):
483498
"Array API when the array is the sole index."
484499
)
485500
if not get_array_api_strict_flags()['boolean_indexing']:
486-
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")
487-
488-
elif i.dtype in _integer_dtypes and i.ndim != 0:
489-
raise IndexError(
490-
f"Single-axes index {i} is a non-zero-dimensional "
491-
"integer array, but advanced integer indexing is not "
492-
"specified in the Array API."
493-
)
501+
raise RuntimeError(
502+
"The boolean_indexing flag has been disabled for array-api-strict"
503+
)
494504
elif isinstance(i, tuple):
495505
raise IndexError(
496506
f"Single-axes index {i} is a tuple, but nested tuple "
@@ -902,7 +912,7 @@ def __setitem__(
902912
"""
903913
# Note: Only indices required by the spec are allowed. See the
904914
# docstring of _validate_index
905-
self._validate_index(key)
915+
self._validate_index(key, op="setitem")
906916
if isinstance(key, Array):
907917
# Indexing self._array with array_api_strict arrays can be erroneous
908918
key = key._array

array_api_strict/tests/test_array_object.py

+75-25
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from .. import ones, asarray, result_type, all, equal
8+
from .. import ones, arange, reshape, asarray, result_type, all, equal
99
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
@@ -45,35 +45,46 @@ def test_validate_index():
4545
a = ones((3, 4))
4646

4747
# Out of bounds slices are not allowed
48-
assert_raises(IndexError, lambda: a[:4])
49-
assert_raises(IndexError, lambda: a[:-4])
50-
assert_raises(IndexError, lambda: a[:3:-1])
51-
assert_raises(IndexError, lambda: a[:-5:-1])
52-
assert_raises(IndexError, lambda: a[4:])
53-
assert_raises(IndexError, lambda: a[-4:])
54-
assert_raises(IndexError, lambda: a[4::-1])
55-
assert_raises(IndexError, lambda: a[-4::-1])
56-
57-
assert_raises(IndexError, lambda: a[...,:5])
58-
assert_raises(IndexError, lambda: a[...,:-5])
59-
assert_raises(IndexError, lambda: a[...,:5:-1])
60-
assert_raises(IndexError, lambda: a[...,:-6:-1])
61-
assert_raises(IndexError, lambda: a[...,5:])
62-
assert_raises(IndexError, lambda: a[...,-5:])
63-
assert_raises(IndexError, lambda: a[...,5::-1])
64-
assert_raises(IndexError, lambda: a[...,-5::-1])
48+
assert_raises(IndexError, lambda: a[:4, 0])
49+
assert_raises(IndexError, lambda: a[:-4, 0])
50+
assert_raises(IndexError, lambda: a[:3:-1]) # XXX raises for a wrong reason
51+
assert_raises(IndexError, lambda: a[:-5:-1, 0])
52+
assert_raises(IndexError, lambda: a[4:, 0])
53+
assert_raises(IndexError, lambda: a[-4:, 0])
54+
assert_raises(IndexError, lambda: a[4::-1, 0])
55+
assert_raises(IndexError, lambda: a[-4::-1, 0])
56+
57+
assert_raises(IndexError, lambda: a[..., :5])
58+
assert_raises(IndexError, lambda: a[..., :-5])
59+
assert_raises(IndexError, lambda: a[..., :5:-1])
60+
assert_raises(IndexError, lambda: a[..., :-6:-1])
61+
assert_raises(IndexError, lambda: a[..., 5:])
62+
assert_raises(IndexError, lambda: a[..., -5:])
63+
assert_raises(IndexError, lambda: a[..., 5::-1])
64+
assert_raises(IndexError, lambda: a[..., -5::-1])
6565

6666
# Boolean indices cannot be part of a larger tuple index
67-
assert_raises(IndexError, lambda: a[a[:,0]==1,0])
68-
assert_raises(IndexError, lambda: a[a[:,0]==1,...])
69-
assert_raises(IndexError, lambda: a[..., a[0]==1])
67+
assert_raises(IndexError, lambda: a[a[:, 0] == 1, 0])
68+
assert_raises(IndexError, lambda: a[a[:, 0] == 1, ...])
69+
assert_raises(IndexError, lambda: a[..., a[0] == 1])
7070
assert_raises(IndexError, lambda: a[[True, True, True]])
7171
assert_raises(IndexError, lambda: a[(True, True, True),])
7272

73-
# Integer array indices are not allowed (except for 0-D)
74-
idx = asarray([[0, 1]])
75-
assert_raises(IndexError, lambda: a[idx])
76-
assert_raises(IndexError, lambda: a[idx,])
73+
# Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
74+
idx = asarray([0, 1])
75+
assert_raises(IndexError, lambda: a[..., idx])
76+
assert_raises(IndexError, lambda: a[:, idx])
77+
assert_raises(IndexError, lambda: a[asarray([True, True]), idx])
78+
79+
# 1D integer array indices must have the same length
80+
idx1 = asarray([0, 1])
81+
idx2 = asarray([0, 1, 1])
82+
assert_raises(IndexError, lambda: a[idx1, idx2])
83+
84+
# Non-integer array indices are not allowed
85+
assert_raises(IndexError, lambda: a[ones(2), 0])
86+
87+
# Array-likes (lists, tuples) are not allowed as indices
7788
assert_raises(IndexError, lambda: a[[0, 1]])
7889
assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
7990
assert_raises(IndexError, lambda: a[[0, 1]])
@@ -87,6 +98,45 @@ def test_validate_index():
8798
assert_raises(IndexError, lambda: a[0,])
8899
assert_raises(IndexError, lambda: a[0])
89100
assert_raises(IndexError, lambda: a[:])
101+
assert_raises(IndexError, lambda: a[idx])
102+
103+
104+
def test_indexing_arrays():
105+
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
106+
107+
# 1D array
108+
a = arange(5)
109+
idx = asarray([1, 0, 1, 2, -1])
110+
a_idx = a[idx]
111+
112+
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
113+
assert all(a_idx == a_idx_loop)
114+
115+
# setitem with arrays is not allowed
116+
with assert_raises(IndexError):
117+
a[idx] = 42
118+
119+
# mixed array and integer indexing
120+
a = reshape(arange(3*4), (3, 4))
121+
idx = asarray([1, 0, 1, 2, -1])
122+
a_idx = a[idx, 1]
123+
124+
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
125+
assert all(a_idx == a_idx_loop)
126+
127+
# index with two arrays
128+
a_idx = a[idx, idx]
129+
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
130+
assert all(a_idx == a_idx_loop)
131+
132+
# setitem with arrays is not allowed
133+
with assert_raises(IndexError):
134+
a[idx, idx] = 42
135+
136+
# smoke test indexing with ndim > 1 arrays
137+
idx = idx[..., None]
138+
a[idx, idx]
139+
90140

91141
def test_promoted_scalar_inherits_device():
92142
device1 = Device("device1")

0 commit comments

Comments
 (0)