Skip to content

Commit 807001e

Browse files
committed
multidimensional groupby
1 parent 4fdf6d4 commit 807001e

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

xarray/core/groupby.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,28 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
102102
from .dataset import as_dataset
103103

104104
if group.ndim != 1:
105-
# TODO: remove this limitation?
106-
raise ValueError('`group` must be 1 dimensional')
105+
# try to stack the dims of the group into a single dim
106+
# TODO: figure out how to exclude dimensions from the stacking
107+
# (e.g. group over space dims but leave time dim intact)
108+
orig_dims = group.dims
109+
stacked_dim_name = 'stacked_' + '_'.join(orig_dims)
110+
# the copy is necessary here
111+
group = group.stack(**{stacked_dim_name: orig_dims}).copy()
112+
# without it, an error is raised deep in pandas
113+
########################
114+
# xarray/core/groupby.py
115+
# ---> 31 inverse, values = pd.factorize(ar, sort=True)
116+
# pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint)
117+
# --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True)
118+
# pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)()
119+
# pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)()
120+
# pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)()
121+
# ValueError: buffer source array is read-only
122+
#######################
123+
# seems related to
124+
# https://github.com/pydata/pandas/issues/10043
125+
# https://github.com/pydata/pandas/pull/10070
126+
obj = obj.stack(**{stacked_dim_name: orig_dims})
107127
if getattr(group, 'name', None) is None:
108128
raise ValueError('`group` must have a name')
109129
if not hasattr(group, 'dims'):

xarray/test/test_dataarray.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,17 @@ def test_groupby_first_and_last(self):
12441244
expected = array # should be a no-op
12451245
self.assertDataArrayIdentical(expected, actual)
12461246

1247+
def test_groupby_multidim(self):
1248+
array = DataArray([[0,1],[2,3]],
1249+
coords={'lon': (['ny','nx'], [[30,40],[40,50]] ),
1250+
'lat': (['ny','nx'], [[10,10],[20,20]] ),},
1251+
dims=['ny','nx'])
1252+
for dim, expected_sum in [
1253+
('lon', DataArray([0, 3, 3], coords={'lon': [30,40,50]})),
1254+
('lat', DataArray([1,5], coords={'lat': [10,20]}))]:
1255+
actual_sum = array.groupby(dim).sum()
1256+
self.assertDataArrayIdentical(expected_sum, actual_sum)
1257+
12471258
def make_rolling_example_array(self):
12481259
times = pd.date_range('2000-01-01', freq='1D', periods=21)
12491260
values = np.random.random((21, 4))
@@ -1792,29 +1803,29 @@ def test_full_like(self):
17921803
actual = _full_like(DataArray([1, 2, 3]), fill_value=np.nan)
17931804
self.assertEqual(actual.dtype, np.float)
17941805
np.testing.assert_equal(actual.values, np.nan)
1795-
1806+
17961807
def test_dot(self):
17971808
x = np.linspace(-3, 3, 6)
17981809
y = np.linspace(-3, 3, 5)
1799-
z = range(4)
1810+
z = range(4)
18001811
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
18011812
da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z'])
1802-
1813+
18031814
dm_vals = range(4)
18041815
dm = DataArray(dm_vals, coords=[z], dims=['z'])
1805-
1816+
18061817
# nd dot 1d
18071818
actual = da.dot(dm)
18081819
expected_vals = np.tensordot(da_vals, dm_vals, [2, 0])
18091820
expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y'])
18101821
self.assertDataArrayEqual(expected, actual)
1811-
1822+
18121823
# all shared dims
18131824
actual = da.dot(da)
18141825
expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2]))
18151826
expected = DataArray(expected_vals)
18161827
self.assertDataArrayEqual(expected, actual)
1817-
1828+
18181829
# multiple shared dims
18191830
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
18201831
j = np.linspace(-3, 3, 20)
@@ -1823,7 +1834,7 @@ def test_dot(self):
18231834
expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2]))
18241835
expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j'])
18251836
self.assertDataArrayEqual(expected, actual)
1826-
1837+
18271838
with self.assertRaises(NotImplementedError):
18281839
da.dot(dm.to_dataset(name='dm'))
18291840
with self.assertRaises(TypeError):

xarray/test/test_dataset.py

-2
Original file line numberDiff line numberDiff line change
@@ -1545,8 +1545,6 @@ def test_groupby_iter(self):
15451545

15461546
def test_groupby_errors(self):
15471547
data = create_test_data()
1548-
with self.assertRaisesRegexp(ValueError, 'must be 1 dimensional'):
1549-
data.groupby('var1')
15501548
with self.assertRaisesRegexp(ValueError, 'must have a name'):
15511549
data.groupby(np.arange(10))
15521550
with self.assertRaisesRegexp(ValueError, 'length does not match'):

0 commit comments

Comments
 (0)