Skip to content

Commit 24c357f

Browse files
authored
Skip mean over empty axis (#5207)
* Skip mean over empty axis Avoids changing the datatype if the data does not have the requested axis. * Improvements based on feedback * Better testing * Clarify comment * Handle other functions as well, like sum, min, max
1 parent b2351cb commit 24c357f

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ Breaking changes
8585
as positional, all others need to be passed are keyword arguments. This is part of the
8686
refactor to support external backends (:issue:`4309`, :pull:`4989`).
8787
By `Alessandro Amici <https://github.com/alexamici>`_.
88+
- Functions that are identities for 0d data return the unchanged data
89+
if axis is empty. This ensures that Datasets where some variables do
90+
not have the averaged dimensions are not accidentially changed
91+
(:issue:`4885`, :pull:`5207`). By `David Schwörer
92+
<https://github.com/dschwoerer>`_
8893

8994
Deprecations
9095
~~~~~~~~~~~~

xarray/core/duck_array_ops.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,21 @@ def _ignore_warnings_if(condition):
310310
yield
311311

312312

313-
def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False):
313+
def _create_nan_agg_method(
314+
name, dask_module=dask_array, coerce_strings=False, invariant_0d=False
315+
):
314316
from . import nanops
315317

316318
def f(values, axis=None, skipna=None, **kwargs):
317319
if kwargs.pop("out", None) is not None:
318320
raise TypeError(f"`out` is not valid for {name}")
319321

322+
# The data is invariant in the case of 0d data, so do not
323+
# change the data (and dtype)
324+
# See https://github.com/pydata/xarray/issues/4885
325+
if invariant_0d and axis == ():
326+
return values
327+
320328
values = asarray(values)
321329

322330
if coerce_strings and values.dtype.kind in "SU":
@@ -354,28 +362,30 @@ def f(values, axis=None, skipna=None, **kwargs):
354362
# See ops.inject_reduce_methods
355363
argmax = _create_nan_agg_method("argmax", coerce_strings=True)
356364
argmin = _create_nan_agg_method("argmin", coerce_strings=True)
357-
max = _create_nan_agg_method("max", coerce_strings=True)
358-
min = _create_nan_agg_method("min", coerce_strings=True)
359-
sum = _create_nan_agg_method("sum")
365+
max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True)
366+
min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True)
367+
sum = _create_nan_agg_method("sum", invariant_0d=True)
360368
sum.numeric_only = True
361369
sum.available_min_count = True
362370
std = _create_nan_agg_method("std")
363371
std.numeric_only = True
364372
var = _create_nan_agg_method("var")
365373
var.numeric_only = True
366-
median = _create_nan_agg_method("median", dask_module=dask_array_compat)
374+
median = _create_nan_agg_method(
375+
"median", dask_module=dask_array_compat, invariant_0d=True
376+
)
367377
median.numeric_only = True
368-
prod = _create_nan_agg_method("prod")
378+
prod = _create_nan_agg_method("prod", invariant_0d=True)
369379
prod.numeric_only = True
370380
prod.available_min_count = True
371-
cumprod_1d = _create_nan_agg_method("cumprod")
381+
cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True)
372382
cumprod_1d.numeric_only = True
373-
cumsum_1d = _create_nan_agg_method("cumsum")
383+
cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True)
374384
cumsum_1d.numeric_only = True
375385
unravel_index = _dask_or_eager_func("unravel_index")
376386

377387

378-
_mean = _create_nan_agg_method("mean")
388+
_mean = _create_nan_agg_method("mean", invariant_0d=True)
379389

380390

381391
def _datetime_nanmin(array):

xarray/tests/test_duck_array_ops.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
where,
2727
)
2828
from xarray.core.pycompat import dask_array_type
29-
from xarray.testing import assert_allclose, assert_equal
29+
from xarray.testing import assert_allclose, assert_equal, assert_identical
3030

3131
from . import (
3232
arm_xfail,
@@ -373,6 +373,17 @@ def test_cftime_datetime_mean_dask_error():
373373
da.mean()
374374

375375

376+
def test_empty_axis_dtype():
377+
ds = Dataset()
378+
ds["pos"] = [1, 2, 3]
379+
ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
380+
ds["var"] = "pos", [2, 3, 4]
381+
assert_identical(ds.mean(dim="time")["var"], ds["var"])
382+
assert_identical(ds.max(dim="time")["var"], ds["var"])
383+
assert_identical(ds.min(dim="time")["var"], ds["var"])
384+
assert_identical(ds.sum(dim="time")["var"], ds["var"])
385+
386+
376387
@pytest.mark.parametrize("dim_num", [1, 2])
377388
@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
378389
@pytest.mark.parametrize("dask", [False, True])

0 commit comments

Comments
 (0)