Skip to content

Commit 5e50c0d

Browse files
fujiisoupJoe Hamman
authored and
Joe Hamman
committed
Restored dim order in DataArray.rolling().reduce() (#1277)
Restoring dim order in rolling.reduce().
1 parent 1cafb14 commit 5e50c0d

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ Enhancements
2424

2525
Bug fixes
2626
~~~~~~~~~
27+
- ``rolling`` now keeps its original dimension order (:issue:`1125`).
28+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
29+
2730

2831
.. _whats-new.0.9.1:
2932

xarray/core/rolling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def reduce(self, func, **kwargs):
177177
counts = concat([window.count(dim=self.dim) for _, window in self],
178178
dim=concat_dim)
179179
result = concat(windows, dim=concat_dim)
180+
# restore dim order
181+
result = result.transpose(*self.obj.dims)
180182

181183
result = result.where(counts >= self._min_periods)
182184

xarray/tests/test_dataarray.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,8 +2408,8 @@ def test_combine_first(self):
24082408
def da(request):
24092409
if request.param == 1:
24102410
times = pd.date_range('2000-01-01', freq='1D', periods=21)
2411-
values = np.random.random((21, 4))
2412-
da = DataArray(values, dims=('time', 'x'))
2411+
values = np.random.random((3, 21, 4))
2412+
da = DataArray(values, dims=('a', 'time', 'x'))
24132413
da['time'] = times
24142414
return da
24152415

@@ -2434,7 +2434,7 @@ def test_rolling_properties(da):
24342434

24352435
rolling_obj = da.rolling(time=4)
24362436

2437-
assert rolling_obj._axis_num == 0
2437+
assert rolling_obj._axis_num == 1
24382438

24392439
# catching invalid args
24402440
with pytest.raises(ValueError) as exception:
@@ -2464,7 +2464,8 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods):
24642464

24652465
func_name = 'move_{0}'.format(name)
24662466
actual = getattr(rolling_obj, name)()
2467-
expected = getattr(bn, func_name)(da.values, window=7, axis=0, min_count=min_periods)
2467+
expected = getattr(bn, func_name)(da.values, window=7, axis=1,
2468+
min_count=min_periods)
24682469
assert_array_equal(actual.values, expected)
24692470

24702471
# Test center
@@ -2517,3 +2518,4 @@ def test_rolling_reduce(da, center, min_periods, window, name):
25172518
actual = rolling_obj.reduce(getattr(np, 'nan%s' % name))
25182519
expected = getattr(rolling_obj, name)()
25192520
assert_allclose(actual, expected)
2521+
assert da.dims == expected.dims

0 commit comments

Comments
 (0)