Skip to content

Commit 6cc1144

Browse files
committed
add groupby_bins method
1 parent 488c72c commit 6cc1144

File tree

3 files changed

+59
-15
lines changed

3 files changed

+59
-15
lines changed

xarray/core/common.py

+52-11
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def pipe(self, func, *args, **kwargs):
320320
else:
321321
return func(self, *args, **kwargs)
322322

323-
def groupby(self, group, squeeze=True, bins=None):
323+
def groupby(self, group, squeeze=True):
324324
"""Returns a GroupBy object for performing grouped operations.
325325
326326
Parameters
@@ -332,26 +332,67 @@ def groupby(self, group, squeeze=True, bins=None):
332332
If "group" is a dimension of any arrays in this dataset, `squeeze`
333333
controls whether the subarrays have a dimension of length 1 along
334334
that dimension or if the dimension is squeezed out.
335-
bins : array-like, optional
336-
If `bins` is specified, the groups will be discretized into the
337-
specified bins determined by `pandas.cut` applied to the index of
338-
`group`.
339335
340336
Returns
341337
-------
342338
grouped : GroupBy
343339
A `GroupBy` object patterned after `pandas.GroupBy` that can be
344340
iterated over in the form of `(unique_value, grouped_array)` pairs.
345-
346-
See Also
347-
--------
348-
pandas.cut
349341
"""
350-
from .dataarray import DataArray
342+
if isinstance(group, basestring):
343+
group = self[group]
344+
return self.groupby_cls(self, group, squeeze=squeeze)
345+
346+
def groupby_bins(self, group, bins, right=True, labels=None, precision=3,
347+
include_lowest=False, squeeze=True):
348+
"""Returns a GroupBy object for performing grouped operations. Rather
349+
than using all unique values of `group`, the values are discretized
350+
first by applying `pandas.cut` [1]_ to `group`.
351351
352+
Parameters
353+
----------
354+
group : str, DataArray or Coordinate
355+
Array whose binned values should be used to group this array. If a
356+
string, must be the name of a variable contained in this dataset.
357+
bins : int or array of scalars
358+
If bins is an int, it defines the number of equal-width bins in the
359+
range of x. However, in this case, the range of x is extended by .1%
360+
on each side to include the min or max values of x. If bins is a
361+
sequence it defines the bin edges allowing for non-uniform bin
362+
width. No extension of the range of x is done in this case.
363+
right : boolean, optional
364+
I ndicates whether the bins include the rightmost edge or not. If
365+
right == True (the default), then the bins [1,2,3,4] indicate
366+
(1,2], (2,3], (3,4].
367+
labels : array or boolean, default None
368+
Used as labels for the resulting bins. Must be of the same length as
369+
the resulting bins. If False, string bin labels are assigned by
370+
`pandas.cut`.
371+
precision : int
372+
The precision at which to store and display the bins labels.
373+
include_lowest : bool
374+
Whether the first interval should be left-inclusive or not.
375+
squeeze : boolean, optional
376+
If "group" is a dimension of any arrays in this dataset, `squeeze`
377+
controls whether the subarrays have a dimension of length 1 along
378+
that dimension or if the dimension is squeezed out.
379+
380+
Returns
381+
-------
382+
grouped : GroupBy
383+
A `GroupBy` object patterned after `pandas.GroupBy` that can be
384+
iterated over in the form of `(unique_value, grouped_array)` pairs.
385+
386+
References
387+
----------
388+
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
389+
"""
352390
if isinstance(group, basestring):
353391
group = self[group]
354-
return self.groupby_cls(self, group, squeeze=squeeze, bins=bins)
392+
return self.groupby_cls(self, group, squeeze=squeeze, bins=bins,
393+
cut_kwargs={'right': right, 'labels': labels,
394+
'precision': precision,
395+
'include_lowest': include_lowest})
355396

356397
def rolling(self, min_periods=None, center=False, **windows):
357398
"""

xarray/core/groupby.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ class GroupBy(object):
8383
Dataset.groupby
8484
DataArray.groupby
8585
"""
86-
def __init__(self, obj, group, squeeze=False, grouper=None, bins=None):
86+
def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
87+
cut_kwargs={}):
8788
"""Create a GroupBy object
8889
8990
Parameters
@@ -101,6 +102,8 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None):
101102
bins : array-like, optional
102103
If `bins` is specified, the groups will be discretized into the
103104
specified bins by `pandas.cut`.
105+
cut_kwargs : dict, optional
106+
Extra keyword arguments to pass to `pandas.cut`
104107
"""
105108
from .dataset import as_dataset
106109
from .dataarray import DataArray
@@ -138,7 +141,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None):
138141
if grouper is not None and bins is not None:
139142
raise TypeError("Can't specify both `grouper` and `bins`.")
140143
if bins is not None:
141-
group = DataArray(pd.cut(group.values, bins),
144+
group = DataArray(pd.cut(group.values, bins, **cut_kwargs),
142145
group.coords, name=group.name)
143146
if grouper is not None:
144147
index = safe_cast_to_index(group)

xarray/test/test_dataarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def test_groupby_bins(self):
12781278
expected = DataArray([1,5], dims='dim_0', coords={'dim_0': bin_coords})
12791279
# the problem with this is that it overwrites the dimensions of array!
12801280
#actual = array.groupby('dim_0', bins=bins).sum()
1281-
actual = array.groupby('dim_0', bins=bins).apply(
1281+
actual = array.groupby_bins('dim_0', bins).apply(
12821282
lambda x : x.sum(), shortcut=False)
12831283
self.assertDataArrayIdentical(expected, actual)
12841284
# make sure original array dims are unchanged
@@ -1290,7 +1290,7 @@ def test_groupby_bins_multidim(self):
12901290
bins = [0,15,20]
12911291
bin_coords = ['(0, 15]', '(15, 20]']
12921292
expected = DataArray([16, 40], dims='lat', coords={'lat': bin_coords})
1293-
actual = array.groupby('lat', bins=bins).apply(
1293+
actual = array.groupby_bins('lat', bins).apply(
12941294
lambda x : x.sum(), shortcut=False)
12951295
self.assertDataArrayIdentical(expected, actual)
12961296

0 commit comments

Comments
 (0)