From a427b806b666be4f57ce85410e8ceb0b4bad24de Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 18:09:01 -0800 Subject: [PATCH 1/9] Support __array_ufunc__ for xarray objects. This means NumPy ufuncs are now supported directly on xarray.Dataset objects, and opens the door to supporting computation on new data types, such as sparse arrays or arrays with units. Fixes GH1617 --- asv_bench/benchmarks/rolling.py | 5 +- doc/api.rst | 7 + doc/computation.rst | 18 +-- doc/gallery/control_plot_colorbar.py | 3 +- doc/whats-new.rst | 12 ++ xarray/core/common.py | 58 ++++++- xarray/core/dask_array_ops.py | 5 +- xarray/core/dataarray.py | 4 +- xarray/core/dataset.py | 6 +- xarray/core/groupby.py | 5 +- xarray/core/npcompat.py | 2 +- xarray/core/variable.py | 4 +- xarray/tests/test_nputils.py | 4 +- xarray/tests/test_ufuncs.py | 234 ++++++++++++++++++++------- xarray/ufuncs.py | 7 + 15 files changed, 281 insertions(+), 93 deletions(-) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 79d06019c00..52814ad3481 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -1,9 +1,8 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import numpy as np import pandas as pd + import xarray as xr from . import parameterized, randn, requires_dask diff --git a/doc/api.rst b/doc/api.rst index ae4803e5e62..d10d93a3021 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -357,6 +357,13 @@ Reshaping and reorganizing Universal functions =================== +.. warning:: + + With recent versions of numpy, dask and xarray, NumPy ufuncs are now + supported directly on all xarray and dask objects. This obliviates the need + for the ``xarray.ufuncs`` module, which should not be used for new code + unless compatibility with versions of NumPy prior to v1.13 is required. + This functions are copied from NumPy, but extended to work on NumPy arrays, dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs`` module: diff --git a/doc/computation.rst b/doc/computation.rst index bd0343b214d..589df8eac36 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -341,21 +341,15 @@ Datasets support most of the same methods found on data arrays: ds.mean(dim='x') abs(ds) -Unfortunately, we currently do not support NumPy ufuncs for datasets [1]_. -:py:meth:`~xarray.Dataset.apply` works around this -limitation, by applying the given function to each variable in the dataset: +Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or +alternatively you can use :py:meth:`~xarray.Dataset.apply` to apply a function +to each variable in a dataset: .. ipython:: python + np.sin(ds) ds.apply(np.sin) -You can also use the wrapped functions in the ``xarray.ufuncs`` module: - -.. ipython:: python - - import xarray.ufuncs as xu - xu.sin(ds) - Datasets also use looping over variables for *broadcasting* in binary arithmetic. You can do arithmetic between any ``DataArray`` and a dataset: @@ -373,10 +367,6 @@ Arithmetic between two datasets matches data variables of the same name: Similarly to index based alignment, the result has the intersection of all matching data variables. -.. [1] This was previously due to a limitation of NumPy, but with NumPy 1.13 - we should be able to support this by leveraging ``__array_ufunc__`` - (:issue:`1617`). - .. _comput.wrapping-custom: Wrapping custom computation diff --git a/doc/gallery/control_plot_colorbar.py b/doc/gallery/control_plot_colorbar.py index a09d825f8f0..5802a57cf31 100644 --- a/doc/gallery/control_plot_colorbar.py +++ b/doc/gallery/control_plot_colorbar.py @@ -7,9 +7,10 @@ Use ``cbar_kwargs`` keyword to specify the number of ticks. The ``spacing`` kwarg can be used to draw proportional ticks. """ -import xarray as xr import matplotlib.pyplot as plt +import xarray as xr + # Load the data air_temp = xr.tutorial.load_dataset('air_temperature') air2d = air_temp.air.isel(time=500) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab667ceba3f..a470228ec52 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,18 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects + (:issue:`1617`). This enables using NumPy ufuncs directly on + ``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer): + + .. ipython:: python + + ds = xr.Dataset({'a': 1}) + np.sin(ds) + + This obliviates the need for the ``xarray.ufuncs`` module, which will be + deprecated in the future when xarray drops support for older versions of + NumPy. By `Stephan Hoyer `_. - Improve :py:func:`~xarray.DataArray.rolling` logic. :py:func:`~xarray.DataArrayRolling` object now supports :py:func:`~xarray.DataArrayRolling.construct` method that returns a view diff --git a/xarray/core/common.py b/xarray/core/common.py index 85ac0bf9364..fcd0d566577 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,12 +1,16 @@ from __future__ import absolute_import, division, print_function +import numbers import warnings import numpy as np import pandas as pd from . import dtypes, formatting, ops -from .pycompat import OrderedDict, basestring, dask_array_type, suppress +from .options import OPTIONS +from .pycompat import ( + OrderedDict, basestring, bytes_type, dask_array_type, suppress, + unicode_type) from .utils import Frozen, SortedKeysDict, not_implemented @@ -235,7 +239,57 @@ def get_squeeze_dims(xarray_obj, dim, axis=None): return dim -class BaseDataObject(AttrAccessMixin): +class SupportsArithmetic(object): + """Base class for Dataset, DataArray, Variable and GroupBy.""" + + # TODO: implement special methods for arithmetic here rather than injecting + # them in xarray/core/ops.py. Ideally, do so by inheriting from + # numpy.lib.mixins.NDArrayOperatorsMixin. + + # TODO: allow extending this with some sort of registration system + _HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type, + unicode_type) + dask_array_type + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + from .computation import apply_ufunc + + # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. + out = kwargs.get('out', ()) + for x in inputs + out: + if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + return NotImplemented + + if ufunc.signature is not None: + raise NotImplementedError( + '{} not supported: xarray objects do not directly implement ' + 'generalized ufuncs. Instead, use xarray.apply_ufunc.' + .format(ufunc)) + + if method != '__call__': + # TODO: support other methods, e.g., reduce and accumulate. + raise NotImplementedError( + '{} method for ufunc {} is not implemented on xarray objects, ' + 'which currently only support the __call__ method.' + .format(method, ufunc)) + + if any(isinstance(o, SupportsArithmetic) for o in out): + raise NotImplementedError( + 'xarray objects are not yet supported in the `out` argument ' + 'for ufuncs.') + + join = dataset_join = OPTIONS['arithmetic_join'] + + return apply_ufunc(ufunc, *inputs, + input_core_dims=((),) * ufunc.nin, + output_core_dims=((),) * ufunc.nout, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + kwargs=kwargs, + dask='allowed') + + +class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" def squeeze(self, dim=None, drop=False, axis=None): diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 5524efb4803..4bd3766ced9 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,8 +1,7 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import numpy as np + from . import nputils try: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8c0360df8a9..e3b9a3fa39a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -10,7 +10,7 @@ from ..plot.plot import _PlotMethods from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers -from .common import AbstractArray, BaseDataObject +from .common import AbstractArray, DataWithCoords from .coordinates import ( DataArrayCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -117,7 +117,7 @@ def __setitem__(self, key, value): _THIS_ARRAY = utils.ReprObject('') -class DataArray(AbstractArray, BaseDataObject): +class DataArray(AbstractArray, DataWithCoords): """N-dimensional array with labeled coordinates and dimensions. DataArray provides a wrapper around numpy ndarrays that uses labeled diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2a2c4e382ce..03bc8fd6325 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -17,7 +17,7 @@ rolling, utils) from .. import conventions from .alignment import align -from .common import BaseDataObject, ImplementsDatasetReduce +from .common import DataWithCoords, ImplementsDatasetReduce from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -298,7 +298,7 @@ def __getitem__(self, key): return self.dataset.sel(**key) -class Dataset(Mapping, ImplementsDatasetReduce, BaseDataObject, +class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords, formatting.ReprMixin): """A multi-dimensional, in memory, array database. @@ -2362,7 +2362,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += array.count(dims) + count += np.asarray(array.count(dims)) size += np.prod([self.dims[d] for d in dims]) if thresh is not None: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b722a01ec46..83845331268 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -7,7 +7,8 @@ from . import dtypes, duck_array_ops, nputils, ops from .combine import concat -from .common import ImplementsArrayReduce, ImplementsDatasetReduce +from .common import ( + ImplementsArrayReduce, ImplementsDatasetReduce, SupportsArithmetic) from .pycompat import integer_types, range, zip from .utils import hashable, maybe_wrap_array, peek_at, safe_cast_to_index from .variable import IndexVariable, Variable, as_variable @@ -151,7 +152,7 @@ def _unique_and_monotonic(group): return index.is_unique and index.is_monotonic -class GroupBy(object): +class GroupBy(SupportsArithmetic): """A object that implements the split-apply-combine pattern. Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 8f1f3821f96..af722924aae 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, division, print_function -import numpy as np from distutils.version import LooseVersion +import numpy as np if LooseVersion(np.__version__) >= LooseVersion('1.12'): as_strided = np.lib.stride_tricks.as_strided diff --git a/xarray/core/variable.py b/xarray/core/variable.py index bb4285fba0a..c706a3eed05 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -216,8 +216,8 @@ def _as_array_or_item(data): return data -class Variable(common.AbstractArray, utils.NdimSizeLenMixin): - +class Variable(common.AbstractArray, common.SupportsArithmetic, + utils.NdimSizeLenMixin): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index d3ef02a039c..d3ad87d0d28 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,8 +1,8 @@ import numpy as np from numpy.testing import assert_array_equal -from xarray.core.nputils import (NumpyVIndexAdapter, _is_contiguous, - rolling_window) +from xarray.core.nputils import ( + NumpyVIndexAdapter, _is_contiguous, rolling_window) def test_is_contiguous(): diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 64a246953fe..91ec1142950 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -1,67 +1,185 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion import pickle import numpy as np +import pytest import xarray as xr import xarray.ufuncs as xu -from . import TestCase, assert_array_equal, assert_identical, raises_regex - - -class TestOps(TestCase): - def assert_identical(self, a, b): - assert type(a) is type(b) or (float(a) == float(b)) # noqa - try: - assert a.identical(b), (a, b) - except AttributeError: - assert_array_equal(a, b) - - def test_unary(self): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] - for a in args: - self.assert_identical(a + 1, xu.cos(a)) - - def test_binary(self): - args = [0, - np.zeros(2), - xr.Variable(['x'], [0, 0]), - xr.DataArray([0, 0], dims='x'), - xr.Dataset({'y': ('x', [0, 0])})] - for n, t1 in enumerate(args): - for t2 in args[n:]: - self.assert_identical(t2 + 1, xu.maximum(t1, t2 + 1)) - self.assert_identical(t2 + 1, xu.maximum(t2, t1 + 1)) - self.assert_identical(t2 + 1, xu.maximum(t1 + 1, t2)) - self.assert_identical(t2 + 1, xu.maximum(t2 + 1, t1)) - - def test_groupby(self): - ds = xr.Dataset({'a': ('x', [0, 0, 0])}, {'c': ('x', [0, 0, 1])}) - ds_grouped = ds.groupby('c') - group_mean = ds_grouped.mean('x') - arr_grouped = ds['a'].groupby('c') - - assert_identical(ds, xu.maximum(ds_grouped, group_mean)) - assert_identical(ds, xu.maximum(group_mean, ds_grouped)) - - assert_identical(ds, xu.maximum(arr_grouped, group_mean)) - assert_identical(ds, xu.maximum(group_mean, arr_grouped)) - - assert_identical(ds, xu.maximum(ds_grouped, group_mean['a'])) - assert_identical(ds, xu.maximum(group_mean['a'], ds_grouped)) - - assert_identical(ds.a, xu.maximum(arr_grouped, group_mean.a)) - assert_identical(ds.a, xu.maximum(group_mean.a, arr_grouped)) - - with raises_regex(TypeError, 'only support binary ops'): - xu.maximum(ds.a.variable, ds_grouped) - - def test_pickle(self): - a = 1.0 - cos_pickled = pickle.loads(pickle.dumps(xu.cos)) - self.assert_identical(cos_pickled(a), xu.cos(a)) +from . import ( + assert_array_equal, assert_identical as assert_identical_, mock, + raises_regex, +) + + +requires_numpy113 = pytest.mark.skipif(LooseVersion(np.__version__) < '1.13', + reason='numpy 1.13 or newer required') + + +def assert_identical(a, b): + assert type(a) is type(b) or (float(a) == float(b)) # noqa + if isinstance(a, (xr.DataArray, xr.Dataset, xr.Variable)): + assert_identical_(a, b) + else: + assert_array_equal(a, b) + + +@requires_numpy113 +def test_unary(): + args = [0, + np.zeros(2), + xr.Variable(['x'], [0, 0]), + xr.DataArray([0, 0], dims='x'), + xr.Dataset({'y': ('x', [0, 0])})] + for a in args: + assert_identical(a + 1, np.cos(a)) + + +@requires_numpy113 +def test_binary(): + args = [0, + np.zeros(2), + xr.Variable(['x'], [0, 0]), + xr.DataArray([0, 0], dims='x'), + xr.Dataset({'y': ('x', [0, 0])})] + for n, t1 in enumerate(args): + for t2 in args[n:]: + assert_identical(t2 + 1, np.maximum(t1, t2 + 1)) + assert_identical(t2 + 1, np.maximum(t2, t1 + 1)) + assert_identical(t2 + 1, np.maximum(t1 + 1, t2)) + assert_identical(t2 + 1, np.maximum(t2 + 1, t1)) + + +@requires_numpy113 +def test_binary_out(): + args = [1, + np.ones(2), + xr.Variable(['x'], [1, 1]), + xr.DataArray([1, 1], dims='x'), + xr.Dataset({'y': ('x', [1, 1])})] + for arg in args: + actual_mantissa, actual_exponent = np.frexp(arg) + assert_identical(actual_mantissa, 0.5 * arg) + assert_identical(actual_exponent, arg) + + +@requires_numpy113 +def test_groupby(): + ds = xr.Dataset({'a': ('x', [0, 0, 0])}, {'c': ('x', [0, 0, 1])}) + ds_grouped = ds.groupby('c') + group_mean = ds_grouped.mean('x') + arr_grouped = ds['a'].groupby('c') + + assert_identical(ds, np.maximum(ds_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, ds_grouped)) + + assert_identical(ds, np.maximum(arr_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, arr_grouped)) + + assert_identical(ds, np.maximum(ds_grouped, group_mean['a'])) + assert_identical(ds, np.maximum(group_mean['a'], ds_grouped)) + + assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) + assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) + + with raises_regex(ValueError, 'mismatched lengths for dimension'): + np.maximum(ds.a.variable, ds_grouped) + + +@requires_numpy113 +def test_alignment(): + ds1 = xr.Dataset({'a': ('x', [1, 2])}, {'x': [0, 1]}) + ds2 = xr.Dataset({'a': ('x', [2, 3]), 'b': 4}, {'x': [1, 2]}) + + actual = np.add(ds1, ds2) + expected = xr.Dataset({'a': ('x', [4])}, {'x': [1]}) + assert_identical_(actual, expected) + + with xr.set_options(arithmetic_join='outer'): + actual = np.add(ds1, ds2) + expected = xr.Dataset({'a': ('x', [np.nan, 4, np.nan]), 'b': np.nan}, + coords={'x': [0, 1, 2]}) + assert_identical_(actual, expected) + + +@requires_numpy113 +def test_kwargs(): + x = xr.DataArray(0) + result = np.add(x, 1, dtype=np.float64) + assert result.dtype == np.float64 + + +@requires_numpy113 +def test_xarray_defers_to_unrecognized_type(): + + class Other(object): + def __array_ufunc__(self, *args, **kwargs): + return 'other' + + xarray_obj = xr.DataArray([1, 2, 3]) + other = Other() + assert np.maximum(xarray_obj, other) == 'other' + assert np.sin(xarray_obj, out=other) == 'other' + + +@requires_numpy113 +def test_xarray_handles_dask(): + da = pytest.importorskip('dask.array') + x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(x, y) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +@requires_numpy113 +def test_dask_defers_to_xarray(): + da = pytest.importorskip('dask.array') + x = xr.DataArray(np.ones((2, 2)), dims=['x', 'y']) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(y, x) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +@requires_numpy113 +def test_gufunc_methods(): + xarray_obj = xr.DataArray([1, 2, 3]) + with raises_regex(NotImplementedError, 'reduce method'): + np.add.reduce(xarray_obj, 1) + + +@requires_numpy113 +def test_out(): + xarray_obj = xr.DataArray([1, 2, 3]) + + # xarray out arguments should raise + with raises_regex(NotImplementedError, '`out` argument'): + np.add(xarray_obj, 1, out=xarray_obj) + + # but non-xarray should be OK + other = np.zeros((3,)) + np.add(other, xarray_obj, out=other) + assert_identical(other, np.array([1, 2, 3])) + + +@requires_numpy113 +def test_gufuncs(): + xarray_obj = xr.DataArray([1, 2, 3]) + fake_gufunc = mock.Mock(signature='(n)->()', autospec=np.sin) + with raises_regex(NotImplementedError, 'generalized ufuncs'): + xarray_obj.__array_ufunc__(fake_gufunc, '__call__', xarray_obj) + + +def test_xarray_ufuncs_deprecation(): + with pytest.warns(PendingDeprecationWarning, match='xarray.ufuncs'): + xu.cos(xr.DataArray([0, 1])) + + +def test_xarray_ufuncs_pickle(): + a = 1.0 + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) + assert_identical(cos_pickled(a), xu.cos(a)) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index f7f17aedc2b..d9fcc1eac5d 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -15,6 +15,8 @@ """ from __future__ import absolute_import, division, print_function +import warnings as _warnings + import numpy as _np from .core.dataarray import DataArray as _DataArray @@ -42,6 +44,11 @@ def __init__(self, name): self._name = name def __call__(self, *args, **kwargs): + _warnings.warn( + 'xarray.ufuncs will be deprecated when xarray no longer supports ' + 'versions of numpy older than v1.13. Instead, use numpy ufuncs ' + 'directly.', PendingDeprecationWarning, stacklevel=2) + new_args = args f = _dask_or_eager_func(self._name, n_array_args=len(args)) if len(args) > 2 or len(args) == 0: From 661c5b49e7405650a4c42a15f5b5f937685c1698 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 18:40:24 -0800 Subject: [PATCH 2/9] add TODO note on xarray objects in out argument --- xarray/core/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index fcd0d566577..d16fd6fdc82 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -273,6 +273,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): .format(method, ufunc)) if any(isinstance(o, SupportsArithmetic) for o in out): + # TODO: implement this with logic like _inplace_binary_op. This + # will be necessary to use NDArrayOperatorsMixin. raise NotImplementedError( 'xarray objects are not yet supported in the `out` argument ' 'for ufuncs.') From 0ebbcb65d9b17868c221084ce4906272904b829b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 18:44:35 -0800 Subject: [PATCH 3/9] Satisfy stickler for __eq__ overload --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index d16fd6fdc82..a41564af27a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -809,7 +809,7 @@ def __exit__(self, exc_type, exc_value, traceback): # methods are defined and don't warn on these operations __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented + __or__ = __div__ = __eq__ = __ne__ = not_implemented # noqa def full_like(other, fill_value, dtype=None): From 561ac77e71b2ff0fe2cf4c3c32d31f1428d51305 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 18:46:03 -0800 Subject: [PATCH 4/9] Move dummy arithmetic implementations to SupportsArithemtic --- xarray/core/common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a41564af27a..4ab1bd34717 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -290,6 +290,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): kwargs=kwargs, dask='allowed') + # this has no runtime function - these are listed so IDEs know these + # methods are defined and don't warn on these operations + __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ + __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ + __or__ = __div__ = __eq__ = __ne__ = not_implemented # noqa + class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" @@ -805,12 +811,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - # this has no runtime function - these are listed so IDEs know these - # methods are defined and don't warn on these operations - __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ - __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented # noqa - def full_like(other, fill_value, dtype=None): """Return a new object with the same shape and type as a given object. From 52c750dd38f17b5e8d36ed1cfd6e2627cf36a87c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 18:48:17 -0800 Subject: [PATCH 5/9] Try again to disable flake8 warning --- xarray/core/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 4ab1bd34717..b316220730c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -239,7 +239,7 @@ def get_squeeze_dims(xarray_obj, dim, axis=None): return dim -class SupportsArithmetic(object): +class SupportsArithmetic(object): # noqa: W1641 """Base class for Dataset, DataArray, Variable and GroupBy.""" # TODO: implement special methods for arithmetic here rather than injecting @@ -294,7 +294,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # methods are defined and don't warn on these operations __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented # noqa + __or__ = __div__ = __eq__ = __ne__ = not_implemented class DataWithCoords(SupportsArithmetic, AttrAccessMixin): From 03697868e43183a9e47264cde78f609e23ba98de Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 19:47:02 -0800 Subject: [PATCH 6/9] Disable py3k tool on stickler-ci --- .stickler.yml | 1 - xarray/core/common.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.stickler.yml b/.stickler.yml index db8f5f254e9..79d8b7fb717 100644 --- a/.stickler.yml +++ b/.stickler.yml @@ -6,7 +6,6 @@ linters: # stickler doesn't support 'exclude' for flake8 properly, so we disable it # below with files.ignore: # https://github.com/markstory/lint-review/issues/184 - py3k: files: ignore: - doc/**/*.py diff --git a/xarray/core/common.py b/xarray/core/common.py index b316220730c..74c239fbb5d 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -239,7 +239,7 @@ def get_squeeze_dims(xarray_obj, dim, axis=None): return dim -class SupportsArithmetic(object): # noqa: W1641 +class SupportsArithmetic(object): """Base class for Dataset, DataArray, Variable and GroupBy.""" # TODO: implement special methods for arithmetic here rather than injecting From 4e6ac28d5e330b201c520ed90d34466aab18990d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 4 Mar 2018 20:44:56 -0800 Subject: [PATCH 7/9] Move arithmetic to its own file. --- xarray/core/arithmetic.py | 71 +++++++++++++++++++++++++++++++++++++++ xarray/core/common.py | 59 +------------------------------- xarray/core/groupby.py | 4 +-- xarray/core/variable.py | 5 +-- 4 files changed, 77 insertions(+), 62 deletions(-) create mode 100644 xarray/core/arithmetic.py diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py new file mode 100644 index 00000000000..3988d1abe2e --- /dev/null +++ b/xarray/core/arithmetic.py @@ -0,0 +1,71 @@ +"""Base classes implementing arithmetic for xarray objects.""" +from __future__ import absolute_import, division, print_function + +import numbers + +import numpy as np + +from .options import OPTIONS +from .pycompat import bytes_type, dask_array_type, unicode_type +from .utils import not_implemented + + +class SupportsArithmetic(object): + """Base class for xarray types that support arithmetic. + + Used by Dataset, DataArray, Variable and GroupBy. + """ + + # TODO: implement special methods for arithmetic here rather than injecting + # them in xarray/core/ops.py. Ideally, do so by inheriting from + # numpy.lib.mixins.NDArrayOperatorsMixin. + + # TODO: allow extending this with some sort of registration system + _HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type, + unicode_type) + dask_array_type + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + from .computation import apply_ufunc + + # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. + out = kwargs.get('out', ()) + for x in inputs + out: + if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + return NotImplemented + + if ufunc.signature is not None: + raise NotImplementedError( + '{} not supported: xarray objects do not directly implement ' + 'generalized ufuncs. Instead, use xarray.apply_ufunc.' + .format(ufunc)) + + if method != '__call__': + # TODO: support other methods, e.g., reduce and accumulate. + raise NotImplementedError( + '{} method for ufunc {} is not implemented on xarray objects, ' + 'which currently only support the __call__ method.' + .format(method, ufunc)) + + if any(isinstance(o, SupportsArithmetic) for o in out): + # TODO: implement this with logic like _inplace_binary_op. This + # will be necessary to use NDArrayOperatorsMixin. + raise NotImplementedError( + 'xarray objects are not yet supported in the `out` argument ' + 'for ufuncs.') + + join = dataset_join = OPTIONS['arithmetic_join'] + + return apply_ufunc(ufunc, *inputs, + input_core_dims=((),) * ufunc.nin, + output_core_dims=((),) * ufunc.nout, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + kwargs=kwargs, + dask='allowed') + + # this has no runtime function - these are listed so IDEs know these + # methods are defined and don't warn on these operations + __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ + __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ + __or__ = __div__ = __eq__ = __ne__ = not_implemented diff --git a/xarray/core/common.py b/xarray/core/common.py index 74c239fbb5d..4bfff853f53 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -7,6 +7,7 @@ import pandas as pd from . import dtypes, formatting, ops +from .arithmetic import SupportsArithmetic from .options import OPTIONS from .pycompat import ( OrderedDict, basestring, bytes_type, dask_array_type, suppress, @@ -239,64 +240,6 @@ def get_squeeze_dims(xarray_obj, dim, axis=None): return dim -class SupportsArithmetic(object): - """Base class for Dataset, DataArray, Variable and GroupBy.""" - - # TODO: implement special methods for arithmetic here rather than injecting - # them in xarray/core/ops.py. Ideally, do so by inheriting from - # numpy.lib.mixins.NDArrayOperatorsMixin. - - # TODO: allow extending this with some sort of registration system - _HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type, - unicode_type) + dask_array_type - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - from .computation import apply_ufunc - - # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. - out = kwargs.get('out', ()) - for x in inputs + out: - if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): - return NotImplemented - - if ufunc.signature is not None: - raise NotImplementedError( - '{} not supported: xarray objects do not directly implement ' - 'generalized ufuncs. Instead, use xarray.apply_ufunc.' - .format(ufunc)) - - if method != '__call__': - # TODO: support other methods, e.g., reduce and accumulate. - raise NotImplementedError( - '{} method for ufunc {} is not implemented on xarray objects, ' - 'which currently only support the __call__ method.' - .format(method, ufunc)) - - if any(isinstance(o, SupportsArithmetic) for o in out): - # TODO: implement this with logic like _inplace_binary_op. This - # will be necessary to use NDArrayOperatorsMixin. - raise NotImplementedError( - 'xarray objects are not yet supported in the `out` argument ' - 'for ufuncs.') - - join = dataset_join = OPTIONS['arithmetic_join'] - - return apply_ufunc(ufunc, *inputs, - input_core_dims=((),) * ufunc.nin, - output_core_dims=((),) * ufunc.nout, - join=join, - dataset_join=dataset_join, - dataset_fill_value=np.nan, - kwargs=kwargs, - dask='allowed') - - # this has no runtime function - these are listed so IDEs know these - # methods are defined and don't warn on these operations - __lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ - __truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ - __or__ = __div__ = __eq__ = __ne__ = not_implemented - - class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 83845331268..7068f8e6cae 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -6,9 +6,9 @@ import pandas as pd from . import dtypes, duck_array_ops, nputils, ops +from .arithmetic import SupportsArithmetic from .combine import concat -from .common import ( - ImplementsArrayReduce, ImplementsDatasetReduce, SupportsArithmetic) +from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .pycompat import integer_types, range, zip from .utils import hashable, maybe_wrap_array, peek_at, safe_cast_to_index from .variable import IndexVariable, Variable, as_variable diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c706a3eed05..c0c3accea34 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -10,7 +10,8 @@ import xarray as xr # only for Dataset and DataArray -from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils +from . import ( + arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils,) from .indexing import ( BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer, as_indexable) @@ -216,7 +217,7 @@ def _as_array_or_item(data): return data -class Variable(common.AbstractArray, common.SupportsArithmetic, +class Variable(common.AbstractArray, arithmetic.SupportsArithmetic, utils.NdimSizeLenMixin): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully From b6bed5bb24c5260a539fe71fe520326d73889bd8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 8 Mar 2018 20:46:23 -0800 Subject: [PATCH 8/9] Remove unused imports --- xarray/core/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 4bfff853f53..337c1c51415 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import numbers import warnings import numpy as np @@ -8,11 +7,8 @@ from . import dtypes, formatting, ops from .arithmetic import SupportsArithmetic -from .options import OPTIONS -from .pycompat import ( - OrderedDict, basestring, bytes_type, dask_array_type, suppress, - unicode_type) -from .utils import Frozen, SortedKeysDict, not_implemented +from .pycompat import OrderedDict, basestring, dask_array_type, suppress +from .utils import Frozen, SortedKeysDict class ImplementsArrayReduce(object): From 259a109ca0ec3fbc4fc8330b08502f61018491d0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 8 Mar 2018 21:09:34 -0800 Subject: [PATCH 9/9] Add note on backwards incompatible changes from apply_ufunc --- doc/whats-new.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d178491ef81..2ce7801bed6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,6 +32,18 @@ v0.10.2 (unreleased) The minor release includes a number of bug-fixes and backwards compatible enhancements. +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The addition of ``__array_ufunc__`` for xarray objects (see below) means that + NumPy `ufunc methods`_ (e.g., ``np.add.reduce``) that previously worked on + ``xarray.DataArray`` objects by converting them into NumPy arrays will now + raise ``NotImplementedError`` instead. In all cases, the work-around is + simple: convert your objects explicitly into NumPy arrays before calling the + ufunc (e.g., with ``.values``). + +.. _ufunc methods: https://docs.scipy.org/doc/numpy/reference/ufuncs.html#methods + Documentation ~~~~~~~~~~~~~