Skip to content

Commit e47eb92

Browse files
authored
introduce .vindex property for Explicitly Indexed Arrays (#8780)
1 parent 0ec1912 commit e47eb92

File tree

3 files changed

+47
-6
lines changed

3 files changed

+47
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ New Features
2626
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
2727
By `Anderson Banihirwe <https://github.com/andersy005>`_.
2828

29+
- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`)
30+
By `Anderson Banihirwe <https://github.com/andersy005>`_.
2931

3032
Breaking changes
3133
~~~~~~~~~~~~~~~~

xarray/core/indexing.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,10 +488,17 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
488488
def _oindex_get(self, key):
489489
raise NotImplementedError("This method should be overridden")
490490

491+
def _vindex_get(self, key):
492+
raise NotImplementedError("This method should be overridden")
493+
491494
@property
492495
def oindex(self):
493496
return IndexCallable(self._oindex_get)
494497

498+
@property
499+
def vindex(self):
500+
return IndexCallable(self._vindex_get)
501+
495502

496503
class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
497504
"""Wrap an array, converting tuples into the indicated explicit indexer."""
@@ -585,6 +592,10 @@ def transpose(self, order):
585592
def _oindex_get(self, indexer):
586593
return type(self)(self.array, self._updated_key(indexer))
587594

595+
def _vindex_get(self, indexer):
596+
array = LazilyVectorizedIndexedArray(self.array, self.key)
597+
return array[indexer]
598+
588599
def __getitem__(self, indexer):
589600
if isinstance(indexer, VectorizedIndexer):
590601
array = LazilyVectorizedIndexedArray(self.array, self.key)
@@ -644,6 +655,12 @@ def get_duck_array(self):
644655
def _updated_key(self, new_key):
645656
return _combine_indexers(self.key, self.shape, new_key)
646657

658+
def _oindex_get(self, indexer):
659+
return type(self)(self.array, self._updated_key(indexer))
660+
661+
def _vindex_get(self, indexer):
662+
return type(self)(self.array, self._updated_key(indexer))
663+
647664
def __getitem__(self, indexer):
648665
# If the indexed array becomes a scalar, return LazilyIndexedArray
649666
if all(isinstance(ind, integer_types) for ind in indexer.tuple):
@@ -691,6 +708,9 @@ def get_duck_array(self):
691708
def _oindex_get(self, key):
692709
return type(self)(_wrap_numpy_scalars(self.array[key]))
693710

711+
def _vindex_get(self, key):
712+
return type(self)(_wrap_numpy_scalars(self.array[key]))
713+
694714
def __getitem__(self, key):
695715
return type(self)(_wrap_numpy_scalars(self.array[key]))
696716

@@ -727,6 +747,9 @@ def get_duck_array(self):
727747
def _oindex_get(self, key):
728748
return type(self)(_wrap_numpy_scalars(self.array[key]))
729749

750+
def _vindex_get(self, key):
751+
return type(self)(_wrap_numpy_scalars(self.array[key]))
752+
730753
def __getitem__(self, key):
731754
return type(self)(_wrap_numpy_scalars(self.array[key]))
732755

@@ -1364,8 +1387,12 @@ def transpose(self, order):
13641387
return self.array.transpose(order)
13651388

13661389
def _oindex_get(self, key):
1367-
array, key = self._indexing_array_and_key(key)
1368-
return array[key]
1390+
key = _outer_to_numpy_indexer(key, self.array.shape)
1391+
return self.array[key]
1392+
1393+
def _vindex_get(self, key):
1394+
array = NumpyVIndexAdapter(self.array)
1395+
return array[key.tuple]
13691396

13701397
def __getitem__(self, key):
13711398
array, key = self._indexing_array_and_key(key)
@@ -1419,6 +1446,9 @@ def _oindex_get(self, key):
14191446
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
14201447
return value
14211448

1449+
def _vindex_get(self, key):
1450+
raise TypeError("Vectorized indexing is not supported")
1451+
14221452
def __getitem__(self, key):
14231453
if isinstance(key, BasicIndexer):
14241454
return self.array[key.tuple]
@@ -1465,11 +1495,14 @@ def _oindex_get(self, key):
14651495
value = value[(slice(None),) * axis + (subkey,)]
14661496
return value
14671497

1498+
def _vindex_get(self, key):
1499+
return self.array.vindex[key.tuple]
1500+
14681501
def __getitem__(self, key):
14691502
if isinstance(key, BasicIndexer):
14701503
return self.array[key.tuple]
14711504
elif isinstance(key, VectorizedIndexer):
1472-
return self.array.vindex[key.tuple]
1505+
return self.vindex[key]
14731506
else:
14741507
assert isinstance(key, OuterIndexer)
14751508
return self.oindex[key]
@@ -1551,6 +1584,9 @@ def _convert_scalar(self, item):
15511584
def _oindex_get(self, key):
15521585
return self.__getitem__(key)
15531586

1587+
def _vindex_get(self, key):
1588+
return self.__getitem__(key)
1589+
15541590
def __getitem__(
15551591
self, indexer
15561592
) -> (

xarray/core/variable.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,10 +759,10 @@ def __getitem__(self, key) -> Self:
759759
dims, indexer, new_order = self._broadcast_indexes(key)
760760
indexable = as_indexable(self._data)
761761

762-
if isinstance(indexer, BasicIndexer):
763-
data = indexable[indexer]
764-
elif isinstance(indexer, OuterIndexer):
762+
if isinstance(indexer, OuterIndexer):
765763
data = indexable.oindex[indexer]
764+
elif isinstance(indexer, VectorizedIndexer):
765+
data = indexable.vindex[indexer]
766766
else:
767767
data = indexable[indexer]
768768
if new_order:
@@ -801,6 +801,9 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
801801

802802
if isinstance(indexer, OuterIndexer):
803803
data = indexable.oindex[indexer]
804+
805+
elif isinstance(indexer, VectorizedIndexer):
806+
data = indexable.vindex[indexer]
804807
else:
805808
data = indexable[actual_indexer]
806809
mask = indexing.create_mask(indexer, self.shape, data)

0 commit comments

Comments
 (0)