Skip to content

Commit 79272c3

Browse files
andersy005dcherian
andauthored
Implement setitem syntax for .oindex and .vindex properties (#8845)
* Implement setitem syntax for `.oindex` and `.vindex` properties * Apply suggestions from code review Co-authored-by: Deepak Cherian <[email protected]> * use getter and setter properties instead of func_get and func_set methods * delete unnecessary _indexing_array_and_key method * Add tests for IndexCallable class * fix bug/unnecessary code introduced in #8790 * add unit tests --------- Co-authored-by: Deepak Cherian <[email protected]>
1 parent c6c01b1 commit 79272c3

File tree

3 files changed

+174
-67
lines changed

3 files changed

+174
-67
lines changed

xarray/core/indexing.py

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -326,18 +326,23 @@ def as_integer_slice(value):
326326

327327

328328
class IndexCallable:
329-
"""Provide getitem syntax for a callable object."""
329+
"""Provide getitem and setitem syntax for callable objects."""
330330

331-
__slots__ = ("func",)
331+
__slots__ = ("getter", "setter")
332332

333-
def __init__(self, func):
334-
self.func = func
333+
def __init__(self, getter, setter=None):
334+
self.getter = getter
335+
self.setter = setter
335336

336337
def __getitem__(self, key):
337-
return self.func(key)
338+
return self.getter(key)
338339

339340
def __setitem__(self, key, value):
340-
raise NotImplementedError
341+
if self.setter is None:
342+
raise NotImplementedError(
343+
"Setting values is not supported for this indexer."
344+
)
345+
self.setter(key, value)
341346

342347

343348
class BasicIndexer(ExplicitIndexer):
@@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
486491
return np.asarray(self.get_duck_array(), dtype=dtype)
487492

488493
def _oindex_get(self, key):
489-
raise NotImplementedError("This method should be overridden")
494+
raise NotImplementedError(
495+
f"{self.__class__.__name__}._oindex_get method should be overridden"
496+
)
490497

491498
def _vindex_get(self, key):
492-
raise NotImplementedError("This method should be overridden")
499+
raise NotImplementedError(
500+
f"{self.__class__.__name__}._vindex_get method should be overridden"
501+
)
502+
503+
def _oindex_set(self, key, value):
504+
raise NotImplementedError(
505+
f"{self.__class__.__name__}._oindex_set method should be overridden"
506+
)
507+
508+
def _vindex_set(self, key, value):
509+
raise NotImplementedError(
510+
f"{self.__class__.__name__}._vindex_set method should be overridden"
511+
)
493512

494513
def _check_and_raise_if_non_basic_indexer(self, key):
495514
if isinstance(key, (VectorizedIndexer, OuterIndexer)):
@@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key):
500519

501520
@property
502521
def oindex(self):
503-
return IndexCallable(self._oindex_get)
522+
return IndexCallable(self._oindex_get, self._oindex_set)
504523

505524
@property
506525
def vindex(self):
507-
return IndexCallable(self._vindex_get)
526+
return IndexCallable(self._vindex_get, self._vindex_set)
508527

509528

510529
class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
@@ -616,12 +635,18 @@ def __getitem__(self, indexer):
616635
self._check_and_raise_if_non_basic_indexer(indexer)
617636
return type(self)(self.array, self._updated_key(indexer))
618637

638+
def _vindex_set(self, key, value):
639+
raise NotImplementedError(
640+
"Lazy item assignment with the vectorized indexer is not yet "
641+
"implemented. Load your data first by .load() or compute()."
642+
)
643+
644+
def _oindex_set(self, key, value):
645+
full_key = self._updated_key(key)
646+
self.array.oindex[full_key] = value
647+
619648
def __setitem__(self, key, value):
620-
if isinstance(key, VectorizedIndexer):
621-
raise NotImplementedError(
622-
"Lazy item assignment with the vectorized indexer is not yet "
623-
"implemented. Load your data first by .load() or compute()."
624-
)
649+
self._check_and_raise_if_non_basic_indexer(key)
625650
full_key = self._updated_key(key)
626651
self.array[full_key] = value
627652

@@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]:
657682
return np.broadcast(*self.key.tuple).shape
658683

659684
def get_duck_array(self):
660-
661685
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
662686
array = apply_indexer(self.array, self.key)
663687
else:
@@ -739,8 +763,18 @@ def __getitem__(self, key):
739763
def transpose(self, order):
740764
return self.array.transpose(order)
741765

766+
def _vindex_set(self, key, value):
767+
self._ensure_copied()
768+
self.array.vindex[key] = value
769+
770+
def _oindex_set(self, key, value):
771+
self._ensure_copied()
772+
self.array.oindex[key] = value
773+
742774
def __setitem__(self, key, value):
775+
self._check_and_raise_if_non_basic_indexer(key)
743776
self._ensure_copied()
777+
744778
self.array[key] = value
745779

746780
def __deepcopy__(self, memo):
@@ -779,7 +813,14 @@ def __getitem__(self, key):
779813
def transpose(self, order):
780814
return self.array.transpose(order)
781815

816+
def _vindex_set(self, key, value):
817+
self.array.vindex[key] = value
818+
819+
def _oindex_set(self, key, value):
820+
self.array.oindex[key] = value
821+
782822
def __setitem__(self, key, value):
823+
self._check_and_raise_if_non_basic_indexer(key)
783824
self.array[key] = value
784825

785826

@@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer):
950991
return indexable[indexer]
951992

952993

994+
def set_with_indexer(indexable, indexer, value):
995+
"""Set values in an indexable object using an indexer."""
996+
if isinstance(indexer, VectorizedIndexer):
997+
indexable.vindex[indexer] = value
998+
elif isinstance(indexer, OuterIndexer):
999+
indexable.oindex[indexer] = value
1000+
else:
1001+
indexable[indexer] = value
1002+
1003+
9531004
def decompose_indexer(
9541005
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
9551006
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
@@ -1399,24 +1450,6 @@ def __init__(self, array):
13991450
)
14001451
self.array = array
14011452

1402-
def _indexing_array_and_key(self, key):
1403-
if isinstance(key, OuterIndexer):
1404-
array = self.array
1405-
key = _outer_to_numpy_indexer(key, self.array.shape)
1406-
elif isinstance(key, VectorizedIndexer):
1407-
array = NumpyVIndexAdapter(self.array)
1408-
key = key.tuple
1409-
elif isinstance(key, BasicIndexer):
1410-
array = self.array
1411-
# We want 0d slices rather than scalars. This is achieved by
1412-
# appending an ellipsis (see
1413-
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
1414-
key = key.tuple + (Ellipsis,)
1415-
else:
1416-
raise TypeError(f"unexpected key type: {type(key)}")
1417-
1418-
return array, key
1419-
14201453
def transpose(self, order):
14211454
return self.array.transpose(order)
14221455

@@ -1430,22 +1463,43 @@ def _vindex_get(self, key):
14301463

14311464
def __getitem__(self, key):
14321465
self._check_and_raise_if_non_basic_indexer(key)
1433-
array, key = self._indexing_array_and_key(key)
1466+
1467+
array = self.array
1468+
# We want 0d slices rather than scalars. This is achieved by
1469+
# appending an ellipsis (see
1470+
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
1471+
key = key.tuple + (Ellipsis,)
14341472
return array[key]
14351473

1436-
def __setitem__(self, key, value):
1437-
array, key = self._indexing_array_and_key(key)
1474+
def _safe_setitem(self, array, key, value):
14381475
try:
14391476
array[key] = value
1440-
except ValueError:
1477+
except ValueError as exc:
14411478
# More informative exception if read-only view
14421479
if not array.flags.writeable and not array.flags.owndata:
14431480
raise ValueError(
14441481
"Assignment destination is a view. "
14451482
"Do you want to .copy() array first?"
14461483
)
14471484
else:
1448-
raise
1485+
raise exc
1486+
1487+
def _oindex_set(self, key, value):
1488+
key = _outer_to_numpy_indexer(key, self.array.shape)
1489+
self._safe_setitem(self.array, key, value)
1490+
1491+
def _vindex_set(self, key, value):
1492+
array = NumpyVIndexAdapter(self.array)
1493+
self._safe_setitem(array, key.tuple, value)
1494+
1495+
def __setitem__(self, key, value):
1496+
self._check_and_raise_if_non_basic_indexer(key)
1497+
array = self.array
1498+
# We want 0d slices rather than scalars. This is achieved by
1499+
# appending an ellipsis (see
1500+
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
1501+
key = key.tuple + (Ellipsis,)
1502+
self._safe_setitem(array, key, value)
14491503

14501504

14511505
class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
@@ -1488,13 +1542,15 @@ def __getitem__(self, key):
14881542
self._check_and_raise_if_non_basic_indexer(key)
14891543
return self.array[key.tuple]
14901544

1545+
def _oindex_set(self, key, value):
1546+
self.array[key.tuple] = value
1547+
1548+
def _vindex_set(self, key, value):
1549+
raise TypeError("Vectorized indexing is not supported")
1550+
14911551
def __setitem__(self, key, value):
1492-
if isinstance(key, (BasicIndexer, OuterIndexer)):
1493-
self.array[key.tuple] = value
1494-
elif isinstance(key, VectorizedIndexer):
1495-
raise TypeError("Vectorized indexing is not supported")
1496-
else:
1497-
raise TypeError(f"Unrecognized indexer: {key}")
1552+
self._check_and_raise_if_non_basic_indexer(key)
1553+
self.array[key.tuple] = value
14981554

14991555
def transpose(self, order):
15001556
xp = self.array.__array_namespace__()
@@ -1530,19 +1586,20 @@ def __getitem__(self, key):
15301586
self._check_and_raise_if_non_basic_indexer(key)
15311587
return self.array[key.tuple]
15321588

1589+
def _oindex_set(self, key, value):
1590+
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
1591+
if num_non_slices > 1:
1592+
raise NotImplementedError(
1593+
"xarray can't set arrays with multiple " "array indices to dask yet."
1594+
)
1595+
self.array[key.tuple] = value
1596+
1597+
def _vindex_set(self, key, value):
1598+
self.array.vindex[key.tuple] = value
1599+
15331600
def __setitem__(self, key, value):
1534-
if isinstance(key, BasicIndexer):
1535-
self.array[key.tuple] = value
1536-
elif isinstance(key, VectorizedIndexer):
1537-
self.array.vindex[key.tuple] = value
1538-
elif isinstance(key, OuterIndexer):
1539-
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
1540-
if num_non_slices > 1:
1541-
raise NotImplementedError(
1542-
"xarray can't set arrays with multiple "
1543-
"array indices to dask yet."
1544-
)
1545-
self.array[key.tuple] = value
1601+
self._check_and_raise_if_non_basic_indexer(key)
1602+
self.array[key.tuple] = value
15461603

15471604
def transpose(self, order):
15481605
return self.array.transpose(order)

xarray/core/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def __setitem__(self, key, value):
849849
value = np.moveaxis(value, new_order, range(len(new_order)))
850850

851851
indexable = as_indexable(self._data)
852-
indexable[index_tuple] = value
852+
indexing.set_with_indexer(indexable, index_tuple, value)
853853

854854
@property
855855
def encoding(self) -> dict[Any, Any]:

xarray/tests/test_indexing.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@
2323
B = IndexerMaker(indexing.BasicIndexer)
2424

2525

26+
class TestIndexCallable:
27+
def test_getitem(self):
28+
def getter(key):
29+
return key * 2
30+
31+
indexer = indexing.IndexCallable(getter)
32+
assert indexer[3] == 6
33+
assert indexer[0] == 0
34+
assert indexer[-1] == -2
35+
36+
def test_setitem(self):
37+
def getter(key):
38+
return key * 2
39+
40+
def setter(key, value):
41+
raise NotImplementedError("Setter not implemented")
42+
43+
indexer = indexing.IndexCallable(getter, setter)
44+
with pytest.raises(NotImplementedError):
45+
indexer[3] = 6
46+
47+
2648
class TestIndexers:
2749
def set_to_zero(self, x, i):
2850
x = x.copy()
@@ -361,15 +383,8 @@ def test_vectorized_lazily_indexed_array(self) -> None:
361383

362384
def check_indexing(v_eager, v_lazy, indexers):
363385
for indexer in indexers:
364-
if isinstance(indexer, indexing.VectorizedIndexer):
365-
actual = v_lazy.vindex[indexer]
366-
expected = v_eager.vindex[indexer]
367-
elif isinstance(indexer, indexing.OuterIndexer):
368-
actual = v_lazy.oindex[indexer]
369-
expected = v_eager.oindex[indexer]
370-
else:
371-
actual = v_lazy[indexer]
372-
expected = v_eager[indexer]
386+
actual = v_lazy[indexer]
387+
expected = v_eager[indexer]
373388
assert expected.shape == actual.shape
374389
assert isinstance(
375390
actual._data,
@@ -406,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers):
406421
]
407422
check_indexing(v_eager, v_lazy, indexers)
408423

424+
def test_lazily_indexed_array_vindex_setitem(self) -> None:
425+
426+
lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30))
427+
428+
# vectorized indexing
429+
indexer = indexing.VectorizedIndexer(
430+
(np.array([0, 1]), np.array([0, 1]), slice(None, None, None))
431+
)
432+
with pytest.raises(
433+
NotImplementedError,
434+
match=r"Lazy item assignment with the vectorized indexer is not yet",
435+
):
436+
lazy.vindex[indexer] = 0
437+
438+
@pytest.mark.parametrize(
439+
"indexer_class, key, value",
440+
[
441+
(indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10),
442+
(indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10),
443+
],
444+
)
445+
def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None:
446+
original = np.random.rand(10, 20, 30)
447+
x = indexing.NumpyIndexingAdapter(original)
448+
lazy = indexing.LazilyIndexedArray(x)
449+
450+
if indexer_class is indexing.BasicIndexer:
451+
indexer = indexer_class(key)
452+
lazy[indexer] = value
453+
elif indexer_class is indexing.OuterIndexer:
454+
indexer = indexer_class(key)
455+
lazy.oindex[indexer] = value
456+
457+
assert_array_equal(original[key], value)
458+
409459

410460
class TestCopyOnWriteArray:
411461
def test_setitem(self) -> None:

0 commit comments

Comments
 (0)