From 47610a661bf6742928f3fe77482b68a1eba07ab0 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 12 Nov 2024 23:34:29 -0500 Subject: [PATCH 01/12] initial namespace-aware implementation --- xarray/__init__.py | 3 +- xarray/tests/test_dask.py | 12 ++++ xarray/tests/test_sparse.py | 8 +++ xarray/tests/test_ufuncs.py | 92 ++++++++++++++++++++++++++ xarray/ufuncs.py | 125 ++++++++++++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 xarray/ufuncs.py diff --git a/xarray/__init__.py b/xarray/__init__.py index e474cee85ad..634f67a61a2 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version as _version -from xarray import groupers, testing, tutorial +from xarray import groupers, testing, tutorial, ufuncs from xarray.backends.api import ( load_dataarray, load_dataset, @@ -69,6 +69,7 @@ "groupers", "testing", "tutorial", + "ufuncs", # Top-level functions "align", "apply_ufunc", diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 26e42dd692a..54ae80a1d9d 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -11,6 +11,7 @@ import pytest import xarray as xr +import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops from xarray.core.duck_array_ops import lazy_array_equiv @@ -274,6 +275,17 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(0, v)) + def test_univariate_xufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + + def test_bivariate_xufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) + self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) + def test_compute(self): u = self.eager_var v = self.lazy_var diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f0a97fc7e69..a69e370572b 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -9,6 +9,7 @@ import pytest import xarray as xr +import xarray.ufuncs as xu from xarray import DataArray, Variable from xarray.namedarray.pycompat import array_type from xarray.tests import assert_equal, assert_identical, requires_dask @@ -294,6 +295,13 @@ def test_bivariate_ufunc(self): assert_sparse_equal(np.maximum(self.data, 0), np.maximum(self.var, 0).data) assert_sparse_equal(np.maximum(self.data, 0), np.maximum(0, self.var).data) + def test_univariate_xufunc(self): + assert_sparse_equal(xu.sin(self.var).data, np.sin(self.data)) + + def test_bivariate_xufunc(self): + assert_sparse_equal(xu.multiply(self.var, 0).data, np.multiply(self.data, 0)) + assert_sparse_equal(xu.multiply(0, self.var).data, np.multiply(0, self.data)) + def test_repr(self): expected = dedent( """\ diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 20e064e2013..ecfd82c6512 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -1,9 +1,12 @@ from __future__ import annotations +import pickle + import numpy as np import pytest import xarray as xr +import xarray.ufuncs as xu from xarray.tests import assert_allclose, assert_array_equal, mock from xarray.tests import assert_identical as assert_identical_ @@ -155,3 +158,92 @@ def test_gufuncs(): fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) + + +class DuckArray: + # Minimal array class that implements a few of its own ufuncs, and otherwise + # dispatches to numpy but returns a DuckArray + def __init__(self, data): + self.data = data + self.shape = data.shape + self.ndim = data.ndim + self.dtype = data.dtype + + def __array_namespace__(self): + return DuckArray + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + inputs = [i.data if isinstance(i, DuckArray) else i for i in inputs] + result = ufunc(*inputs, **kwargs) + return DuckArray(result) + + @staticmethod + def sin(x): + return DuckArray(np.sin(x.data)) + + @staticmethod + def add(x, y): + return DuckArray(x.data + y.data) + + +class DuckArray2(DuckArray): + def __array_namespace__(self): + return DuckArray2 + + +class TestXarrayUfuncs: + @pytest.fixture(autouse=True) + def setUp(self): + self.x = xr.DataArray(np.array([1, 2, 3])) + self.xd = xr.DataArray(DuckArray(np.array([1, 2, 3]))) + self.xd2 = xr.DataArray(DuckArray2(np.array([1, 2, 3]))) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("name", xu.__all__) + def test_ufuncs(self, name, request): + np_func = getattr(np, name) + xu_func = getattr(xu, name) + if hasattr(np_func, "nin") and np_func.nin == 2: + args = (self.x, self.x) + else: + args = (self.x,) + + actual = np_func(*args) + expected = xu_func(*args) + + assert_identical(actual, expected) + + def test_ufunc_pickle(self): + a = 1.0 + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) + assert_identical(cos_pickled(a), xu.cos(a)) + + def test_ufunc_scalar(self): + actual = xu.sin(1) + assert isinstance(actual, float) + + def test_ufunc_duck_array_dataarray(self): + actual = xu.sin(self.xd) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_duck_array_variable(self): + actual = xu.sin(self.xd.variable) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_duck_array_dataset(self): + ds = xr.Dataset({"a": self.xd}) + actual = xu.sin(ds) + assert isinstance(actual.a.data, DuckArray) + + def test_ufunc_numpy_fallback(self): + with pytest.warns(UserWarning, match=r"Function cos not found in DuckArray"): + actual = xu.cos(self.xd) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_mixed_arrays_compatible(self): + actual = xu.add(self.xd, self.x) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_mixed_arrays_incompatible(self): + with pytest.raises(ValueError, match=r"Mixed array types"): + xu.add(self.xd, self.xd2) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py new file mode 100644 index 00000000000..b9f3b848259 --- /dev/null +++ b/xarray/ufuncs.py @@ -0,0 +1,125 @@ +"""xarray specific universal functions.""" + +import textwrap +import warnings + +import numpy as np + +import xarray as xr +from xarray.core.groupby import GroupBy +from xarray.namedarray.pycompat import array_type + + +def _walk_array_namespaces(obj, namespaces): + if isinstance(obj, xr.DataTree): + for node in obj.subtree: + _walk_array_namespaces(node.dataset, namespaces) + elif isinstance(obj, xr.Dataset): + for name in obj.data_vars: + _walk_array_namespaces(obj[name], namespaces) + elif isinstance(obj, GroupBy): + _walk_array_namespaces(next(iter(obj))[1], namespaces) + elif isinstance(obj, xr.DataArray | xr.Variable): + _walk_array_namespaces(obj.data, namespaces) + elif isinstance(obj, array_type("dask")): + _walk_array_namespaces(obj._meta, namespaces) + else: + namespace = getattr(obj, "__array_namespace__", None) + if namespace is not None: + namespaces.add(namespace()) + + return namespaces + + +class _UFuncDispatcher: + """Wrapper for dispatching ufuncs.""" + + def __init__(self, name): + self._name = name + + def __call__(self, *args, **kwargs): + xps = set() + for arg in args: + _walk_array_namespaces(arg, xps) + + xps.discard(np) + if len(xps) > 1: + names = [module.__name__ for module in xps] + raise ValueError( + f"Mixed array types {names} are not supported by xarray.ufuncs" + ) + + xp = next(iter(xps)) if len(xps) else np + func = getattr(xp, self._name, None) + + if func is None: + warnings.warn( + f"Function {self._name} not found in {xp.__name__}, falling back to numpy", + stacklevel=2, + ) + func = getattr(np, self._name) + + return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) + + +def _skip_signature(doc, name): + if not isinstance(doc, str): + return doc + + if doc.startswith(name): + signature_end = doc.find("\n\n") + doc = doc[signature_end + 2 :] + + return doc + + +def _remove_unused_reference_labels(doc): + if not isinstance(doc, str): + return doc + + max_references = 5 + for num in range(max_references): + label = f".. [{num}]" + reference = f"[{num}]_" + index = f"{num}. " + + if label not in doc or reference in doc: + continue + + doc = doc.replace(label, index) + + return doc + + +def _dedent(doc): + if not isinstance(doc, str): + return doc + + return textwrap.dedent(doc) + + +def _create_op(name): + func = _UFuncDispatcher(name) + func.__name__ = name + doc = getattr(np, name).__doc__ + + doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name)) + + func.__doc__ = ( + f"xarray specific variant of numpy.{name}. Handles " + "xarray objects by dispatching to the appropriate " + "function for the underlying array type.\n\n" + f"Documentation from numpy:\n\n{doc}" + ) + return func + + +# Auto generate from the public numpy ufuncs +np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} +excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} +additional_ufuncs = {"isreal"} # "angle", "iscomplex" +__all__ = sorted(np_ufuncs - excluded_ufuncs | additional_ufuncs) + + +for name in __all__: + globals()[name] = _create_op(name) From 6d9bfbedd9b013391469f18402bff6b3bf22b680 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 13 Nov 2024 11:56:42 -0500 Subject: [PATCH 02/12] use np subclass, test duck dask arrays --- xarray/tests/test_ufuncs.py | 39 +++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index ecfd82c6512..89bfc5c2445 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -7,7 +7,7 @@ import xarray as xr import xarray.ufuncs as xu -from xarray.tests import assert_allclose, assert_array_equal, mock +from xarray.tests import assert_allclose, assert_array_equal, mock, requires_dask from xarray.tests import assert_identical as assert_identical_ @@ -160,30 +160,23 @@ def test_gufuncs(): xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) -class DuckArray: - # Minimal array class that implements a few of its own ufuncs, and otherwise - # dispatches to numpy but returns a DuckArray - def __init__(self, data): - self.data = data - self.shape = data.shape - self.ndim = data.ndim - self.dtype = data.dtype +class DuckArray(np.ndarray): + # Minimal subclassed duck array with its own self-contained namespace, + # which implements a few ufuncs + def __new__(cls, array): + obj = np.asarray(array).view(cls) + return obj def __array_namespace__(self): return DuckArray - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - inputs = [i.data if isinstance(i, DuckArray) else i for i in inputs] - result = ufunc(*inputs, **kwargs) - return DuckArray(result) - @staticmethod def sin(x): - return DuckArray(np.sin(x.data)) + return np.sin(x) @staticmethod def add(x, y): - return DuckArray(x.data + y.data) + return x + y class DuckArray2(DuckArray): @@ -194,9 +187,9 @@ def __array_namespace__(self): class TestXarrayUfuncs: @pytest.fixture(autouse=True) def setUp(self): - self.x = xr.DataArray(np.array([1, 2, 3])) - self.xd = xr.DataArray(DuckArray(np.array([1, 2, 3]))) - self.xd2 = xr.DataArray(DuckArray2(np.array([1, 2, 3]))) + self.x = xr.DataArray([1, 2, 3]) + self.xd = xr.DataArray(DuckArray([1, 2, 3])) + self.xd2 = xr.DataArray(DuckArray2([1, 2, 3])) @pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize("name", xu.__all__) @@ -235,6 +228,14 @@ def test_ufunc_duck_array_dataset(self): actual = xu.sin(ds) assert isinstance(actual.a.data, DuckArray) + @requires_dask + def test_ufunc_duck_dask(self): + import dask.array as da + + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) + actual = xu.sin(x) + assert isinstance(actual.data._meta, DuckArray) + def test_ufunc_numpy_fallback(self): with pytest.warns(UserWarning, match=r"Function cos not found in DuckArray"): actual = xu.cos(self.xd) From 3799b4c88219fa7685e92f5e9c1bc183e797038f Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 16 Nov 2024 22:22:01 -0500 Subject: [PATCH 03/12] remove dask special casing and numpy fallback --- xarray/tests/test_ufuncs.py | 19 ++++++++++------ xarray/ufuncs.py | 43 +++++++++++++++---------------------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 89bfc5c2445..a9c128aab68 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -1,6 +1,7 @@ from __future__ import annotations import pickle +from unittest.mock import patch import numpy as np import pytest @@ -201,8 +202,8 @@ def test_ufuncs(self, name, request): else: args = (self.x,) - actual = np_func(*args) - expected = xu_func(*args) + expected = np_func(*args) + actual = xu_func(*args) assert_identical(actual, expected) @@ -236,10 +237,16 @@ def test_ufunc_duck_dask(self): actual = xu.sin(x) assert isinstance(actual.data._meta, DuckArray) - def test_ufunc_numpy_fallback(self): - with pytest.warns(UserWarning, match=r"Function cos not found in DuckArray"): - actual = xu.cos(self.xd) - assert isinstance(actual.data, DuckArray) + @requires_dask + @pytest.mark.xfail(reason="dask ufuncs currently dispatch to numpy") + def test_ufunc_duck_dask_no_array_ufunc(self): + import dask.array as da + + # dask ufuncs currently only preserve duck arrays that implement __array_ufunc__ + with patch.object(DuckArray, "__array_ufunc__", new=None, create=True): + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) + actual = xu.sin(x) + assert isinstance(actual.data._meta, DuckArray) def test_ufunc_mixed_arrays_compatible(self): actual = xu.add(self.xd, self.x) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index b9f3b848259..331ebf2fba2 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -1,17 +1,16 @@ """xarray specific universal functions.""" import textwrap -import warnings import numpy as np import xarray as xr from xarray.core.groupby import GroupBy -from xarray.namedarray.pycompat import array_type def _walk_array_namespaces(obj, namespaces): if isinstance(obj, xr.DataTree): + # TODO: DataTree doesn't actually support ufuncs yet for node in obj.subtree: _walk_array_namespaces(node.dataset, namespaces) elif isinstance(obj, xr.Dataset): @@ -21,8 +20,6 @@ def _walk_array_namespaces(obj, namespaces): _walk_array_namespaces(next(iter(obj))[1], namespaces) elif isinstance(obj, xr.DataArray | xr.Variable): _walk_array_namespaces(obj.data, namespaces) - elif isinstance(obj, array_type("dask")): - _walk_array_namespaces(obj._meta, namespaces) else: namespace = getattr(obj, "__array_namespace__", None) if namespace is not None: @@ -31,6 +28,19 @@ def _walk_array_namespaces(obj, namespaces): return namespaces +def get_array_namespace(*args): + xps = set() + for arg in args: + _walk_array_namespaces(arg, xps) + + xps.discard(np) + if len(xps) > 1: + names = [module.__name__ for module in xps] + raise ValueError(f"Mixed array types {names} are not supported.") + + return next(iter(xps)) if len(xps) else np + + class _UFuncDispatcher: """Wrapper for dispatching ufuncs.""" @@ -38,28 +48,9 @@ def __init__(self, name): self._name = name def __call__(self, *args, **kwargs): - xps = set() - for arg in args: - _walk_array_namespaces(arg, xps) - - xps.discard(np) - if len(xps) > 1: - names = [module.__name__ for module in xps] - raise ValueError( - f"Mixed array types {names} are not supported by xarray.ufuncs" - ) - - xp = next(iter(xps)) if len(xps) else np - func = getattr(xp, self._name, None) - - if func is None: - warnings.warn( - f"Function {self._name} not found in {xp.__name__}, falling back to numpy", - stacklevel=2, - ) - func = getattr(np, self._name) - - return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) + xp = get_array_namespace(*args) + func = getattr(xp, self._name) + return xr.apply_ufunc(func, *args, dask="allowed", **kwargs) def _skip_signature(doc, name): From 0f929bdf49f67f97acc48c773080e0a38762592e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 16 Nov 2024 22:42:34 -0500 Subject: [PATCH 04/12] add isnat --- xarray/tests/test_ufuncs.py | 6 +++++- xarray/ufuncs.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index a9c128aab68..a862cf73386 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -191,13 +191,17 @@ def setUp(self): self.x = xr.DataArray([1, 2, 3]) self.xd = xr.DataArray(DuckArray([1, 2, 3])) self.xd2 = xr.DataArray(DuckArray2([1, 2, 3])) + self.xt = xr.DataArray(np.datetime64("2021-01-01", "ns")) @pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): np_func = getattr(np, name) xu_func = getattr(xu, name) - if hasattr(np_func, "nin") and np_func.nin == 2: + + if name == "isnat": + args = (self.xt,) + elif hasattr(np_func, "nin") and np_func.nin == 2: args = (self.x, self.x) else: args = (self.x,) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 331ebf2fba2..49f71280c9c 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -107,7 +107,7 @@ def _create_op(name): # Auto generate from the public numpy ufuncs np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} -excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} +excluded_ufuncs = {"divmod", "frexp", "matmul", "modf", "vecdot"} additional_ufuncs = {"isreal"} # "angle", "iscomplex" __all__ = sorted(np_ufuncs - excluded_ufuncs | additional_ufuncs) From b7c58cf7ec8f38dc57c2638fc43024afab1c2c5e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 16 Nov 2024 23:42:10 -0500 Subject: [PATCH 05/12] hard code the supported ufuncs --- xarray/tests/test_ufuncs.py | 5 +- xarray/ufuncs.py | 233 ++++++++++++++++++++++++++++++++++-- 2 files changed, 228 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index a862cf73386..27e8b46c5fa 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -209,7 +209,10 @@ def test_ufuncs(self, name, request): expected = np_func(*args) actual = xu_func(*args) - assert_identical(actual, expected) + if name in ["angle", "iscomplex"]: + np.testing.assert_equal(expected, actual.values) + else: + assert_identical(actual, expected) def test_ufunc_pickle(self): a = 1.0 diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 49f71280c9c..d5eb2060a78 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -105,12 +105,227 @@ def _create_op(name): return func -# Auto generate from the public numpy ufuncs -np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} -excluded_ufuncs = {"divmod", "frexp", "matmul", "modf", "vecdot"} -additional_ufuncs = {"isreal"} # "angle", "iscomplex" -__all__ = sorted(np_ufuncs - excluded_ufuncs | additional_ufuncs) - - -for name in __all__: - globals()[name] = _create_op(name) +# These can be auto-generated from the public numpy ufuncs: +# {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} + +# Ufuncs that use core dimensions or product multiple output arrays are +# not currently supported, and left commented below. + +# unary +abs = _create_op("abs") +absolute = _create_op("absolute") +acos = _create_op("acos") +acosh = _create_op("acosh") +arccos = _create_op("arccos") +arccosh = _create_op("arccosh") +arcsin = _create_op("arcsin") +arcsinh = _create_op("arcsinh") +arctan = _create_op("arctan") +arctanh = _create_op("arctanh") +asin = _create_op("asin") +asinh = _create_op("asinh") +atan = _create_op("atan") +atanh = _create_op("atanh") +bitwise_count = _create_op("bitwise_count") +bitwise_invert = _create_op("bitwise_invert") +bitwise_not = _create_op("bitwise_not") +cbrt = _create_op("cbrt") +ceil = _create_op("ceil") +conj = _create_op("conj") +conjugate = _create_op("conjugate") +cos = _create_op("cos") +cosh = _create_op("cosh") +deg2rad = _create_op("deg2rad") +degrees = _create_op("degrees") +exp = _create_op("exp") +exp2 = _create_op("exp2") +expm1 = _create_op("expm1") +fabs = _create_op("fabs") +floor = _create_op("floor") +# frexp +invert = _create_op("invert") +isfinite = _create_op("isfinite") +isinf = _create_op("isinf") +isnan = _create_op("isnan") +isnat = _create_op("isnat") +log = _create_op("log") +log10 = _create_op("log10") +log1p = _create_op("log1p") +log2 = _create_op("log2") +logical_not = _create_op("logical_not") +# modf +negative = _create_op("negative") +positive = _create_op("positive") +rad2deg = _create_op("rad2deg") +radians = _create_op("radians") +reciprocal = _create_op("reciprocal") +rint = _create_op("rint") +sign = _create_op("sign") +signbit = _create_op("signbit") +sin = _create_op("sin") +sinh = _create_op("sinh") +spacing = _create_op("spacing") +sqrt = _create_op("sqrt") +square = _create_op("square") +tan = _create_op("tan") +tanh = _create_op("tanh") +trunc = _create_op("trunc") + +# binary +add = _create_op("add") +arctan2 = _create_op("arctan2") +atan2 = _create_op("atan2") +bitwise_and = _create_op("bitwise_and") +bitwise_left_shift = _create_op("bitwise_left_shift") +bitwise_or = _create_op("bitwise_or") +bitwise_right_shift = _create_op("bitwise_right_shift") +bitwise_xor = _create_op("bitwise_xor") +copysign = _create_op("copysign") +divide = _create_op("divide") +# divmod +equal = _create_op("equal") +float_power = _create_op("float_power") +floor_divide = _create_op("floor_divide") +fmax = _create_op("fmax") +fmin = _create_op("fmin") +fmod = _create_op("fmod") +gcd = _create_op("gcd") +greater = _create_op("greater") +greater_equal = _create_op("greater_equal") +heaviside = _create_op("heaviside") +hypot = _create_op("hypot") +lcm = _create_op("lcm") +ldexp = _create_op("ldexp") +left_shift = _create_op("left_shift") +less = _create_op("less") +less_equal = _create_op("less_equal") +logaddexp = _create_op("logaddexp") +logaddexp2 = _create_op("logaddexp2") +logical_and = _create_op("logical_and") +logical_or = _create_op("logical_or") +logical_xor = _create_op("logical_xor") +# matmul +maximum = _create_op("maximum") +minimum = _create_op("minimum") +mod = _create_op("mod") +multiply = _create_op("multiply") +nextafter = _create_op("nextafter") +not_equal = _create_op("not_equal") +pow = _create_op("pow") +power = _create_op("power") +remainder = _create_op("remainder") +right_shift = _create_op("right_shift") +subtract = _create_op("subtract") +true_divide = _create_op("true_divide") +# vecdot + +# elementwise non-ufunc +angle = _create_op("angle") +isreal = _create_op("isreal") +iscomplex = _create_op("iscomplex") + + +__all__ = [ + "abs", + "absolute", + "acos", + "acosh", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "asin", + "asinh", + "atan", + "atanh", + "bitwise_count", + "bitwise_invert", + "bitwise_not", + "cbrt", + "ceil", + "conj", + "conjugate", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "exp2", + "expm1", + "fabs", + "floor", + "invert", + "isfinite", + "isinf", + "isnan", + "isnat", + "log", + "log10", + "log1p", + "log2", + "logical_not", + "negative", + "positive", + "rad2deg", + "radians", + "reciprocal", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "spacing", + "sqrt", + "square", + "tan", + "tanh", + "trunc", + "add", + "arctan2", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "copysign", + "divide", + "equal", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "greater", + "greater_equal", + "heaviside", + "hypot", + "lcm", + "ldexp", + "left_shift", + "less", + "less_equal", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "mod", + "multiply", + "nextafter", + "not_equal", + "pow", + "power", + "remainder", + "right_shift", + "subtract", + "true_divide", + "angle", + "isreal", + "iscomplex", +] From 6a55ea6710d7e91c2d2b22d27768a5f76414f2eb Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sun, 17 Nov 2024 09:35:25 -0500 Subject: [PATCH 06/12] handle np versions, separate unary/binary path --- xarray/tests/test_ufuncs.py | 5 ++- xarray/ufuncs.py | 61 +++++++++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 27e8b46c5fa..580d3a8b94f 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -196,8 +196,11 @@ def setUp(self): @pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): - np_func = getattr(np, name) xu_func = getattr(xu, name) + if isinstance(xu_func, xu._UnavailableUfunc): + pytest.xfail(f"Ufunc {name} is not available in numpy {np.__version__}.") + + np_func = getattr(np, name) if name == "isnat": args = (self.xt,) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index d5eb2060a78..65240e042a7 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -41,16 +41,40 @@ def get_array_namespace(*args): return next(iter(xps)) if len(xps) else np -class _UFuncDispatcher: - """Wrapper for dispatching ufuncs.""" +class _UnaryUfunc: + """Wrapper for dispatching unary ufuncs.""" def __init__(self, name): self._name = name - def __call__(self, *args, **kwargs): - xp = get_array_namespace(*args) + def __call__(self, x, **kwargs): + xp = get_array_namespace(x) func = getattr(xp, self._name) - return xr.apply_ufunc(func, *args, dask="allowed", **kwargs) + return xr.apply_ufunc(func, x, dask="allowed", **kwargs) + + +class _BinaryUfunc: + """Wrapper for dispatching binary ufuncs.""" + + def __init__(self, name): + self._name = name + + def __call__(self, x, y, **kwargs): + xp = get_array_namespace(x, y) + func = getattr(xp, self._name) + return xr.apply_ufunc(func, x, y, dask="allowed", **kwargs) + + +class _UnavailableUfunc: + """Wrapper for unimplemented ufuncs in older numpy versions.""" + + def __init__(self, name): + self._name = name + + def __call__(self, *args, **kwargs): + raise NotImplementedError( + f"Ufunc {self._name} is not available in numpy {np.__version__}." + ) def _skip_signature(doc, name): @@ -90,7 +114,18 @@ def _dedent(doc): def _create_op(name): - func = _UFuncDispatcher(name) + if not hasattr(np, name): + # handle older numpy versions with missing array api standard aliases + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + return _UnavailableUfunc(name) + raise ValueError(f"'{name}' is not a valid numpy function") + + np_func = getattr(np, name) + if hasattr(np_func, "nin") and np_func.nin == 2: + func = _BinaryUfunc(name) + else: + func = _UnaryUfunc(name) + func.__name__ = name doc = getattr(np, name).__doc__ @@ -111,7 +146,7 @@ def _create_op(name): # Ufuncs that use core dimensions or product multiple output arrays are # not currently supported, and left commented below. -# unary +# UNARY abs = _create_op("abs") absolute = _create_op("absolute") acos = _create_op("acos") @@ -142,7 +177,7 @@ def _create_op(name): expm1 = _create_op("expm1") fabs = _create_op("fabs") floor = _create_op("floor") -# frexp +# frexp = _create_op("frexp") invert = _create_op("invert") isfinite = _create_op("isfinite") isinf = _create_op("isinf") @@ -153,7 +188,7 @@ def _create_op(name): log1p = _create_op("log1p") log2 = _create_op("log2") logical_not = _create_op("logical_not") -# modf +# modf = _create_op("modf") negative = _create_op("negative") positive = _create_op("positive") rad2deg = _create_op("rad2deg") @@ -171,7 +206,7 @@ def _create_op(name): tanh = _create_op("tanh") trunc = _create_op("trunc") -# binary +# BINARY add = _create_op("add") arctan2 = _create_op("arctan2") atan2 = _create_op("atan2") @@ -182,7 +217,7 @@ def _create_op(name): bitwise_xor = _create_op("bitwise_xor") copysign = _create_op("copysign") divide = _create_op("divide") -# divmod +# divmod = _create_op("divmod") equal = _create_op("equal") float_power = _create_op("float_power") floor_divide = _create_op("floor_divide") @@ -204,7 +239,7 @@ def _create_op(name): logical_and = _create_op("logical_and") logical_or = _create_op("logical_or") logical_xor = _create_op("logical_xor") -# matmul +# matmul = _create_op("matmul") maximum = _create_op("maximum") minimum = _create_op("minimum") mod = _create_op("mod") @@ -217,7 +252,7 @@ def _create_op(name): right_shift = _create_op("right_shift") subtract = _create_op("subtract") true_divide = _create_op("true_divide") -# vecdot +# vecdot = _create_op("vecdot") # elementwise non-ufunc angle = _create_op("angle") From 900ca17ca68136cb77d6a2f9d9931f75e517f6e1 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 11:47:20 -0500 Subject: [PATCH 07/12] explicit unary/binary creators --- xarray/tests/test_ufuncs.py | 2 +- xarray/ufuncs.py | 307 +++++++++++++++++------------------- 2 files changed, 149 insertions(+), 160 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 580d3a8b94f..5fdf12b9031 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -197,7 +197,7 @@ def setUp(self): @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): xu_func = getattr(xu, name) - if isinstance(xu_func, xu._UnavailableUfunc): + if not xu_func._available: pytest.xfail(f"Ufunc {name} is not available in numpy {np.__version__}.") np_func = getattr(np, name) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 65240e042a7..5021c33b13d 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -1,6 +1,7 @@ """xarray specific universal functions.""" import textwrap +from abc import ABC, abstractmethod import numpy as np @@ -41,42 +42,57 @@ def get_array_namespace(*args): return next(iter(xps)) if len(xps) else np -class _UnaryUfunc: - """Wrapper for dispatching unary ufuncs.""" - +class _ufunc_wrapper(ABC): def __init__(self, name): - self._name = name + self.__name__ = name + self._setup() + + @abstractmethod + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def _setup(self): + if hasattr(np, self.__name__): + self._available = True + self._create_doc() + else: + # some aliases are missing in older numpy versions + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + self._available = False + else: + raise ValueError(f"'{self.__name__}' is not a valid numpy function") + + def _create_doc(self): + doc = getattr(np, self.__name__).__doc__ + doc = _remove_unused_reference_labels( + _skip_signature(_dedent(doc), self.__name__) + ) + self.__doc__ = ( + f"xarray specific variant of numpy.{__name__}. Handles " + "xarray objects by dispatching to the appropriate " + "function for the underlying array type.\n\n" + f"Documentation from numpy:\n\n{doc}" + ) + - def __call__(self, x, **kwargs): +class _unary_ufunc(_ufunc_wrapper): + """Wrapper for dispatching unary ufuncs.""" + + def __call__(self, x, /, **kwargs): xp = get_array_namespace(x) - func = getattr(xp, self._name) + func = getattr(xp, self.__name__) return xr.apply_ufunc(func, x, dask="allowed", **kwargs) -class _BinaryUfunc: +class _binary_ufunc(_ufunc_wrapper): """Wrapper for dispatching binary ufuncs.""" - def __init__(self, name): - self._name = name - - def __call__(self, x, y, **kwargs): + def __call__(self, x, y, /, **kwargs): xp = get_array_namespace(x, y) - func = getattr(xp, self._name) + func = getattr(xp, self.__name__) return xr.apply_ufunc(func, x, y, dask="allowed", **kwargs) -class _UnavailableUfunc: - """Wrapper for unimplemented ufuncs in older numpy versions.""" - - def __init__(self, name): - self._name = name - - def __call__(self, *args, **kwargs): - raise NotImplementedError( - f"Ufunc {self._name} is not available in numpy {np.__version__}." - ) - - def _skip_signature(doc, name): if not isinstance(doc, str): return doc @@ -113,151 +129,124 @@ def _dedent(doc): return textwrap.dedent(doc) -def _create_op(name): - if not hasattr(np, name): - # handle older numpy versions with missing array api standard aliases - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - return _UnavailableUfunc(name) - raise ValueError(f"'{name}' is not a valid numpy function") - - np_func = getattr(np, name) - if hasattr(np_func, "nin") and np_func.nin == 2: - func = _BinaryUfunc(name) - else: - func = _UnaryUfunc(name) - - func.__name__ = name - doc = getattr(np, name).__doc__ - - doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name)) - - func.__doc__ = ( - f"xarray specific variant of numpy.{name}. Handles " - "xarray objects by dispatching to the appropriate " - "function for the underlying array type.\n\n" - f"Documentation from numpy:\n\n{doc}" - ) - return func - - # These can be auto-generated from the public numpy ufuncs: # {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} # Ufuncs that use core dimensions or product multiple output arrays are -# not currently supported, and left commented below. +# not currently supported, and left commented out below. # UNARY -abs = _create_op("abs") -absolute = _create_op("absolute") -acos = _create_op("acos") -acosh = _create_op("acosh") -arccos = _create_op("arccos") -arccosh = _create_op("arccosh") -arcsin = _create_op("arcsin") -arcsinh = _create_op("arcsinh") -arctan = _create_op("arctan") -arctanh = _create_op("arctanh") -asin = _create_op("asin") -asinh = _create_op("asinh") -atan = _create_op("atan") -atanh = _create_op("atanh") -bitwise_count = _create_op("bitwise_count") -bitwise_invert = _create_op("bitwise_invert") -bitwise_not = _create_op("bitwise_not") -cbrt = _create_op("cbrt") -ceil = _create_op("ceil") -conj = _create_op("conj") -conjugate = _create_op("conjugate") -cos = _create_op("cos") -cosh = _create_op("cosh") -deg2rad = _create_op("deg2rad") -degrees = _create_op("degrees") -exp = _create_op("exp") -exp2 = _create_op("exp2") -expm1 = _create_op("expm1") -fabs = _create_op("fabs") -floor = _create_op("floor") -# frexp = _create_op("frexp") -invert = _create_op("invert") -isfinite = _create_op("isfinite") -isinf = _create_op("isinf") -isnan = _create_op("isnan") -isnat = _create_op("isnat") -log = _create_op("log") -log10 = _create_op("log10") -log1p = _create_op("log1p") -log2 = _create_op("log2") -logical_not = _create_op("logical_not") -# modf = _create_op("modf") -negative = _create_op("negative") -positive = _create_op("positive") -rad2deg = _create_op("rad2deg") -radians = _create_op("radians") -reciprocal = _create_op("reciprocal") -rint = _create_op("rint") -sign = _create_op("sign") -signbit = _create_op("signbit") -sin = _create_op("sin") -sinh = _create_op("sinh") -spacing = _create_op("spacing") -sqrt = _create_op("sqrt") -square = _create_op("square") -tan = _create_op("tan") -tanh = _create_op("tanh") -trunc = _create_op("trunc") +abs = _unary_ufunc("abs") +absolute = _unary_ufunc("absolute") +acos = _unary_ufunc("acos") +acosh = _unary_ufunc("acosh") +arccos = _unary_ufunc("arccos") +arccosh = _unary_ufunc("arccosh") +arcsin = _unary_ufunc("arcsin") +arcsinh = _unary_ufunc("arcsinh") +arctan = _unary_ufunc("arctan") +arctanh = _unary_ufunc("arctanh") +asin = _unary_ufunc("asin") +asinh = _unary_ufunc("asinh") +atan = _unary_ufunc("atan") +atanh = _unary_ufunc("atanh") +bitwise_count = _unary_ufunc("bitwise_count") +bitwise_invert = _unary_ufunc("bitwise_invert") +bitwise_not = _unary_ufunc("bitwise_not") +cbrt = _unary_ufunc("cbrt") +ceil = _unary_ufunc("ceil") +conj = _unary_ufunc("conj") +conjugate = _unary_ufunc("conjugate") +cos = _unary_ufunc("cos") +cosh = _unary_ufunc("cosh") +deg2rad = _unary_ufunc("deg2rad") +degrees = _unary_ufunc("degrees") +exp = _unary_ufunc("exp") +exp2 = _unary_ufunc("exp2") +expm1 = _unary_ufunc("expm1") +fabs = _unary_ufunc("fabs") +floor = _unary_ufunc("floor") +# frexp = _unary_ufunc("frexp") +invert = _unary_ufunc("invert") +isfinite = _unary_ufunc("isfinite") +isinf = _unary_ufunc("isinf") +isnan = _unary_ufunc("isnan") +isnat = _unary_ufunc("isnat") +log = _unary_ufunc("log") +log10 = _unary_ufunc("log10") +log1p = _unary_ufunc("log1p") +log2 = _unary_ufunc("log2") +logical_not = _unary_ufunc("logical_not") +# modf = _unary_ufunc("modf") +negative = _unary_ufunc("negative") +positive = _unary_ufunc("positive") +rad2deg = _unary_ufunc("rad2deg") +radians = _unary_ufunc("radians") +reciprocal = _unary_ufunc("reciprocal") +rint = _unary_ufunc("rint") +sign = _unary_ufunc("sign") +signbit = _unary_ufunc("signbit") +sin = _unary_ufunc("sin") +sinh = _unary_ufunc("sinh") +spacing = _unary_ufunc("spacing") +sqrt = _unary_ufunc("sqrt") +square = _unary_ufunc("square") +tan = _unary_ufunc("tan") +tanh = _unary_ufunc("tanh") +trunc = _unary_ufunc("trunc") # BINARY -add = _create_op("add") -arctan2 = _create_op("arctan2") -atan2 = _create_op("atan2") -bitwise_and = _create_op("bitwise_and") -bitwise_left_shift = _create_op("bitwise_left_shift") -bitwise_or = _create_op("bitwise_or") -bitwise_right_shift = _create_op("bitwise_right_shift") -bitwise_xor = _create_op("bitwise_xor") -copysign = _create_op("copysign") -divide = _create_op("divide") -# divmod = _create_op("divmod") -equal = _create_op("equal") -float_power = _create_op("float_power") -floor_divide = _create_op("floor_divide") -fmax = _create_op("fmax") -fmin = _create_op("fmin") -fmod = _create_op("fmod") -gcd = _create_op("gcd") -greater = _create_op("greater") -greater_equal = _create_op("greater_equal") -heaviside = _create_op("heaviside") -hypot = _create_op("hypot") -lcm = _create_op("lcm") -ldexp = _create_op("ldexp") -left_shift = _create_op("left_shift") -less = _create_op("less") -less_equal = _create_op("less_equal") -logaddexp = _create_op("logaddexp") -logaddexp2 = _create_op("logaddexp2") -logical_and = _create_op("logical_and") -logical_or = _create_op("logical_or") -logical_xor = _create_op("logical_xor") -# matmul = _create_op("matmul") -maximum = _create_op("maximum") -minimum = _create_op("minimum") -mod = _create_op("mod") -multiply = _create_op("multiply") -nextafter = _create_op("nextafter") -not_equal = _create_op("not_equal") -pow = _create_op("pow") -power = _create_op("power") -remainder = _create_op("remainder") -right_shift = _create_op("right_shift") -subtract = _create_op("subtract") -true_divide = _create_op("true_divide") -# vecdot = _create_op("vecdot") +add = _binary_ufunc("add") +arctan2 = _binary_ufunc("arctan2") +atan2 = _binary_ufunc("atan2") +bitwise_and = _binary_ufunc("bitwise_and") +bitwise_left_shift = _binary_ufunc("bitwise_left_shift") +bitwise_or = _binary_ufunc("bitwise_or") +bitwise_right_shift = _binary_ufunc("bitwise_right_shift") +bitwise_xor = _binary_ufunc("bitwise_xor") +copysign = _binary_ufunc("copysign") +divide = _binary_ufunc("divide") +# divmod = _binary_ufunc("divmod") +equal = _binary_ufunc("equal") +float_power = _binary_ufunc("float_power") +floor_divide = _binary_ufunc("floor_divide") +fmax = _binary_ufunc("fmax") +fmin = _binary_ufunc("fmin") +fmod = _binary_ufunc("fmod") +gcd = _binary_ufunc("gcd") +greater = _binary_ufunc("greater") +greater_equal = _binary_ufunc("greater_equal") +heaviside = _binary_ufunc("heaviside") +hypot = _binary_ufunc("hypot") +lcm = _binary_ufunc("lcm") +ldexp = _binary_ufunc("ldexp") +left_shift = _binary_ufunc("left_shift") +less = _binary_ufunc("less") +less_equal = _binary_ufunc("less_equal") +logaddexp = _binary_ufunc("logaddexp") +logaddexp2 = _binary_ufunc("logaddexp2") +logical_and = _binary_ufunc("logical_and") +logical_or = _binary_ufunc("logical_or") +logical_xor = _binary_ufunc("logical_xor") +# matmul = _binary_ufunc("matmul") +maximum = _binary_ufunc("maximum") +minimum = _binary_ufunc("minimum") +mod = _binary_ufunc("mod") +multiply = _binary_ufunc("multiply") +nextafter = _binary_ufunc("nextafter") +not_equal = _binary_ufunc("not_equal") +pow = _binary_ufunc("pow") +power = _binary_ufunc("power") +remainder = _binary_ufunc("remainder") +right_shift = _binary_ufunc("right_shift") +subtract = _binary_ufunc("subtract") +true_divide = _binary_ufunc("true_divide") +# vecdot = _binary_ufunc("vecdot") # elementwise non-ufunc -angle = _create_op("angle") -isreal = _create_op("isreal") -iscomplex = _create_op("iscomplex") +angle = _unary_ufunc("angle") +isreal = _unary_ufunc("isreal") +iscomplex = _unary_ufunc("iscomplex") __all__ = [ From 2da600dadca57306eeb9a11b5d860279fa4164b9 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 13:12:53 -0500 Subject: [PATCH 08/12] add to api docs --- doc/api.rst | 114 +++++++++++++++++++++++++++++++++++++++++++++++ xarray/ufuncs.py | 9 ++-- 2 files changed, 119 insertions(+), 4 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 63427447d53..0c30ddc4c20 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -894,6 +894,120 @@ Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data .. DataTree.sortby .. DataTree.broadcast_like +Universal functions +=================== + +These functions are equivalent to their NumPy versions, but for xarray +objects backed by non-NumPy array types (e.g. ``cupy``, ``sparse``, or ``jax``), +they will ensure that the computation is dispatched to the appropriate +backend. You can find them in the ``xarray.ufuncs`` module: + +.. autosummary:: + :toctree: generated/ + + ufuncs.abs + ufuncs.absolute + ufuncs.acos + ufuncs.acosh + ufuncs.arccos + ufuncs.arccosh + ufuncs.arcsin + ufuncs.arcsinh + ufuncs.arctan + ufuncs.arctanh + ufuncs.asin + ufuncs.asinh + ufuncs.atan + ufuncs.atanh + ufuncs.bitwise_count + ufuncs.bitwise_invert + ufuncs.bitwise_not + ufuncs.cbrt + ufuncs.ceil + ufuncs.conj + ufuncs.conjugate + ufuncs.cos + ufuncs.cosh + ufuncs.deg2rad + ufuncs.degrees + ufuncs.exp + ufuncs.exp2 + ufuncs.expm1 + ufuncs.fabs + ufuncs.floor + ufuncs.invert + ufuncs.isfinite + ufuncs.isinf + ufuncs.isnan + ufuncs.isnat + ufuncs.log + ufuncs.log10 + ufuncs.log1p + ufuncs.log2 + ufuncs.logical_not + ufuncs.negative + ufuncs.positive + ufuncs.rad2deg + ufuncs.radians + ufuncs.reciprocal + ufuncs.rint + ufuncs.sign + ufuncs.signbit + ufuncs.sin + ufuncs.sinh + ufuncs.spacing + ufuncs.sqrt + ufuncs.square + ufuncs.tan + ufuncs.tanh + ufuncs.trunc + ufuncs.add + ufuncs.arctan2 + ufuncs.atan2 + ufuncs.bitwise_and + ufuncs.bitwise_left_shift + ufuncs.bitwise_or + ufuncs.bitwise_right_shift + ufuncs.bitwise_xor + ufuncs.copysign + ufuncs.divide + ufuncs.equal + ufuncs.float_power + ufuncs.floor_divide + ufuncs.fmax + ufuncs.fmin + ufuncs.fmod + ufuncs.gcd + ufuncs.greater + ufuncs.greater_equal + ufuncs.heaviside + ufuncs.hypot + ufuncs.lcm + ufuncs.ldexp + ufuncs.left_shift + ufuncs.less + ufuncs.less_equal + ufuncs.logaddexp + ufuncs.logaddexp2 + ufuncs.logical_and + ufuncs.logical_or + ufuncs.logical_xor + ufuncs.maximum + ufuncs.minimum + ufuncs.mod + ufuncs.multiply + ufuncs.nextafter + ufuncs.not_equal + ufuncs.pow + ufuncs.power + ufuncs.remainder + ufuncs.right_shift + ufuncs.subtract + ufuncs.true_divide + ufuncs.angle + ufuncs.isreal + ufuncs.iscomplex + IO / Conversion =============== diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 5021c33b13d..eab25a48a4a 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -68,8 +68,8 @@ def _create_doc(self): _skip_signature(_dedent(doc), self.__name__) ) self.__doc__ = ( - f"xarray specific variant of numpy.{__name__}. Handles " - "xarray objects by dispatching to the appropriate " + f"xarray specific variant of :py:func:`numpy.{self.__name__}`. " + "Handles xarray objects by dispatching to the appropriate " "function for the underlying array type.\n\n" f"Documentation from numpy:\n\n{doc}" ) @@ -97,6 +97,7 @@ def _skip_signature(doc, name): if not isinstance(doc, str): return doc + # TODO: this fails to remove the signature for aliased functions if doc.startswith(name): signature_end = doc.find("\n\n") doc = doc[signature_end + 2 :] @@ -132,8 +133,8 @@ def _dedent(doc): # These can be auto-generated from the public numpy ufuncs: # {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} -# Ufuncs that use core dimensions or product multiple output arrays are -# not currently supported, and left commented out below. +# Generalized ufuncs that use core dimensions or produce multiple output +# arrays are not currently supported, and left commented out below. # UNARY abs = _unary_ufunc("abs") From 22e678c80f80ab4732aa846cc9a961ffd77de8ce Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 13:18:30 -0500 Subject: [PATCH 09/12] add whats new --- doc/whats-new.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ee826e6e56f..7beb4a209f3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,10 @@ New Features - Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with arrays with more than two dimensions. (:issue:`5629`). By `Deepak Cherian `_. +- Re-implement the :py:mod:`ufuncs` module, which now dynamically dispatches to the + underlying array's backend. Provides better support for certain wrapped array types + like `jax.numpy.ndarray`. (:issue:`7848`, :pull:`9776`). + By `Sam Levang `_. Breaking changes ~~~~~~~~~~~~~~~~ From 21dd468004034a7acc8cee72b7d7ca5c50236e3d Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 13:42:50 -0500 Subject: [PATCH 10/12] move numpy version check to tests --- xarray/tests/test_ufuncs.py | 7 +++---- xarray/ufuncs.py | 14 ++------------ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 5fdf12b9031..61cd88e30ac 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -197,10 +197,9 @@ def setUp(self): @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): xu_func = getattr(xu, name) - if not xu_func._available: - pytest.xfail(f"Ufunc {name} is not available in numpy {np.__version__}.") - - np_func = getattr(np, name) + np_func = getattr(np, name, None) + if np_func is None and np.lib.NumpyVersion(np.__version__) < "2.0.0": + pytest.skip(f"Ufunc {name} is not available in numpy {np.__version__}.") if name == "isnat": args = (self.xt,) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index eab25a48a4a..9f55d03eddf 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -45,23 +45,13 @@ def get_array_namespace(*args): class _ufunc_wrapper(ABC): def __init__(self, name): self.__name__ = name - self._setup() + if hasattr(np, name): + self._create_doc() @abstractmethod def __call__(self, *args, **kwargs): raise NotImplementedError - def _setup(self): - if hasattr(np, self.__name__): - self._available = True - self._create_doc() - else: - # some aliases are missing in older numpy versions - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self._available = False - else: - raise ValueError(f"'{self.__name__}' is not a valid numpy function") - def _create_doc(self): doc = getattr(np, self.__name__).__doc__ doc = _remove_unused_reference_labels( From 127b5919b7a8cb405b9050dd70b6be2d7d2f9a5e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 13:47:57 -0500 Subject: [PATCH 11/12] fix docs for aliased np funcs --- xarray/ufuncs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 9f55d03eddf..cedece4c68f 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -87,8 +87,10 @@ def _skip_signature(doc, name): if not isinstance(doc, str): return doc - # TODO: this fails to remove the signature for aliased functions - if doc.startswith(name): + # numpy creates some functions as aliases and copies the docstring exactly, + # so check the actual name to handle this case + np_name = getattr(np, name).__name__ + if doc.startswith(np_name): signature_end = doc.find("\n\n") doc = doc[signature_end + 2 :] From 81d6c074231213e711b972b1662d8a67a32d8067 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 13:53:33 -0500 Subject: [PATCH 12/12] fix whats new --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7beb4a209f3..d447ac21266 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,7 +40,7 @@ New Features (:issue:`5629`). By `Deepak Cherian `_. - Re-implement the :py:mod:`ufuncs` module, which now dynamically dispatches to the underlying array's backend. Provides better support for certain wrapped array types - like `jax.numpy.ndarray`. (:issue:`7848`, :pull:`9776`). + like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. Breaking changes