Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to not auto-create index during expand_dims #8960

Merged
merged 18 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ New Features
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
`create_index=False`. (:pull:`8960`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 7 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,6 +2557,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -2566,6 +2567,9 @@ def expand_dims(
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.

The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.

Parameters
----------
dim : Hashable, sequence of Hashable, dict, or None, optional
Expand All @@ -2581,6 +2585,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -2644,7 +2650,7 @@ def expand_dims(
dim = {dim: 1}

dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
ds = self._to_temp_dataset().expand_dims(dim, axis)
ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index)
return self._from_temp_dataset(ds)

def set_index(
Expand Down
39 changes: 31 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4497,6 +4497,7 @@ def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
create_index: bool = True,
**dim_kwargs: Any,
) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
Expand All @@ -4506,6 +4507,9 @@ def expand_dims(
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.

The automatic creation of indexes to back new 1D coordinate variables
controlled by the create_index kwarg.

Parameters
----------
dim : hashable, sequence of hashable, mapping, or None
Expand All @@ -4521,6 +4525,8 @@ def expand_dims(
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
create_index : bool, default is True
Whether to create new PandasIndex objects for any new 1D coordinate variables.
**dim_kwargs : int or sequence or ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
Expand Down Expand Up @@ -4640,9 +4646,14 @@ def expand_dims(
# save the coordinates to the variables dict, and set the
# value within the dim dict to the length of the iterable
# for later use.
index = PandasIndex(v, k)
indexes[k] = index
variables.update(index.create_variables())

if create_index:
index = PandasIndex(v, k)
indexes[k] = index
name_and_new_1d_var = index.create_variables()
else:
name_and_new_1d_var = {k: Variable(data=v, dims=k)}
variables.update(name_and_new_1d_var)
coord_names.add(k)
dim[k] = variables[k].size
elif isinstance(v, int):
Expand Down Expand Up @@ -4678,11 +4689,23 @@ def expand_dims(
variables[k] = v.set_dims(dict(all_dims))
else:
if k not in variables:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
indexes[k] = index
variables.update(index_vars)
if k in coord_names and create_index:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
indexes[k] = index
variables.update(index_vars)
else:
if create_index:
warnings.warn(
f"No index created for dimension {k} because variable {k} is not a coordinate. "
f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.",
UserWarning,
)

# create 1D variable without creating a new index
new_1d_var = v.set_dims(k)
variables.update({k: new_1d_var})

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,52 @@ def test_expand_dims_kwargs_python36plus(self) -> None:
)
assert_identical(other_way_expected, other_way)

@pytest.mark.parametrize("create_index_flag", [True, False])
def test_expand_dims_create_index_data_variable(self, create_index_flag):
# data variables should not gain an index ever
ds = Dataset({"x": 0})

if create_index_flag:
with pytest.warns(UserWarning, match="No index created"):
expanded = ds.expand_dims("x", create_index=create_index_flag)
else:
expanded = ds.expand_dims("x", create_index=create_index_flag)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x")

assert_identical(expanded, expected, check_default_indexes=False)
assert expanded.indexes == {}

def test_expand_dims_create_index_coordinate_variable(self):
# coordinate variables should gain an index only if create_index is True (the default)
ds = Dataset(coords={"x": 0})
expanded = ds.expand_dims("x")
expected = Dataset({"x": ("x", [0])})
assert_identical(expanded, expected)

expanded_no_index = ds.expand_dims("x", create_index=False)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x")

assert_identical(expanded_no_index, expected, check_default_indexes=False)
assert expanded_no_index.indexes == {}

def test_expand_dims_create_index_from_iterable(self):
ds = Dataset(coords={"x": 0})
expanded = ds.expand_dims(x=[0, 1])
expected = Dataset({"x": ("x", [0, 1])})
assert_identical(expanded, expected)

expanded_no_index = ds.expand_dims(x=[0, 1], create_index=False)

# TODO Can't just create the expected dataset directly using constructor because of GH issue 8959
expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x")

assert_identical(expanded, expected, check_default_indexes=False)
assert expanded_no_index.indexes == {}

def test_expand_dims_non_nanosecond_conversion(self) -> None:
# Regression test for https://github.com/pydata/xarray/issues/7493#issuecomment-1953091000
with pytest.warns(UserWarning, match="non-nanosecond precision"):
Expand Down
Loading