Skip to content

Commit 4be6653

Browse files
authored
rolling_exp: keep_attrs and typing (#4592)
* rolling_exp: keep_attrs and typing * Update doc/whats-new.rst * update whats-new
1 parent 19c2626 commit 4be6653

File tree

4 files changed

+99
-5
lines changed

4 files changed

+99
-5
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ New Features
5353
By `Michal Baumgartner <https://github.com/m1so>`_.
5454
- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`).
5555
By `Julius Busecke <https://github.com/jbusecke>`_.
56+
- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes
57+
per default. By `Mathias Hauser <https://github.com/mathause>`_ (:pull:`4592`).
5658

5759
Bug fixes
5860
~~~~~~~~~

xarray/core/rolling_exp.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar
2+
13
import numpy as np
24

5+
from .options import _get_keep_attrs
36
from .pdcompat import count_not_none
47
from .pycompat import is_duck_dask_array
58

9+
if TYPE_CHECKING:
10+
from .dataarray import DataArray # noqa: F401
11+
from .dataset import Dataset # noqa: F401
12+
13+
T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")
14+
615

716
def _get_alpha(com=None, span=None, halflife=None, alpha=None):
817
# pandas defines in terms of com (converting to alpha in the algo)
@@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
5665
return float(comass)
5766

5867

59-
class RollingExp:
68+
class RollingExp(Generic[T_DSorDA]):
6069
"""
6170
Exponentially-weighted moving window object.
6271
Similar to EWM in pandas
@@ -78,16 +87,28 @@ class RollingExp:
7887
RollingExp : type of input argument
7988
"""
8089

81-
def __init__(self, obj, windows, window_type="span"):
82-
self.obj = obj
90+
def __init__(
91+
self,
92+
obj: T_DSorDA,
93+
windows: Mapping[Hashable, int],
94+
window_type: str = "span",
95+
):
96+
self.obj: T_DSorDA = obj
8397
dim, window = next(iter(windows.items()))
8498
self.dim = dim
8599
self.alpha = _get_alpha(**{window_type: window})
86100

87-
def mean(self):
101+
def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA:
88102
"""
89103
Exponentially weighted moving average
90104
105+
Parameters
106+
----------
107+
keep_attrs : bool, default: None
108+
If True, the attributes (``attrs``) will be copied from the original
109+
object to the new one. If False, the new object will be returned
110+
without attributes. If None uses the global default.
111+
91112
Examples
92113
--------
93114
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
@@ -97,4 +118,9 @@ def mean(self):
97118
Dimensions without coordinates: x
98119
"""
99120

100-
return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha)
121+
if keep_attrs is None:
122+
keep_attrs = _get_keep_attrs(default=True)
123+
124+
return self.obj.reduce(
125+
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
126+
)

xarray/tests/test_dataarray.py

+29
Original file line numberDiff line numberDiff line change
@@ -6931,6 +6931,35 @@ def test_rolling_exp(da, dim, window_type, window):
69316931
assert_allclose(expected.variable, result.variable)
69326932

69336933

6934+
@requires_numbagg
6935+
def test_rolling_exp_keep_attrs(da):
6936+
6937+
attrs = {"attrs": "da"}
6938+
da.attrs = attrs
6939+
6940+
# attrs are kept per default
6941+
result = da.rolling_exp(time=10).mean()
6942+
assert result.attrs == attrs
6943+
6944+
# discard attrs
6945+
result = da.rolling_exp(time=10).mean(keep_attrs=False)
6946+
assert result.attrs == {}
6947+
6948+
# test discard attrs using global option
6949+
with set_options(keep_attrs=False):
6950+
result = da.rolling_exp(time=10).mean()
6951+
assert result.attrs == {}
6952+
6953+
# keyword takes precedence over global option
6954+
with set_options(keep_attrs=False):
6955+
result = da.rolling_exp(time=10).mean(keep_attrs=True)
6956+
assert result.attrs == attrs
6957+
6958+
with set_options(keep_attrs=True):
6959+
result = da.rolling_exp(time=10).mean(keep_attrs=False)
6960+
assert result.attrs == {}
6961+
6962+
69346963
def test_no_dict():
69356964
d = DataArray()
69366965
with pytest.raises(AttributeError):

xarray/tests/test_dataset.py

+37
Original file line numberDiff line numberDiff line change
@@ -6150,6 +6150,43 @@ def test_rolling_exp(ds):
61506150
assert isinstance(result, Dataset)
61516151

61526152

6153+
@requires_numbagg
6154+
def test_rolling_exp_keep_attrs(ds):
6155+
6156+
attrs_global = {"attrs": "global"}
6157+
attrs_z1 = {"attr": "z1"}
6158+
6159+
ds.attrs = attrs_global
6160+
ds.z1.attrs = attrs_z1
6161+
6162+
# attrs are kept per default
6163+
result = ds.rolling_exp(time=10).mean()
6164+
assert result.attrs == attrs_global
6165+
assert result.z1.attrs == attrs_z1
6166+
6167+
# discard attrs
6168+
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
6169+
assert result.attrs == {}
6170+
assert result.z1.attrs == {}
6171+
6172+
# test discard attrs using global option
6173+
with set_options(keep_attrs=False):
6174+
result = ds.rolling_exp(time=10).mean()
6175+
assert result.attrs == {}
6176+
assert result.z1.attrs == {}
6177+
6178+
# keyword takes precedence over global option
6179+
with set_options(keep_attrs=False):
6180+
result = ds.rolling_exp(time=10).mean(keep_attrs=True)
6181+
assert result.attrs == attrs_global
6182+
assert result.z1.attrs == attrs_z1
6183+
6184+
with set_options(keep_attrs=True):
6185+
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
6186+
assert result.attrs == {}
6187+
assert result.z1.attrs == {}
6188+
6189+
61536190
@pytest.mark.parametrize("center", (True, False))
61546191
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
61556192
@pytest.mark.parametrize("window", (1, 2, 3, 4))

0 commit comments

Comments
 (0)