diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index ff7af02abfc..8d6a1d3ed8c 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -16,7 +16,7 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_array class SupportsArithmetic: @@ -33,12 +33,11 @@ class SupportsArithmetic: # TODO: allow extending this with some sort of registration system _HANDLED_TYPES = ( - np.ndarray, np.generic, numbers.Number, bytes, str, - ) + dask_array_type + ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): from .computation import apply_ufunc @@ -46,7 +45,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. out = kwargs.get("out", ()) for x in inputs + out: - if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + if not is_duck_array(x) and not isinstance( + x, self._HANDLED_TYPES + (SupportsArithmetic,) + ): return NotImplemented if ufunc.signature is not None: diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py deleted file mode 100644 index b53947e88eb..00000000000 --- a/xarray/core/dask_array_compat.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import warnings - -import numpy as np - -try: - import dask.array as da -except ImportError: - da = None # type: ignore - - -def _validate_pad_output_shape(input_shape, pad_width, output_shape): - """Validates the output shape of dask.array.pad, raising a RuntimeError if they do not match. - In the current versions of dask (2.2/2.4), dask.array.pad with mode='reflect' sometimes returns - an invalid shape. - """ - isint = lambda i: isinstance(i, int) - - if isint(pad_width): - pass - elif len(pad_width) == 2 and all(map(isint, pad_width)): - pad_width = sum(pad_width) - elif ( - len(pad_width) == len(input_shape) - and all(map(lambda x: len(x) == 2, pad_width)) - and all(isint(i) for p in pad_width for i in p) - ): - pad_width = np.sum(pad_width, axis=1) - else: - # unreachable: dask.array.pad should already have thrown an error - raise ValueError("Invalid value for `pad_width`") - - if not np.array_equal(np.array(input_shape) + pad_width, output_shape): - raise RuntimeError( - "There seems to be something wrong with the shape of the output of dask.array.pad, " - "try upgrading Dask, use a different pad mode e.g. mode='constant' or first convert " - "your DataArray/Dataset to one backed by a numpy array by calling the `compute()` method." - "See: https://github.com/dask/dask/issues/5303" - ) - - -def pad(array, pad_width, mode="constant", **kwargs): - padded = da.pad(array, pad_width, mode=mode, **kwargs) - # workaround for inconsistency between numpy and dask: https://github.com/dask/dask/issues/5303 - if mode == "mean" and issubclass(array.dtype.type, np.integer): - warnings.warn( - 'dask.array.pad(mode="mean") converts integers to floats. xarray converts ' - "these floats back to integers to keep the interface consistent. There is a chance that " - "this introduces rounding errors. If you wish to keep the values as floats, first change " - "the dtype to a float before calling pad.", - UserWarning, - ) - return da.round(padded).astype(array.dtype) - _validate_pad_output_shape(array.shape, pad_width, padded.shape) - return padded - - -if da is not None: - sliding_window_view = da.lib.stride_tricks.sliding_window_view -else: - sliding_window_view = None diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2bf05abb96a..4f42f497a69 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -23,9 +23,9 @@ from numpy import take, tensordot, transpose, unravel_index # noqa from numpy import where as _where -from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils +from . import dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import cupy_array_type, dask_array_type, is_duck_dask_array +from .pycompat import cupy_array_type, is_duck_dask_array from .utils import is_duck_array try: @@ -113,7 +113,7 @@ def isnull(data): return zeros_like(data, dtype=bool) else: # at this point, array should have dtype=object - if isinstance(data, (np.ndarray, dask_array_type)): + if isinstance(data, np.ndarray): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -631,7 +631,9 @@ def sliding_window_view(array, window_shape, axis): The rolling dimension will be placed at the last dimension. """ if is_duck_dask_array(array): - return dask_array_compat.sliding_window_view(array, window_shape, axis) + import dask.array as da + + return da.lib.stride_tricks.sliding_window_view(array, window_shape, axis) else: return npcompat.sliding_window_view(array, window_shape, axis) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 920fd5a094e..4c71658b577 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -6,15 +6,6 @@ from . import dtypes, nputils, utils from .duck_array_ops import count, fillna, isnull, where, where_method -from .pycompat import dask_array_type - -try: - import dask.array as dask_array - - from . import dask_array_compat -except ImportError: - dask_array = None # type: ignore[assignment] - dask_array_compat = None # type: ignore[assignment] def _maybe_null_out(result, axis, mask, min_count=1): @@ -65,16 +56,14 @@ def nanmin(a, axis=None, out=None): if a.dtype.kind == "O": return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanmin(a, axis=axis) + return nputils.nanmin(a, axis=axis) def nanmax(a, axis=None, out=None): if a.dtype.kind == "O": return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanmax(a, axis=axis) + return nputils.nanmax(a, axis=axis) def nanargmin(a, axis=None): @@ -82,8 +71,7 @@ def nanargmin(a, axis=None): fill_value = dtypes.get_pos_infinity(a.dtype) return _nan_argminmax_object("argmin", fill_value, a, axis=axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanargmin(a, axis=axis) + return nputils.nanargmin(a, axis=axis) def nanargmax(a, axis=None): @@ -91,8 +79,7 @@ def nanargmax(a, axis=None): fill_value = dtypes.get_neg_infinity(a.dtype) return _nan_argminmax_object("argmax", fill_value, a, axis=axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanargmax(a, axis=axis) + return nputils.nanargmax(a, axis=axis) def nansum(a, axis=None, dtype=None, out=None, min_count=None): @@ -128,8 +115,6 @@ def nanmean(a, axis=None, dtype=None, out=None): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning ) - if isinstance(a, dask_array_type): - return dask_array.nanmean(a, axis=axis, dtype=dtype) return np.nanmean(a, axis=axis, dtype=dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c650464bbcb..cfe1db6a24e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -39,7 +39,6 @@ from .pycompat import ( DuckArrayModule, cupy_array_type, - dask_array_type, integer_types, is_duck_dask_array, sparse_array_type, @@ -59,12 +58,8 @@ ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - ( - indexing.ExplicitlyIndexed, - pd.Index, - ) - + dask_array_type - + cupy_array_type + indexing.ExplicitlyIndexed, + pd.Index, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -1150,7 +1145,7 @@ def to_numpy(self) -> np.ndarray: data = self.data # TODO first attempt to call .to_numpy() once some libraries implement it - if isinstance(data, dask_array_type): + if hasattr(data, "chunks"): data = data.compute() if isinstance(data, cupy_array_type): data = data.get()