1
- from collections import NamedTuple
1
+ import itertools
2
+ from collections import namedtuple
2
3
3
4
4
5
def result_name (objects ):
@@ -13,19 +14,25 @@ def result_name(objects):
13
14
return name
14
15
15
16
16
- def apply_dataarray (func , args , join = 'inner' , kwargs = None ):
17
+ def apply_dataarray (func , args , join = 'inner' , gufunc_signature = None ,
18
+ kwargs = None , combine_names = None ):
17
19
if kwargs is None :
18
20
kwargs = {}
19
21
22
+ if combine_names is None :
23
+ combine_names = result_name
24
+
20
25
args = deep_align (* args , join = join , copy = False , raise_on_invalid = False )
21
26
22
27
coord_variables = [getattr (getattr (a , 'coords' , {}), 'variables' )
23
28
for a in args ]
24
29
coords = merge_coords_without_align (coord_variables )
25
- name = result_name (args )
30
+ name = combine_names (args )
26
31
27
32
data_vars = [getattr (a , 'variable' ) for a in args ]
28
- variable = func (* data_vars , ** kwargs )
33
+ variables = func (* data_vars , ** kwargs )
34
+
35
+ # TODO handle gufunc_signature
29
36
30
37
return DataArray (variable , coords , name = name , fastpath = True )
31
38
@@ -111,76 +118,35 @@ def _as_sequence(arg, cls):
111
118
return cls (arg )
112
119
113
120
114
- _ElemwiseSignature = NamedTuple (
115
- '_ElemwiseSignature' , 'broadcast_dims, core_dims, output_dims, axis ' )
121
+ _ElemwiseSignature = namedtuple (
122
+ '_ElemwiseSignature' , 'broadcast_dims, output_dims' )
116
123
124
+ class GUFuncSignature (object ):
125
+ def __init__ (self , inputs , outputs ):
126
+ self .inputs = inputs
127
+ self .outputs = outputs
117
128
118
- def _build_and_check_signature (variables , core_dims = None , axis_dims = None ,
119
- drop_dims = None , new_dims = None ):
120
- # All input dimension arguments are checked to appear on at least one input:
121
- # - core_dims are not broadcast over, and moved to the right with order
122
- # preserved.
123
- # - axis_dims is used to generate an integer or tuples of integers `axis`
124
- # keyword argument, which corresponds to the position of the given
125
- # dimension on the inputs. If `axis_dims` have overlap with `core_dims`,
126
- # no non-axis dimensions may appear in `core_dims` before an axis
127
- # dimension.
128
- # - drop_dims are input dimensions that should be dropped from the output.
129
- #
130
- # All output dimensions arguments are checked not to appear on any inputs:
131
- # - new_dims are new dimensions that should be added to the output array, in
132
- # order to the right of dimensions that are not dropped.
129
+ @classmethod
130
+ def from_string (cls , string ):
131
+ raise NotImplementedError
133
132
134
- if core_dims is None and drop_dims is None and axis_dims is None :
135
- # broadcast everything
136
- dims = tuple (_calculate_unified_dim_sizes (variables ))
137
- return _ElemwiseSignature (dims , (), dims , None )
138
133
139
- core_dims = () if core_dims is None else _as_sequence (core_dims , tuple )
140
- drop_dims = set () if drop_dims is None else _as_sequence (drop_dims , set )
141
- new_dims = () if new_dims is None else _as_sequence (new_dims , tuple )
142
-
143
- axis_is_scalar = axis_dims is not None and is_scalar (axis_dims )
144
- axis_dims = set () if axis_dims is None else _as_sequence (axis_dims , set )
134
+ def _build_and_check_signature (variables , gufunc_signature ):
135
+ # core_dims are not broadcast over, and moved to the right with order
136
+ # preserved.
145
137
146
138
dim_sizes = _calculate_unified_dim_sizes (variables )
139
+
140
+ if gufunc_signature is None :
141
+ # broadcast everything, one output
142
+ dims = tuple (size_dims )
143
+ return _ElemwiseSignature (dims , [dims ])
144
+
145
+ core_dims = set (itertools .chain .from_iterable (
146
+ itertools .chain (gufunc_signature .inputs , gufunc_signature .outputs )))
147
147
broadcast_dims = tuple (d for d in dim_sizes if d not in core_dims )
148
- all_input_dims = set (dim_sizes )
149
-
150
- invalid = set (core_dims ) - all_input_dims
151
- if invalid :
152
- raise ValueError ('some `core_dims` not found on any input variables: '
153
- '%r' % list (invalid ))
154
-
155
- invalid = drop_dims - all_input_dims
156
- if invalid :
157
- raise ValueError ('some `drop_dims` not found on any input variables: '
158
- '%r' % list (invalid ))
159
-
160
- invalid = axis_dims - all_input_dims
161
- if invalid :
162
- raise ValueError ('some `axis_dims` not found on any input variables: '
163
- '%r' % list (invalid ))
164
- axis = tuple (broadcast_dims .index (d ) for d in axis_dims )
165
- n_remaining_axes = len (axis_dims ) - len (axis )
166
- if n_remaining_axes > 0 :
167
- valid_core_dims_for_axis = core_dims [:remaining_axes ]
168
- if not set (valid_core_dims_for_axis ) <= axis_dims :
169
- raise ValueError ('axis dimensions %r have overlap with core '
170
- 'dimensions %r, but do not appear at the start'
171
- % (axis_dims , core_dims ))
172
- axis += tuple (range (len (axis ), n_remaining_axes + len (axis )))
173
- if axis_is_scalar :
174
- axis , = axis
175
-
176
- invalid = set (new_dims ) ^ all_input_dims
177
- if invalid :
178
- raise ValueError ('some `new_dims` are found on input variables: '
179
- '%r' % list (invalid ))
180
-
181
- output_dims = tuple (d for d in dim_sizes if d not in drop_dims ) + new_dims
182
-
183
- return _ElemwiseSignature (broadcast_dims , core_dims , output_dims , axis )
148
+ output_dims = [broadcast_dims + out for out in gufunc_signature .outputs ]
149
+ return _ElemwiseSignature (broadcast_dims , output_dims )
184
150
185
151
186
152
def _broadcast_variable_data_to (variable , broadcast_dims , allow_dask = True ):
@@ -208,8 +174,7 @@ def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True):
208
174
return data
209
175
210
176
211
- def apply_variable_ufunc (func , args , allow_dask = True , core_dims = None ,
212
- axis_dims = None , drop_dims = None , new_dims = None ,
177
+ def apply_variable_ufunc (func , args , allow_dask = True , gufunc_signature = None ,
213
178
combine_attrs = None , kwargs = None ):
214
179
215
180
if kwargs is None :
@@ -218,29 +183,34 @@ def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None,
218
183
if combine_attrs is None :
219
184
combine_attrs = lambda func , attrs : None
220
185
221
- result_attrs = combine_attrs (func , [getattr (a , 'attrs' , {}) for a in args ])
222
-
223
- sig = _build_and_check_dims_signature (
224
- variables , core_dims , axis_dims , drop_dims , new_dims )
186
+ sig = _build_and_check_signature (variables , gufunc_signature )
225
187
226
- if sig .axis :
227
- if 'axis' in kwargs :
228
- raise ValueError ('axis is already set in kwargs' )
229
- kwargs = dict (kwargs )
230
- kwargs ['axis' ] = sig .axis
188
+ n_out = len (sig .output_dims )
189
+ input_attrs = [getattr (a , 'attrs' , {}) for a in args ]
190
+ result_attrs = [combine_attrs (input_attrs , func , n ) for n in range (n_out )]
231
191
232
192
list_of_data = []
233
193
for arg in args :
234
194
if isinstance (arg , Variable ):
235
195
data = _broadcast_variable_data_to (arg , sig .broadcast_dims ,
236
196
allow_dask = allow_dask )
237
- list_of_data .append (data )
238
197
else :
239
- list_of_data .append (arg )
198
+ data = arg
199
+ list_of_data .append (data )
240
200
241
201
result_data = func (* list_of_data , ** kwargs )
242
202
243
- return Variable (sig .output_dims , result_data , result_attrs )
203
+ if n_out > 1 :
204
+ output = []
205
+ for dims , data , attrs in zip (
206
+ sig .output_dims , result_data , result_attrs ):
207
+ output .append (Variable (dims , data , attrs ))
208
+ return tuple (output )
209
+ else :
210
+ dims , = sig .output_dims
211
+ data , = result_data
212
+ attrs = result_attrs
213
+ return Variable (dims , data , attrs )
244
214
245
215
246
216
def apply_ufunc (func , args , join = 'inner' , allow_dask = True , kwargs = None ,
0 commit comments