Skip to content

Commit

Permalink
Option to not auto-create index during expand_dims (#8960)
Browse files Browse the repository at this point in the history
* test expand_dims doesn't create Index

* add option to not create 1D index in expand_dims

* refactor tests to consider data variables and coordinate variables separately

* fix bug causing new test to fail

* test index auto-creation when iterable passed as new coordinate values

* make test for iterable pass

* added kwarg to dataarray

* whatsnew

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests to use private versions of assertions

* create_1d_index->create_index

* Update doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>

* warn if create_index=True but no index created because dimension variable was a data var not a coord

* add string marks in warning message

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Apr 27, 2024
1 parent d7edbd7 commit 214d941
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 9 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ New Features
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
`create_index=False`. (:pull:`8960`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 7 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,6 +2557,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -2566,6 +2567,9 @@ def expand_dims(
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.
The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.
Parameters
----------
dim : Hashable, sequence of Hashable, dict, or None, optional
Expand All @@ -2581,6 +2585,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -2644,7 +2650,7 @@ def expand_dims(
dim = {dim: 1}

dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
ds = self._to_temp_dataset().expand_dims(dim, axis)
ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index)
return self._from_temp_dataset(ds)

def set_index(
Expand Down
39 changes: 31 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4497,6 +4497,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -4506,6 +4507,9 @@ def expand_dims(
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.
The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.
Parameters
----------
dim : hashable, sequence of hashable, mapping, or None
Expand All @@ -4521,6 +4525,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -4640,9 +4646,14 @@ def expand_dims(
# save the coordinates to the variables dict, and set the
# value within the dim dict to the length of the iterable
# for later use.
index = PandasIndex(v, k)
indexes[k] = index
variables.update(index.create_variables())

if create_index:
index = PandasIndex(v, k)
indexes[k] = index
name_and_new_1d_var = index.create_variables()
else:
name_and_new_1d_var = {k: Variable(data=v, dims=k)}
variables.update(name_and_new_1d_var)
coord_names.add(k)
dim[k] = variables[k].size
elif isinstance(v, int):
Expand Down Expand Up @@ -4678,11 +4689,23 @@ def expand_dims(
variables[k] = v.set_dims(dict(all_dims))
else:
if k not in variables:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
indexes[k] = index
variables.update(index_vars)
if k in coord_names and create_index:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
indexes[k] = index
variables.update(index_vars)
else:
if create_index:
warnings.warn(
f"No index created for dimension {k} because variable {k} is not a coordinate. "
f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.",
UserWarning,
)

# create 1D variable without creating a new index
new_1d_var = v.set_dims(k)
variables.update({k: new_1d_var})

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,52 @@ def test_expand_dims_kwargs_python36plus(self) -> None:
)
assert_identical(other_way_expected, other_way)

@pytest.mark.parametrize("create_index_flag", [True, False])
def test_expand_dims_create_index_data_variable(self, create_index_flag):
# data variables should not gain an index ever
ds = Dataset({"x": 0})

if create_index_flag:
with pytest.warns(UserWarning, match="No index created"):
expanded = ds.expand_dims("x", create_index=create_index_flag)
else:
expanded = ds.expand_dims("x", create_index=create_index_flag)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x")

assert_identical(expanded, expected, check_default_indexes=False)
assert expanded.indexes == {}

def test_expand_dims_create_index_coordinate_variable(self):
# coordinate variables should gain an index only if create_index is True (the default)
ds = Dataset(coords={"x": 0})
expanded = ds.expand_dims("x")
expected = Dataset({"x": ("x", [0])})
assert_identical(expanded, expected)

expanded_no_index = ds.expand_dims("x", create_index=False)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x")

assert_identical(expanded_no_index, expected, check_default_indexes=False)
assert expanded_no_index.indexes == {}

def test_expand_dims_create_index_from_iterable(self):
ds = Dataset(coords={"x": 0})
expanded = ds.expand_dims(x=[0, 1])
expected = Dataset({"x": ("x", [0, 1])})
assert_identical(expanded, expected)

expanded_no_index = ds.expand_dims(x=[0, 1], create_index=False)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x")

assert_identical(expanded, expected, check_default_indexes=False)
assert expanded_no_index.indexes == {}

def test_expand_dims_non_nanosecond_conversion(self) -> None:
# Regression test for https://github.com/pydata/xarray/issues/7493#issuecomment-1953091000
with pytest.warns(UserWarning, match="non-nanosecond precision"):
Expand Down

0 comments on commit 214d941

Please sign in to comment.