Skip to content

Commit d8786da

Browse files
committed
Rewrite _build_and_check_signature
1 parent c27e25e commit d8786da

File tree

1 file changed

+51
-81
lines changed

1 file changed

+51
-81
lines changed

xarray/core/computation.py

+51-81
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from collections import NamedTuple
1+
import itertools
2+
from collections import namedtuple
23

34

45
def result_name(objects):
@@ -13,19 +14,25 @@ def result_name(objects):
1314
return name
1415

1516

16-
def apply_dataarray(func, args, join='inner', kwargs=None):
17+
def apply_dataarray(func, args, join='inner', gufunc_signature=None,
18+
kwargs=None, combine_names=None):
1719
if kwargs is None:
1820
kwargs = {}
1921

22+
if combine_names is None:
23+
combine_names = result_name
24+
2025
args = deep_align(*args, join=join, copy=False, raise_on_invalid=False)
2126

2227
coord_variables = [getattr(getattr(a, 'coords', {}), 'variables')
2328
for a in args]
2429
coords = merge_coords_without_align(coord_variables)
25-
name = result_name(args)
30+
name = combine_names(args)
2631

2732
data_vars = [getattr(a, 'variable') for a in args]
28-
variable = func(*data_vars, **kwargs)
33+
variables = func(*data_vars, **kwargs)
34+
35+
# TODO handle gufunc_signature
2936

3037
return DataArray(variable, coords, name=name, fastpath=True)
3138

@@ -111,76 +118,35 @@ def _as_sequence(arg, cls):
111118
return cls(arg)
112119

113120

114-
_ElemwiseSignature = NamedTuple(
115-
'_ElemwiseSignature', 'broadcast_dims, core_dims, output_dims, axis')
121+
_ElemwiseSignature = namedtuple(
122+
'_ElemwiseSignature', 'broadcast_dims, output_dims')
116123

124+
class GUFuncSignature(object):
125+
def __init__(self, inputs, outputs):
126+
self.inputs = inputs
127+
self.outputs = outputs
117128

118-
def _build_and_check_signature(variables, core_dims=None, axis_dims=None,
119-
drop_dims=None, new_dims=None):
120-
# All input dimension arguments are checked to appear on at least one input:
121-
# - core_dims are not broadcast over, and moved to the right with order
122-
# preserved.
123-
# - axis_dims is used to generate an integer or tuples of integers `axis`
124-
# keyword argument, which corresponds to the position of the given
125-
# dimension on the inputs. If `axis_dims` have overlap with `core_dims`,
126-
# no non-axis dimensions may appear in `core_dims` before an axis
127-
# dimension.
128-
# - drop_dims are input dimensions that should be dropped from the output.
129-
#
130-
# All output dimensions arguments are checked not to appear on any inputs:
131-
# - new_dims are new dimensions that should be added to the output array, in
132-
# order to the right of dimensions that are not dropped.
129+
@classmethod
130+
def from_string(cls, string):
131+
raise NotImplementedError
133132

134-
if core_dims is None and drop_dims is None and axis_dims is None:
135-
# broadcast everything
136-
dims = tuple(_calculate_unified_dim_sizes(variables))
137-
return _ElemwiseSignature(dims, (), dims, None)
138133

139-
core_dims = () if core_dims is None else _as_sequence(core_dims, tuple)
140-
drop_dims = set() if drop_dims is None else _as_sequence(drop_dims, set)
141-
new_dims = () if new_dims is None else _as_sequence(new_dims, tuple)
142-
143-
axis_is_scalar = axis_dims is not None and is_scalar(axis_dims)
144-
axis_dims = set() if axis_dims is None else _as_sequence(axis_dims, set)
134+
def _build_and_check_signature(variables, gufunc_signature):
135+
# core_dims are not broadcast over, and moved to the right with order
136+
# preserved.
145137

146138
dim_sizes = _calculate_unified_dim_sizes(variables)
139+
140+
if gufunc_signature is None:
141+
# broadcast everything, one output
142+
dims = tuple(size_dims)
143+
return _ElemwiseSignature(dims, [dims])
144+
145+
core_dims = set(itertools.chain.from_iterable(
146+
itertools.chain(gufunc_signature.inputs, gufunc_signature.outputs)))
147147
broadcast_dims = tuple(d for d in dim_sizes if d not in core_dims)
148-
all_input_dims = set(dim_sizes)
149-
150-
invalid = set(core_dims) - all_input_dims
151-
if invalid:
152-
raise ValueError('some `core_dims` not found on any input variables: '
153-
'%r' % list(invalid))
154-
155-
invalid = drop_dims - all_input_dims
156-
if invalid:
157-
raise ValueError('some `drop_dims` not found on any input variables: '
158-
'%r' % list(invalid))
159-
160-
invalid = axis_dims - all_input_dims
161-
if invalid:
162-
raise ValueError('some `axis_dims` not found on any input variables: '
163-
'%r' % list(invalid))
164-
axis = tuple(broadcast_dims.index(d) for d in axis_dims)
165-
n_remaining_axes = len(axis_dims) - len(axis)
166-
if n_remaining_axes > 0:
167-
valid_core_dims_for_axis = core_dims[:remaining_axes]
168-
if not set(valid_core_dims_for_axis) <= axis_dims:
169-
raise ValueError('axis dimensions %r have overlap with core '
170-
'dimensions %r, but do not appear at the start'
171-
% (axis_dims, core_dims))
172-
axis += tuple(range(len(axis), n_remaining_axes + len(axis)))
173-
if axis_is_scalar:
174-
axis, = axis
175-
176-
invalid = set(new_dims) ^ all_input_dims
177-
if invalid:
178-
raise ValueError('some `new_dims` are found on input variables: '
179-
'%r' % list(invalid))
180-
181-
output_dims = tuple(d for d in dim_sizes if d not in drop_dims) + new_dims
182-
183-
return _ElemwiseSignature(broadcast_dims, core_dims, output_dims, axis)
148+
output_dims = [broadcast_dims + out for out in gufunc_signature.outputs]
149+
return _ElemwiseSignature(broadcast_dims, output_dims)
184150

185151

186152
def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True):
@@ -208,8 +174,7 @@ def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True):
208174
return data
209175

210176

211-
def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None,
212-
axis_dims=None, drop_dims=None, new_dims=None,
177+
def apply_variable_ufunc(func, args, allow_dask=True, gufunc_signature=None,
213178
combine_attrs=None, kwargs=None):
214179

215180
if kwargs is None:
@@ -218,29 +183,34 @@ def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None,
218183
if combine_attrs is None:
219184
combine_attrs = lambda func, attrs: None
220185

221-
result_attrs = combine_attrs(func, [getattr(a, 'attrs', {}) for a in args])
222-
223-
sig = _build_and_check_dims_signature(
224-
variables, core_dims, axis_dims, drop_dims, new_dims)
186+
sig = _build_and_check_signature(variables, gufunc_signature)
225187

226-
if sig.axis:
227-
if 'axis' in kwargs:
228-
raise ValueError('axis is already set in kwargs')
229-
kwargs = dict(kwargs)
230-
kwargs['axis'] = sig.axis
188+
n_out = len(sig.output_dims)
189+
input_attrs = [getattr(a, 'attrs', {}) for a in args]
190+
result_attrs = [combine_attrs(input_attrs, func, n) for n in range(n_out)]
231191

232192
list_of_data = []
233193
for arg in args:
234194
if isinstance(arg, Variable):
235195
data = _broadcast_variable_data_to(arg, sig.broadcast_dims,
236196
allow_dask=allow_dask)
237-
list_of_data.append(data)
238197
else:
239-
list_of_data.append(arg)
198+
data = arg
199+
list_of_data.append(data)
240200

241201
result_data = func(*list_of_data, **kwargs)
242202

243-
return Variable(sig.output_dims, result_data, result_attrs)
203+
if n_out > 1:
204+
output = []
205+
for dims, data, attrs in zip(
206+
sig.output_dims, result_data, result_attrs):
207+
output.append(Variable(dims, data, attrs))
208+
return tuple(output)
209+
else:
210+
dims, = sig.output_dims
211+
data, = result_data
212+
attrs = result_attrs
213+
return Variable(dims, data, attrs)
244214

245215

246216
def apply_ufunc(func, args, join='inner', allow_dask=True, kwargs=None,

0 commit comments

Comments
 (0)