Skip to content

Commit

Permalink
Avoid auto creation of indexes in concat (#8872)
Browse files Browse the repository at this point in the history
* test not creating indexes on concatenation

* construct result dataset using Coordinates object with indexes passed explicitly

* remove unnecessary overwriting of indexes

* ConcatenatableArray class

* use ConcatenableArray in tests

* add regression tests

* fix by performing check

* refactor assert_valid_explicit_coords and rename dims->sizes

* Revert "add regression tests"

This reverts commit beb665a.

* Revert "fix by performing check"

This reverts commit 22f361d.

* Revert "refactor assert_valid_explicit_coords and rename dims->sizes"

This reverts commit 55166fc.

* fix failing test

* possible fix for failing groupby test

* Revert "possible fix for failing groupby test"

This reverts commit 6e9ead6.

* test expand_dims doesn't create Index

* add option to not create 1D index in expand_dims

* refactor tests to consider data variables and coordinate variables separately

* test expand_dims doesn't create Index

* add option to not create 1D index in expand_dims

* refactor tests to consider data variables and coordinate variables separately

* fix bug causing new test to fail

* test index auto-creation when iterable passed as new coordinate values

* make test for iterable pass

* added kwarg to dataarray

* whatsnew

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "refactor tests to consider data variables and coordinate variables separately"

This reverts commit ba5627e.

* Revert "add option to not create 1D index in expand_dims"

This reverts commit 95d453c.

* test that concat doesn't raise if create_1d_index=False

* make test pass by passing create_1d_index down through concat

* assert that an UnexpectedDataAccess error is raised when create_1d_index=True

* eliminate possibility of xarray internals bypassing UnexpectedDataAccess error by accessing .array

* update tests to use private versions of assertions

* create_1d_index->create_index

* Update doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>

* Rename create_1d_index -> create_index

* fix ConcatenatableArray

* formatting

* whatsnew

* add new create_index kwarg to overloads

* split vars into data_vars and coord_vars in one loop

* avoid mypy error by using new variable name

* warn if create_index=True but no index created because dimension variable was a data var not a coord

* add string marks in warning message

* regression test for dtype changing in to_stacked_array

* correct doctest

* Remove outdated comment

* test we can skip creation of indexes during shape promotion

* make shape promotion test pass

* point to issue in whatsnew

* don't create dimension coordinates just to drop them at the end

* Remove ToDo about not using Coordinates object to pass indexes

Co-authored-by: Deepak Cherian <[email protected]>

* get rid of unlabeled_dims variable entirely

* move ConcatenatableArray and similar to new file

* formatting nit

Co-authored-by: Justus Magin <[email protected]>

* renamed create_index -> create_index_for_new_dim in concat

* renamed create_index -> create_index_for_new_dim in expand_dims

* fix incorrect arg name

* add example to docstring

* add example of using new kwarg to docstring of expand_dims

* add example of using new kwarg to docstring of concat

* re-nit the nit

Co-authored-by: Justus Magin <[email protected]>

* more instances of the nit

* fix docstring doctest formatting nit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Justus Magin <[email protected]>
  • Loading branch information
4 people authored May 8, 2024
1 parent 71661d5 commit 6057128
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 84 deletions.
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ New Features
- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`)
By `Ignacio Martinez Vazquez <https://github.com/ignamv>`_.
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
`create_index=False`. (:pull:`8960`)
`create_index_for_new_dim=False`. (:pull:`8960`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Avoid automatically re-creating 1D pandas indexes in :py:func:`concat()`. Also added option to avoid creating 1D indexes for
new dimension coordinates by passing the new kwarg `create_index_for_new_dim=False`. (:issue:`8871`, :pull:`8872`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
Expand Down
67 changes: 52 additions & 15 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from xarray.core import dtypes, utils
from xarray.core.alignment import align, reindex_variables
from xarray.core.coordinates import Coordinates
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
from xarray.core.merge import (
Expand Down Expand Up @@ -42,6 +43,7 @@ def concat(
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
) -> T_Dataset: ...


Expand All @@ -56,6 +58,7 @@ def concat(
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
) -> T_DataArray: ...


Expand All @@ -69,6 +72,7 @@ def concat(
fill_value=dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
):
"""Concatenate xarray objects along a new or existing dimension.
Expand Down Expand Up @@ -162,6 +166,8 @@ def concat(
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
create_index_for_new_dim : bool, default: True
Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``.
Returns
-------
Expand Down Expand Up @@ -217,6 +223,25 @@ def concat(
x (new_dim) <U1 8B 'a' 'b'
* y (y) int64 24B 10 20 30
* new_dim (new_dim) int64 16B -90 -100
# Concatenate a scalar variable along a new dimension of the same name with and without creating a new index
>>> ds = xr.Dataset(coords={"x": 0})
>>> xr.concat([ds, ds], dim="x")
<xarray.Dataset> Size: 16B
Dimensions: (x: 2)
Coordinates:
* x (x) int64 16B 0 0
Data variables:
*empty*
>>> xr.concat([ds, ds], dim="x").indexes
Indexes:
x Index([0, 0], dtype='int64', name='x')
>>> xr.concat([ds, ds], dim="x", create_index_for_new_dim=False).indexes
Indexes:
*empty*
"""
# TODO: add ignore_index arguments copied from pandas.concat
# TODO: support concatenating scalar coordinates even if the concatenated
Expand Down Expand Up @@ -245,6 +270,7 @@ def concat(
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
create_index_for_new_dim=create_index_for_new_dim,
)
elif isinstance(first_obj, Dataset):
return _dataset_concat(
Expand All @@ -257,6 +283,7 @@ def concat(
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
create_index_for_new_dim=create_index_for_new_dim,
)
else:
raise TypeError(
Expand Down Expand Up @@ -439,7 +466,7 @@ def _parse_datasets(
if dim in dims:
continue

if dim not in dim_coords:
if dim in ds.coords and dim not in dim_coords:
dim_coords[dim] = ds.coords[dim].variable
dims = dims | set(ds.dims)

Expand All @@ -456,6 +483,7 @@ def _dataset_concat(
fill_value: Any = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
Expand Down Expand Up @@ -489,7 +517,6 @@ def _dataset_concat(
datasets
)
dim_names = set(dim_coords)
unlabeled_dims = dim_names - coord_names

both_data_and_coords = coord_names & data_names
if both_data_and_coords:
Expand All @@ -502,15 +529,18 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
datasets = [ds.expand_dims(dim) for ds in datasets]
datasets = [
ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim)
for ds in datasets
]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
datasets, dim, dim_names, data_vars, coords, compat
)

# determine which variables to merge, and then merge them according to compat
variables_to_merge = (coord_names | data_names) - concat_over - unlabeled_dims
variables_to_merge = (coord_names | data_names) - concat_over

result_vars = {}
result_indexes = {}
Expand Down Expand Up @@ -567,7 +597,8 @@ def get_indexes(name):
var = ds._variables[name]
if not var.dims:
data = var.set_dims(dim).values
yield PandasIndex(data, dim, coord_dtype=var.dtype)
if create_index_for_new_dim:
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
Expand Down Expand Up @@ -646,29 +677,33 @@ def get_indexes(name):
# preserves original variable order
result_vars[name] = result_vars.pop(name)

result = type(datasets[0])(result_vars, attrs=result_attrs)

absent_coord_names = coord_names - set(result.variables)
absent_coord_names = coord_names - set(result_vars)
if absent_coord_names:
raise ValueError(
f"Variables {absent_coord_names!r} are coordinates in some datasets but not others."
)
result = result.set_coords(coord_names)
result.encoding = result_encoding

result = result.drop_vars(unlabeled_dims, errors="ignore")
result_data_vars = {}
coord_vars = {}
for name, result_var in result_vars.items():
if name in coord_names:
coord_vars[name] = result_var
else:
result_data_vars[name] = result_var

if index is not None:
# add concat index / coordinate last to ensure that its in the final Dataset
if dim_var is not None:
index_vars = index.create_variables({dim: dim_var})
else:
index_vars = index.create_variables()
result[dim] = index_vars[dim]

coord_vars[dim] = index_vars[dim]
result_indexes[dim] = index

# TODO: add indexes at Dataset creation (when it is supported)
result = result._overwrite_indexes(result_indexes)
coords_obj = Coordinates(coord_vars, indexes=result_indexes)

result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs)
result.encoding = result_encoding

return result

Expand All @@ -683,6 +718,7 @@ def _dataarray_concat(
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
create_index_for_new_dim: bool = True,
) -> T_DataArray:
from xarray.core.dataarray import DataArray

Expand Down Expand Up @@ -719,6 +755,7 @@ def _dataarray_concat(
fill_value=fill_value,
join=join,
combine_attrs=combine_attrs,
create_index_for_new_dim=create_index_for_new_dim,
)

merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs)
Expand Down
12 changes: 7 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2558,7 +2558,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
create_index_for_new_dim: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -2569,7 +2569,7 @@ def expand_dims(
coordinate consisting of a single value.
The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.
controlled by the create_index_for_new_dim kwarg.
Parameters
----------
Expand All @@ -2586,8 +2586,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
create_index_for_new_dim : bool, default: True
Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -2651,7 +2651,9 @@ def expand_dims(
dim = {dim: 1}

dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index)
ds = self._to_temp_dataset().expand_dims(
dim, axis, create_index_for_new_dim=create_index_for_new_dim
)
return self._from_temp_dataset(ds)

def set_index(
Expand Down
43 changes: 35 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4513,7 +4513,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
create_index_for_new_dim: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -4524,7 +4524,7 @@ def expand_dims(
coordinate consisting of a single value.
The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.
controlled by the create_index_for_new_dim kwarg.
Parameters
----------
Expand All @@ -4541,8 +4541,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
create_index_for_new_dim : bool, default: True
Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -4612,6 +4612,33 @@ def expand_dims(
Data variables:
temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289
# Expand a scalar variable along a new dimension of the same name with and without creating a new index
>>> ds = xr.Dataset(coords={"x": 0})
>>> ds
<xarray.Dataset> Size: 8B
Dimensions: ()
Coordinates:
x int64 8B 0
Data variables:
*empty*
>>> ds.expand_dims("x")
<xarray.Dataset> Size: 8B
Dimensions: (x: 1)
Coordinates:
* x (x) int64 8B 0
Data variables:
*empty*
>>> ds.expand_dims("x").indexes
Indexes:
x Index([0], dtype='int64', name='x')
>>> ds.expand_dims("x", create_index_for_new_dim=False).indexes
Indexes:
*empty*
See Also
--------
DataArray.expand_dims
Expand Down Expand Up @@ -4663,7 +4690,7 @@ def expand_dims(
# value within the dim dict to the length of the iterable
# for later use.

if create_index:
if create_index_for_new_dim:
index = PandasIndex(v, k)
indexes[k] = index
name_and_new_1d_var = index.create_variables()
Expand Down Expand Up @@ -4705,14 +4732,14 @@ def expand_dims(
variables[k] = v.set_dims(dict(all_dims))
else:
if k not in variables:
if k in coord_names and create_index:
if k in coord_names and create_index_for_new_dim:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
indexes[k] = index
variables.update(index_vars)
else:
if create_index:
if create_index_for_new_dim:
warnings.warn(
f"No index created for dimension {k} because variable {k} is not a coordinate. "
f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.",
Expand Down Expand Up @@ -5400,7 +5427,7 @@ def to_stacked_array(
[3, 4, 5, 7]])
Coordinates:
* z (z) object 32B MultiIndex
* variable (z) object 32B 'a' 'a' 'a' 'b'
* variable (z) <U1 16B 'a' 'a' 'a' 'b'
* y (z) object 32B 'u' 'v' 'w' nan
Dimensions without coordinates: x
Expand Down
Loading

0 comments on commit 6057128

Please sign in to comment.