Skip to content

Commit 4bb9d9c

Browse files
authored
Refactor index vs. coordinate variable(s) (#5636)
* split index / coordinate variable(s) - Pass Variable objects to xarray.Index constructor - The index should create IndexVariable objects (`coords` attribute) - PandasIndex: IndexVariable wraps PandasIndexingAdpater wraps pd.Index * one PandasIndexingAdapter subclass for multiindex * fastpath Index init + from_pandas_index classmethods * use classmethod constructors instead * add Index.copy and Index.__getitem__ methods * wip: clean-up Revert some changes made in #5102 + additional (temporary) fixes. * clean-up * add PandasIndex and PandasMultiIndex tests * remove unused import * doc: update what's new * use xindexes in map_blocks + temp fix Dataset constructor doesn't accept xarray indexes yet. Create new coordinates from the underlying pandas indexes. * update what's new with #5670 * typo
1 parent 08b3e80 commit 4bb9d9c

17 files changed

+608
-282
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ Documentation
4343
Internal Changes
4444
~~~~~~~~~~~~~~~~
4545

46+
- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`).
47+
By `Deepak Cherian <https://github.com/dcherian>`_.
48+
- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`).
49+
By `Benoit Bovy <https://github.com/benbovy>`_.
4650
- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`)
4751
By `Jimmy Westling <https://github.com/illviljan>`_.
4852

xarray/core/alignment.py

+30-16
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pandas as pd
1919

2020
from . import dtypes
21-
from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index
21+
from .indexes import Index, PandasIndex, get_indexer_nd
2222
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
2323
from .variable import IndexVariable, Variable
2424

@@ -53,7 +53,10 @@ def _get_joiner(join, index_cls):
5353
def _override_indexes(objects, all_indexes, exclude):
5454
for dim, dim_indexes in all_indexes.items():
5555
if dim not in exclude:
56-
lengths = {index.size for index in dim_indexes}
56+
lengths = {
57+
getattr(index, "size", index.to_pandas_index().size)
58+
for index in dim_indexes
59+
}
5760
if len(lengths) != 1:
5861
raise ValueError(
5962
f"Indexes along dimension {dim!r} don't have the same length."
@@ -300,16 +303,14 @@ def align(
300303
joined_indexes = {}
301304
for dim, matching_indexes in all_indexes.items():
302305
if dim in indexes:
303-
# TODO: benbovy - flexible indexes. maybe move this logic in util func
304-
if isinstance(indexes[dim], Index):
305-
index = indexes[dim]
306-
else:
307-
index = PandasIndex(safe_cast_to_index(indexes[dim]))
306+
index, _ = PandasIndex.from_pandas_index(
307+
safe_cast_to_index(indexes[dim]), dim
308+
)
308309
if (
309310
any(not index.equals(other) for other in matching_indexes)
310311
or dim in unlabeled_dim_sizes
311312
):
312-
joined_indexes[dim] = index
313+
joined_indexes[dim] = indexes[dim]
313314
else:
314315
if (
315316
any(
@@ -323,17 +324,18 @@ def align(
323324
joiner = _get_joiner(join, type(matching_indexes[0]))
324325
index = joiner(matching_indexes)
325326
# make sure str coords are not cast to object
326-
index = maybe_coerce_to_str(index, all_coords[dim])
327+
index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim])
327328
joined_indexes[dim] = index
328329
else:
329330
index = all_coords[dim][0]
330331

331332
if dim in unlabeled_dim_sizes:
332333
unlabeled_sizes = unlabeled_dim_sizes[dim]
333-
# TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
334-
# Some indexes may not have a defined size (e.g., built from multiple coords of
335-
# different sizes)
336-
labeled_size = index.size
334+
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
335+
if isinstance(index, PandasIndex):
336+
labeled_size = index.to_pandas_index().size
337+
else:
338+
labeled_size = index.size
337339
if len(unlabeled_sizes | {labeled_size}) > 1:
338340
raise ValueError(
339341
f"arguments without labels along dimension {dim!r} cannot be "
@@ -350,7 +352,14 @@ def align(
350352

351353
result = []
352354
for obj in objects:
353-
valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims}
355+
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
356+
valid_indexers = {}
357+
for k, index in joined_indexes.items():
358+
if k in obj.dims:
359+
if isinstance(index, Index):
360+
valid_indexers[k] = index.to_pandas_index()
361+
else:
362+
valid_indexers[k] = index
354363
if not valid_indexers:
355364
# fast path for no reindexing necessary
356365
new_obj = obj.copy(deep=copy)
@@ -471,7 +480,11 @@ def reindex_like_indexers(
471480
ValueError
472481
If any dimensions without labels have different sizes.
473482
"""
474-
indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
483+
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
484+
# this doesn't support yet indexes other than pd.Index
485+
indexers = {
486+
k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims
487+
}
475488

476489
for dim in other.dims:
477490
if dim not in indexers and dim in target.dims:
@@ -560,7 +573,8 @@ def reindex_variables(
560573
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
561574
)
562575

563-
target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim]))
576+
target = safe_cast_to_index(indexers[dim])
577+
new_indexes[dim] = PandasIndex(target, dim)
564578

565579
if dim in indexes:
566580
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?

xarray/core/combine.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets):
7777
"inferring concatenation order"
7878
)
7979

80-
# TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
81-
# get pd.Index objects from Index objects
82-
indexes = [index.array for index in indexes]
80+
# TODO (benbovy, flexible indexes): support flexible indexes?
81+
indexes = [index.to_pandas_index() for index in indexes]
8382

8483
# If dimension coordinate values are same on every dataset then
8584
# should be leaving this dimension alone (it's just a "bystander")

xarray/core/dataarray.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,7 @@
5151
)
5252
from .dataset import Dataset, split_indexes
5353
from .formatting import format_item
54-
from .indexes import (
55-
Index,
56-
Indexes,
57-
default_indexes,
58-
propagate_indexes,
59-
wrap_pandas_index,
60-
)
54+
from .indexes import Index, Indexes, default_indexes, propagate_indexes
6155
from .indexing import is_fancy_indexer
6256
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
6357
from .options import OPTIONS, _get_keep_attrs
@@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
473467
return self
474468
coords = self._coords.copy()
475469
for name, idx in indexes.items():
476-
coords[name] = IndexVariable(name, idx)
470+
coords[name] = IndexVariable(name, idx.to_pandas_index())
477471
obj = self._replace(coords=coords)
478472

479473
# switch from dimension to level names, if necessary
480474
dim_names: Dict[Any, str] = {}
481475
for dim, idx in indexes.items():
482-
# TODO: benbovy - flexible indexes: update when MultiIndex has its own class
483-
pd_idx = idx.array
484-
if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
476+
pd_idx = idx.to_pandas_index()
477+
if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim:
485478
dim_names[dim] = idx.name
486479
if dim_names:
487480
obj = obj.rename(dim_names)
@@ -1046,12 +1039,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
10461039
if self._indexes is None:
10471040
indexes = self._indexes
10481041
else:
1049-
# TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
1050-
# xarray Index needs a copy method.
1051-
indexes = {
1052-
k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep))
1053-
for k, v in self._indexes.items()
1054-
}
1042+
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
10551043
return self._replace(variable, coords, indexes=indexes)
10561044

10571045
def __copy__(self) -> "DataArray":

xarray/core/dataset.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
propagate_indexes,
7272
remove_unused_levels_categories,
7373
roll_index,
74-
wrap_pandas_index,
7574
)
7675
from .indexing import is_fancy_indexer
7776
from .merge import (
@@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset":
11841183
variables = self._variables.copy()
11851184
new_indexes = dict(self.xindexes)
11861185
for name, idx in indexes.items():
1187-
variables[name] = IndexVariable(name, idx)
1186+
variables[name] = IndexVariable(name, idx.to_pandas_index())
11881187
new_indexes[name] = idx
11891188
obj = self._replace(variables, indexes=new_indexes)
11901189

@@ -2474,6 +2473,10 @@ def sel(
24742473
pos_indexers, new_indexes = remap_label_indexers(
24752474
self, indexers=indexers, method=method, tolerance=tolerance
24762475
)
2476+
# TODO: benbovy - flexible indexes: also use variables returned by Index.query
2477+
# (temporary dirty fix).
2478+
new_indexes = {k: v[0] for k, v in new_indexes.items()}
2479+
24772480
result = self.isel(indexers=pos_indexers, drop=drop)
24782481
return result._overwrite_indexes(new_indexes)
24792482

@@ -3297,20 +3300,21 @@ def _rename_dims(self, name_dict):
32973300
return {name_dict.get(k, k): v for k, v in self.dims.items()}
32983301

32993302
def _rename_indexes(self, name_dict, dims_set):
3303+
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645
33003304
if self._indexes is None:
33013305
return None
33023306
indexes = {}
3303-
for k, v in self.xindexes.items():
3304-
# TODO: benbovy - flexible indexes: make it compatible with any xarray Index
3305-
index = v.to_pandas_index()
3307+
for k, v in self.indexes.items():
33063308
new_name = name_dict.get(k, k)
33073309
if new_name not in dims_set:
33083310
continue
3309-
if isinstance(index, pd.MultiIndex):
3310-
new_names = [name_dict.get(k, k) for k in index.names]
3311-
indexes[new_name] = PandasMultiIndex(index.rename(names=new_names))
3311+
if isinstance(v, pd.MultiIndex):
3312+
new_names = [name_dict.get(k, k) for k in v.names]
3313+
indexes[new_name] = PandasMultiIndex(
3314+
v.rename(names=new_names), new_name
3315+
)
33123316
else:
3313-
indexes[new_name] = PandasIndex(index.rename(new_name))
3317+
indexes[new_name] = PandasIndex(v.rename(new_name), new_name)
33143318
return indexes
33153319

33163320
def _rename_all(self, name_dict, dims_dict):
@@ -3539,7 +3543,10 @@ def swap_dims(
35393543
if new_index.nlevels == 1:
35403544
# make sure index name matches dimension name
35413545
new_index = new_index.rename(k)
3542-
indexes[k] = wrap_pandas_index(new_index)
3546+
if isinstance(new_index, pd.MultiIndex):
3547+
indexes[k] = PandasMultiIndex(new_index, k)
3548+
else:
3549+
indexes[k] = PandasIndex(new_index, k)
35433550
else:
35443551
var = v.to_base_variable()
35453552
var.dims = dims
@@ -3812,7 +3819,7 @@ def reorder_levels(
38123819
raise ValueError(f"coordinate {dim} has no MultiIndex")
38133820
new_index = index.reorder_levels(order)
38143821
variables[dim] = IndexVariable(coord.dims, new_index)
3815-
indexes[dim] = PandasMultiIndex(new_index)
3822+
indexes[dim] = PandasMultiIndex(new_index, dim)
38163823

38173824
return self._replace(variables, indexes=indexes)
38183825

@@ -3840,7 +3847,7 @@ def _stack_once(self, dims, new_dim):
38403847
coord_names = set(self._coord_names) - set(dims) | {new_dim}
38413848

38423849
indexes = {k: v for k, v in self.xindexes.items() if k not in dims}
3843-
indexes[new_dim] = wrap_pandas_index(idx)
3850+
indexes[new_dim] = PandasMultiIndex(idx, new_dim)
38443851

38453852
return self._replace_with_new_dims(
38463853
variables, coord_names=coord_names, indexes=indexes
@@ -4029,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
40294036
variables[name] = var
40304037

40314038
for name, lev in zip(index.names, index.levels):
4032-
variables[name] = IndexVariable(name, lev)
4033-
indexes[name] = PandasIndex(lev)
4039+
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
4040+
variables[name] = idx_vars[name]
4041+
indexes[name] = idx
40344042

40354043
coord_names = set(self._coord_names) - {dim} | set(index.names)
40364044

@@ -4068,8 +4076,9 @@ def _unstack_full_reindex(
40684076
variables[name] = var
40694077

40704078
for name, lev in zip(new_dim_names, index.levels):
4071-
variables[name] = IndexVariable(name, lev)
4072-
indexes[name] = PandasIndex(lev)
4079+
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
4080+
variables[name] = idx_vars[name]
4081+
indexes[name] = idx
40734082

40744083
coord_names = set(self._coord_names) - {dim} | set(new_dim_names)
40754084

@@ -5839,10 +5848,13 @@ def diff(self, dim, n=1, label="upper"):
58395848

58405849
indexes = dict(self.xindexes)
58415850
if dim in indexes:
5842-
# TODO: benbovy - flexible indexes: check slicing of xarray indexes?
5843-
# or only allow this for pandas indexes?
5844-
index = indexes[dim].to_pandas_index()
5845-
indexes[dim] = PandasIndex(index[kwargs_new[dim]])
5851+
if isinstance(indexes[dim], PandasIndex):
5852+
# maybe optimize? (pandas index already indexed above with var.isel)
5853+
new_index = indexes[dim].index[kwargs_new[dim]]
5854+
if isinstance(new_index, pd.MultiIndex):
5855+
indexes[dim] = PandasMultiIndex(new_index, dim)
5856+
else:
5857+
indexes[dim] = PandasIndex(new_index, dim)
58465858

58475859
difference = self._replace_with_new_dims(variables, indexes=indexes)
58485860

0 commit comments

Comments
 (0)