Skip to content

Commit b430524

Browse files
authored
Support __array_ufunc__ for xarray objects. (#1962)
* Support __array_ufunc__ for xarray objects. This means NumPy ufuncs are now supported directly on xarray.Dataset objects, and opens the door to supporting computation on new data types, such as sparse arrays or arrays with units. Fixes GH1617 * add TODO note on xarray objects in out argument * Satisfy stickler for __eq__ overload * Move dummy arithmetic implementations to SupportsArithemtic * Try again to disable flake8 warning * Disable py3k tool on stickler-ci * Move arithmetic to its own file. * Remove unused imports * Add note on backwards incompatible changes from apply_ufunc
1 parent 8271dff commit b430524

17 files changed

+317
-103
lines changed

.stickler.yml

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ linters:
66
# stickler doesn't support 'exclude' for flake8 properly, so we disable it
77
# below with files.ignore:
88
# https://github.com/markstory/lint-review/issues/184
9-
py3k:
109
files:
1110
ignore:
1211
- doc/**/*.py

asv_bench/benchmarks/rolling.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from __future__ import absolute_import
2-
from __future__ import division
3-
from __future__ import print_function
1+
from __future__ import absolute_import, division, print_function
42

53
import numpy as np
64
import pandas as pd
5+
76
import xarray as xr
87

98
from . import parameterized, randn, requires_dask

doc/api.rst

+7
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ Reshaping and reorganizing
358358
Universal functions
359359
===================
360360

361+
.. warning::
362+
363+
With recent versions of numpy, dask and xarray, NumPy ufuncs are now
364+
supported directly on all xarray and dask objects. This obliviates the need
365+
for the ``xarray.ufuncs`` module, which should not be used for new code
366+
unless compatibility with versions of NumPy prior to v1.13 is required.
367+
361368
This functions are copied from NumPy, but extended to work on NumPy arrays,
362369
dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs``
363370
module:

doc/computation.rst

+4-14
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,15 @@ Datasets support most of the same methods found on data arrays:
341341
ds.mean(dim='x')
342342
abs(ds)
343343
344-
Unfortunately, we currently do not support NumPy ufuncs for datasets [1]_.
345-
:py:meth:`~xarray.Dataset.apply` works around this
346-
limitation, by applying the given function to each variable in the dataset:
344+
Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or
345+
alternatively you can use :py:meth:`~xarray.Dataset.apply` to apply a function
346+
to each variable in a dataset:
347347

348348
.. ipython:: python
349349
350+
np.sin(ds)
350351
ds.apply(np.sin)
351352
352-
You can also use the wrapped functions in the ``xarray.ufuncs`` module:
353-
354-
.. ipython:: python
355-
356-
import xarray.ufuncs as xu
357-
xu.sin(ds)
358-
359353
Datasets also use looping over variables for *broadcasting* in binary
360354
arithmetic. You can do arithmetic between any ``DataArray`` and a dataset:
361355

@@ -373,10 +367,6 @@ Arithmetic between two datasets matches data variables of the same name:
373367
Similarly to index based alignment, the result has the intersection of all
374368
matching data variables.
375369

376-
.. [1] This was previously due to a limitation of NumPy, but with NumPy 1.13
377-
we should be able to support this by leveraging ``__array_ufunc__``
378-
(:issue:`1617`).
379-
380370
.. _comput.wrapping-custom:
381371

382372
Wrapping custom computation

doc/gallery/control_plot_colorbar.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
Use ``cbar_kwargs`` keyword to specify the number of ticks.
88
The ``spacing`` kwarg can be used to draw proportional ticks.
99
"""
10-
import xarray as xr
1110
import matplotlib.pyplot as plt
1211

12+
import xarray as xr
13+
1314
# Load the data
1415
air_temp = xr.tutorial.load_dataset('air_temperature')
1516
air2d = air_temp.air.isel(time=500)

doc/whats-new.rst

+29-3
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,53 @@ v0.10.2 (unreleased)
3232

3333
The minor release includes a number of bug-fixes and backwards compatible enhancements.
3434

35+
Backwards incompatible changes
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
38+
- The addition of ``__array_ufunc__`` for xarray objects (see below) means that
39+
NumPy `ufunc methods`_ (e.g., ``np.add.reduce``) that previously worked on
40+
``xarray.DataArray`` objects by converting them into NumPy arrays will now
41+
raise ``NotImplementedError`` instead. In all cases, the work-around is
42+
simple: convert your objects explicitly into NumPy arrays before calling the
43+
ufunc (e.g., with ``.values``).
44+
45+
.. _ufunc methods: https://docs.scipy.org/doc/numpy/reference/ufuncs.html#methods
46+
3547
Documentation
3648
~~~~~~~~~~~~~
3749

3850
Enhancements
3951
~~~~~~~~~~~~
4052

41-
- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``.
53+
- Added :py:func:`~xarray.dot`, equivalent to :py:func:`np.einsum`.
4254
Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option,
4355
which specifies the dimensions to sum over.
4456
(:issue:`1951`)
57+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
58+
4559
- Support for writing xarray datasets to netCDF files (netcdf4 backend only)
4660
when using the `dask.distributed <https://distributed.readthedocs.io>`_
4761
scheduler (:issue:`1464`).
4862
By `Joe Hamman <https://github.com/jhamman>`_.
4963

50-
51-
- Fixed to_netcdf when using dask distributed
5264
- Support lazy vectorized-indexing. After this change, flexible indexing such
5365
as orthogonal/vectorized indexing, becomes possible for all the backend
5466
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
5567
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
68+
69+
- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects
70+
(:issue:`1617`). This enables using NumPy ufuncs directly on
71+
``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer):
72+
73+
.. ipython:: python
74+
75+
ds = xr.Dataset({'a': 1})
76+
np.sin(ds)
77+
78+
This obliviates the need for the ``xarray.ufuncs`` module, which will be
79+
deprecated in the future when xarray drops support for older versions of
80+
NumPy. By `Stephan Hoyer <https://github.com/shoyer>`_.
81+
5682
- Improve :py:func:`~xarray.DataArray.rolling` logic.
5783
:py:func:`~xarray.DataArrayRolling` object now supports
5884
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view

xarray/core/arithmetic.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Base classes implementing arithmetic for xarray objects."""
2+
from __future__ import absolute_import, division, print_function
3+
4+
import numbers
5+
6+
import numpy as np
7+
8+
from .options import OPTIONS
9+
from .pycompat import bytes_type, dask_array_type, unicode_type
10+
from .utils import not_implemented
11+
12+
13+
class SupportsArithmetic(object):
14+
"""Base class for xarray types that support arithmetic.
15+
16+
Used by Dataset, DataArray, Variable and GroupBy.
17+
"""
18+
19+
# TODO: implement special methods for arithmetic here rather than injecting
20+
# them in xarray/core/ops.py. Ideally, do so by inheriting from
21+
# numpy.lib.mixins.NDArrayOperatorsMixin.
22+
23+
# TODO: allow extending this with some sort of registration system
24+
_HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type,
25+
unicode_type) + dask_array_type
26+
27+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
28+
from .computation import apply_ufunc
29+
30+
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
31+
out = kwargs.get('out', ())
32+
for x in inputs + out:
33+
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
34+
return NotImplemented
35+
36+
if ufunc.signature is not None:
37+
raise NotImplementedError(
38+
'{} not supported: xarray objects do not directly implement '
39+
'generalized ufuncs. Instead, use xarray.apply_ufunc.'
40+
.format(ufunc))
41+
42+
if method != '__call__':
43+
# TODO: support other methods, e.g., reduce and accumulate.
44+
raise NotImplementedError(
45+
'{} method for ufunc {} is not implemented on xarray objects, '
46+
'which currently only support the __call__ method.'
47+
.format(method, ufunc))
48+
49+
if any(isinstance(o, SupportsArithmetic) for o in out):
50+
# TODO: implement this with logic like _inplace_binary_op. This
51+
# will be necessary to use NDArrayOperatorsMixin.
52+
raise NotImplementedError(
53+
'xarray objects are not yet supported in the `out` argument '
54+
'for ufuncs.')
55+
56+
join = dataset_join = OPTIONS['arithmetic_join']
57+
58+
return apply_ufunc(ufunc, *inputs,
59+
input_core_dims=((),) * ufunc.nin,
60+
output_core_dims=((),) * ufunc.nout,
61+
join=join,
62+
dataset_join=dataset_join,
63+
dataset_fill_value=np.nan,
64+
kwargs=kwargs,
65+
dask='allowed')
66+
67+
# this has no runtime function - these are listed so IDEs know these
68+
# methods are defined and don't warn on these operations
69+
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
70+
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
71+
__or__ = __div__ = __eq__ = __ne__ = not_implemented

xarray/core/common.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import pandas as pd
77

88
from . import dtypes, formatting, ops
9+
from .arithmetic import SupportsArithmetic
910
from .pycompat import OrderedDict, basestring, dask_array_type, suppress
10-
from .utils import Frozen, SortedKeysDict, not_implemented
11+
from .utils import Frozen, SortedKeysDict
1112

1213

1314
class ImplementsArrayReduce(object):
@@ -235,7 +236,7 @@ def get_squeeze_dims(xarray_obj, dim, axis=None):
235236
return dim
236237

237238

238-
class BaseDataObject(AttrAccessMixin):
239+
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
239240
"""Shared base class for Dataset and DataArray."""
240241

241242
def squeeze(self, dim=None, drop=False, axis=None):
@@ -749,12 +750,6 @@ def __enter__(self):
749750
def __exit__(self, exc_type, exc_value, traceback):
750751
self.close()
751752

752-
# this has no runtime function - these are listed so IDEs know these
753-
# methods are defined and don't warn on these operations
754-
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
755-
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
756-
__or__ = __div__ = __eq__ = __ne__ = not_implemented
757-
758753

759754
def full_like(other, fill_value, dtype=None):
760755
"""Return a new object with the same shape and type as a given object.

xarray/core/dask_array_ops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from __future__ import absolute_import
2-
from __future__ import division
3-
from __future__ import print_function
1+
from __future__ import absolute_import, division, print_function
42

53
import numpy as np
4+
65
from . import nputils
76

87
try:

xarray/core/dataarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..plot.plot import _PlotMethods
1111
from .accessors import DatetimeAccessor
1212
from .alignment import align, reindex_like_indexers
13-
from .common import AbstractArray, BaseDataObject
13+
from .common import AbstractArray, DataWithCoords
1414
from .coordinates import (
1515
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
1616
assert_coordinate_consistent, remap_label_indexers)
@@ -117,7 +117,7 @@ def __setitem__(self, key, value):
117117
_THIS_ARRAY = utils.ReprObject('<this-array>')
118118

119119

120-
class DataArray(AbstractArray, BaseDataObject):
120+
class DataArray(AbstractArray, DataWithCoords):
121121
"""N-dimensional array with labeled coordinates and dimensions.
122122
123123
DataArray provides a wrapper around numpy ndarrays that uses labeled

xarray/core/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
rolling, utils)
1818
from .. import conventions
1919
from .alignment import align
20-
from .common import BaseDataObject, ImplementsDatasetReduce
20+
from .common import DataWithCoords, ImplementsDatasetReduce
2121
from .coordinates import (
2222
DatasetCoordinates, Indexes, LevelCoordinatesSource,
2323
assert_coordinate_consistent, remap_label_indexers)
@@ -298,7 +298,7 @@ def __getitem__(self, key):
298298
return self.dataset.sel(**key)
299299

300300

301-
class Dataset(Mapping, ImplementsDatasetReduce, BaseDataObject,
301+
class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords,
302302
formatting.ReprMixin):
303303
"""A multi-dimensional, in memory, array database.
304304
@@ -2362,7 +2362,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None):
23622362
array = self._variables[k]
23632363
if dim in array.dims:
23642364
dims = [d for d in array.dims if d != dim]
2365-
count += array.count(dims)
2365+
count += np.asarray(array.count(dims))
23662366
size += np.prod([self.dims[d] for d in dims])
23672367

23682368
if thresh is not None:

xarray/core/groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77

88
from . import dtypes, duck_array_ops, nputils, ops
9+
from .arithmetic import SupportsArithmetic
910
from .combine import concat
1011
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
1112
from .pycompat import integer_types, range, zip
@@ -151,7 +152,7 @@ def _unique_and_monotonic(group):
151152
return index.is_unique and index.is_monotonic
152153

153154

154-
class GroupBy(object):
155+
class GroupBy(SupportsArithmetic):
155156
"""A object that implements the split-apply-combine pattern.
156157
157158
Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over

xarray/core/npcompat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import absolute_import, division, print_function
22

3-
import numpy as np
43
from distutils.version import LooseVersion
54

5+
import numpy as np
66

77
if LooseVersion(np.__version__) >= LooseVersion('1.12'):
88
as_strided = np.lib.stride_tricks.as_strided

xarray/core/variable.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import xarray as xr # only for Dataset and DataArray
1212

13-
from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
13+
from . import (
14+
arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils,)
1415
from .indexing import (
1516
BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer,
1617
as_indexable)
@@ -216,8 +217,8 @@ def _as_array_or_item(data):
216217
return data
217218

218219

219-
class Variable(common.AbstractArray, utils.NdimSizeLenMixin):
220-
220+
class Variable(common.AbstractArray, arithmetic.SupportsArithmetic,
221+
utils.NdimSizeLenMixin):
221222
"""A netcdf-like variable consisting of dimensions, data and attributes
222223
which describe a single Array. A single Variable object is not fully
223224
described outside the context of its parent Dataset (if you want such a

xarray/tests/test_nputils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
22
from numpy.testing import assert_array_equal
33

4-
from xarray.core.nputils import (NumpyVIndexAdapter, _is_contiguous,
5-
rolling_window)
4+
from xarray.core.nputils import (
5+
NumpyVIndexAdapter, _is_contiguous, rolling_window)
66

77

88
def test_is_contiguous():

0 commit comments

Comments
 (0)