From 7fc7c7beb85f21482035895d9aa33938d36bd9a6 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 18 Nov 2020 22:24:24 +0100 Subject: [PATCH 1/3] rolling_exp: keep_attrs and typing --- doc/whats-new.rst | 4 ++++ xarray/core/rolling_exp.py | 36 ++++++++++++++++++++++++++++----- xarray/tests/test_dataarray.py | 29 ++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 37 ++++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9460fc08478..268755da1d5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,6 +45,10 @@ New Features By `Michal Baumgartner `_. - :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). By `Julius Busecke `_. +- Added the ``keep_attrs`` keyword to :py:meth:`~xarray.DataArray.rolling_exp.mean` and :py:meth:`~xarray.Dataset.rolling_exp.mean`. + The attributes are now kept per default. + By `Mathias Hauser `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index b80a4d313d9..0ae85a870e8 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,8 +1,17 @@ +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar + import numpy as np +from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array +if TYPE_CHECKING: + from .dataarray import DataArray # noqa: F401 + from .dataset import Dataset # noqa: F401 + +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") + def _get_alpha(com=None, span=None, halflife=None, alpha=None): # pandas defines in terms of com (converting to alpha in the algo) @@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp: +class RollingExp(Generic[T_DSorDA]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -78,16 +87,28 @@ class RollingExp: RollingExp : type of input argument """ - def __init__(self, obj, windows, window_type="span"): - self.obj = obj + def __init__( + self, + obj: T_DSorDA, + windows: Mapping[Hashable, int], + window_type: str = "span", + ): + self.obj: T_DSorDA = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self): + def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA: """ Exponentially weighted moving average + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + Examples -------- >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") @@ -97,4 +118,9 @@ def mean(self): Dimensions without coordinates: x """ - return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return self.obj.reduce( + move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs + ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e944c020503..b085c3bd354 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6927,6 +6927,35 @@ def test_rolling_exp(da, dim, window_type, window): assert_allclose(expected.variable, result.variable) +@requires_numbagg +def test_rolling_exp_keep_attrs(da): + + attrs = {"attrs": "da"} + da.attrs = attrs + + # attrs are kept per default + result = da.rolling_exp(time=10).mean() + assert result.attrs == attrs + + # discard attrs + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean() + assert result.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs + + with set_options(keep_attrs=True): + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + def test_no_dict(): d = DataArray() with pytest.raises(AttributeError): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 61e80557142..b248cb6dd0f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6145,6 +6145,43 @@ def test_rolling_exp(ds): assert isinstance(result, Dataset) +@requires_numbagg +def test_rolling_exp_keep_attrs(ds): + + attrs_global = {"attrs": "global"} + attrs_z1 = {"attr": "z1"} + + ds.attrs = attrs_global + ds.z1.attrs = attrs_z1 + + # attrs are kept per default + result = ds.rolling_exp(time=10).mean() + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + # discard attrs + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean() + assert result.attrs == {} + assert result.z1.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + with set_options(keep_attrs=True): + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) From a7b2389dafedc9d477b40da94710879d7b7b642c Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 18 Nov 2020 22:30:53 +0100 Subject: [PATCH 2/3] Update doc/whats-new.rst --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 268755da1d5..f9f7f0193e8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,7 +46,7 @@ New Features - :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). By `Julius Busecke `_. - Added the ``keep_attrs`` keyword to :py:meth:`~xarray.DataArray.rolling_exp.mean` and :py:meth:`~xarray.Dataset.rolling_exp.mean`. - The attributes are now kept per default. + The attributes are now kept per default (:pull:`4592`). By `Mathias Hauser `_. From 2276992026db7686ae181c106227d3e0c613c6c9 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 18 Nov 2020 22:56:30 +0100 Subject: [PATCH 3/3] update whats-new --- doc/whats-new.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f9f7f0193e8..8e9b40b2182 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,10 +45,8 @@ New Features By `Michal Baumgartner `_. - :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). By `Julius Busecke `_. -- Added the ``keep_attrs`` keyword to :py:meth:`~xarray.DataArray.rolling_exp.mean` and :py:meth:`~xarray.Dataset.rolling_exp.mean`. - The attributes are now kept per default (:pull:`4592`). - By `Mathias Hauser `_. - +- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes + per default. By `Mathias Hauser `_ (:pull:`4592`). Bug fixes ~~~~~~~~~