@@ -88,8 +88,8 @@ def get_coord_variables(arg):
88
88
return output
89
89
90
90
91
- def apply_dataarray (args , func , signature = None , join = 'inner' ,
92
- kwargs = None , new_coords = None , combine_names = None ):
91
+ def apply_dataarray_ufunc (args , func , signature = None , join = 'inner' ,
92
+ kwargs = None , new_coords = None , combine_names = None ):
93
93
if signature is None :
94
94
signature = _default_signature (len (args ))
95
95
@@ -137,8 +137,9 @@ def collect_dict_values(objects, keys, fill_value=None)
137
137
return result_values
138
138
139
139
140
- def apply_dataset (args , func , signature = None , join = 'inner' , fill_value = None ,
141
- kwargs = None , new_coords = None , result_attrs = None ):
140
+ def apply_dataset_ufunc (args , func , signature = None , join = 'inner' ,
141
+ fill_value = None , kwargs = None , new_coords = None ,
142
+ result_attrs = None ):
142
143
if kwargs is None :
143
144
kwargs = {}
144
145
@@ -194,6 +195,65 @@ def make_dataset(data_vars, coord_vars, attrs):
194
195
return make_dataset (data_vars , coord_vars , attrs )
195
196
196
197
198
+
199
+
200
+ def _iter_over_selections (obj , dim , values ):
201
+ """Iterate over selections of an xarray object in the provided order.
202
+ """
203
+ dummy = None
204
+ for value in values :
205
+ try :
206
+ obj_sel = obj .sel (** {dim : values })
207
+ except KeyError :
208
+ if dim not in obj .dims :
209
+ raise ValueError ('incompatible dimensions for a grouped '
210
+ 'binary operation: the group variable %r '
211
+ 'is not a dimension on the other argument'
212
+ % dim )
213
+ if dummy is None :
214
+ dummy = _dummy_copy (obj )
215
+ obj_sel = dummy
216
+ yield obj_sel
217
+
218
+
219
+ def apply_groupby_ufunc (args , func ):
220
+ groupbys = [arg for arg in args if isinstance (GroupBy )]
221
+ if not groupbys :
222
+ raise ValueError ('must have at least one groupby to iterate over' )
223
+ first_groupby = groups [0 ]
224
+ if any (not first_groupby .unique_coord .equals (gb .unique_coord )
225
+ for gb in groupbys [1 :]):
226
+ raise ValueError ('can only perform operations over multiple groupbys '
227
+ 'at once if they have all the same unique coordinate' )
228
+
229
+ grouped_dim = first_groupby .group .name
230
+ unique_values = first_groupby .unique_coord .values
231
+
232
+ iterators = []
233
+ for arg in args :
234
+ if isinstance (arg , GroupBy ):
235
+ iterator = (value for _ , value in arg )
236
+ elif hasattr (arg , 'dims' ) and group_name in arg .dims :
237
+ if isinstance (arg , Variable ):
238
+ raise ValueError (
239
+ 'groupby operations cannot be performed with '
240
+ 'xarray.Variable objects that share a dimension with '
241
+ 'the grouped dimension' )
242
+ iterator = _iter_over_selections (arg , grouped_dim , unique_vlaues )
243
+ else :
244
+ iterator = itertools .repeat (arg )
245
+ iterators .append (iterator )
246
+
247
+ applied = (func (* zipped_args ) for zipped_args in zip (iterators ))
248
+ applied_example , applied = peek_at (applied )
249
+ combine = first_groupby ._combined
250
+ if isinstance (applied_example , tuple ):
251
+ combined = tuple (combine (output ) for output in zip (* applied ))
252
+ else :
253
+ combined = combine (applied )
254
+ return combined
255
+
256
+
197
257
def _calculate_unified_dim_sizes (variables ):
198
258
dim_sizes = OrderedDict ()
199
259
@@ -340,11 +400,19 @@ def apply_ufunc(args, func=None, signature=None, join='inner',
340
400
apply_variable_ufunc , func = func , dask_array = dask_array ,
341
401
combine_attrs = combine_variable_attrs , kwargs = kwargs )
342
402
343
- if any (is_dict_like (a ) for a in args ):
344
- return apply_dataset (args , variables_ufunc , join = join ,
345
- combine_attrs = combine_dataset_attrs )
403
+ if any (isinstance (a , GroupBy ) for a in args ):
404
+ partial_apply_ufunc = functools .partial (
405
+ apply_ufunc , func = func , signature = signature , join = join ,
406
+ dask_array = dask_array , kwargs = kwargs ,
407
+ combine_dataset_attrs = combine_dataset_attrs ,
408
+ combine_variable_attrs = combine_variable_attrs ,
409
+ dtype = None )
410
+ return apply_groupby_ufunc (args , partial_apply_ufunc )
411
+ elif any (is_dict_like (a ) for a in args ):
412
+ return apply_dataset_ufunc (args , variables_ufunc , join = join ,
413
+ combine_attrs = combine_dataset_attrs )
346
414
elif any (isinstance (a , DataArray ) for a in args ):
347
- return apply_dataarray (args , variables_ufunc , join = join )
415
+ return apply_dataarray_ufunc (args , variables_ufunc , join = join )
348
416
elif any (isinstance (a , Variable ) for a in args ):
349
417
return variables_ufunc (args )
350
418
elif dask_array == 'auto' and any (
0 commit comments