Skip to content

Commit 8d64f60

Browse files
committed
Merge branch 'main' into backend-indexing
* main: array api-related upstream-dev failures (#8854) (fix): equality check against singleton `PandasExtensionArray` (#9032)
2 parents 45ceac6 + 026aa7c commit 8d64f60

12 files changed

+291
-119
lines changed

xarray/core/dtypes.py

+78-21
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from typing import Any
55

66
import numpy as np
7+
from pandas.api.types import is_extension_array_dtype
78

8-
from xarray.core import utils
9+
from xarray.core import npcompat, utils
910

1011
# Use as a sentinel value to indicate a dtype appropriate NA value.
1112
NA = utils.ReprObject("<NA>")
@@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
6061
# N.B. these casting rules should match pandas
6162
dtype_: np.typing.DTypeLike
6263
fill_value: Any
63-
if np.issubdtype(dtype, np.floating):
64+
if isdtype(dtype, "real floating"):
6465
dtype_ = dtype
6566
fill_value = np.nan
66-
elif np.issubdtype(dtype, np.timedelta64):
67+
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64):
6768
# See https://github.com/numpy/numpy/issues/10685
6869
# np.timedelta64 is a subclass of np.integer
6970
# Check np.timedelta64 before np.integer
7071
fill_value = np.timedelta64("NaT")
7172
dtype_ = dtype
72-
elif np.issubdtype(dtype, np.integer):
73+
elif isdtype(dtype, "integral"):
7374
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
7475
fill_value = np.nan
75-
elif np.issubdtype(dtype, np.complexfloating):
76+
elif isdtype(dtype, "complex floating"):
7677
dtype_ = dtype
7778
fill_value = np.nan + np.nan * 1j
78-
elif np.issubdtype(dtype, np.datetime64):
79+
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64):
7980
dtype_ = dtype
8081
fill_value = np.datetime64("NaT")
8182
else:
@@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=False):
118119
-------
119120
fill_value : positive infinity value corresponding to this dtype.
120121
"""
121-
if issubclass(dtype.type, np.floating):
122+
if isdtype(dtype, "real floating"):
122123
return np.inf
123124

124-
if issubclass(dtype.type, np.integer):
125+
if isdtype(dtype, "integral"):
125126
if max_for_int:
126127
return np.iinfo(dtype).max
127128
else:
128129
return np.inf
129130

130-
if issubclass(dtype.type, np.complexfloating):
131+
if isdtype(dtype, "complex floating"):
131132
return np.inf + 1j * np.inf
132133

133134
return INF
@@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=False):
146147
-------
147148
fill_value : positive infinity value corresponding to this dtype.
148149
"""
149-
if issubclass(dtype.type, np.floating):
150+
if isdtype(dtype, "real floating"):
150151
return -np.inf
151152

152-
if issubclass(dtype.type, np.integer):
153+
if isdtype(dtype, "integral"):
153154
if min_for_int:
154155
return np.iinfo(dtype).min
155156
else:
156157
return -np.inf
157158

158-
if issubclass(dtype.type, np.complexfloating):
159+
if isdtype(dtype, "complex floating"):
159160
return -np.inf - 1j * np.inf
160161

161162
return NINF
162163

163164

164-
def is_datetime_like(dtype):
165+
def is_datetime_like(dtype) -> bool:
165166
"""Check if a dtype is a subclass of the numpy datetime types"""
166-
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
167+
return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64))
168+
169+
170+
def is_object(dtype) -> bool:
171+
"""Check if a dtype is object"""
172+
return _is_numpy_subdtype(dtype, object)
173+
174+
175+
def is_string(dtype) -> bool:
176+
"""Check if a dtype is a string dtype"""
177+
return _is_numpy_subdtype(dtype, (np.str_, np.character))
178+
179+
180+
def _is_numpy_subdtype(dtype, kind) -> bool:
181+
if not isinstance(dtype, np.dtype):
182+
return False
183+
184+
kinds = kind if isinstance(kind, tuple) else (kind,)
185+
return any(np.issubdtype(dtype, kind) for kind in kinds)
186+
187+
188+
def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
189+
"""Compatibility wrapper for isdtype() from the array API standard.
190+
191+
Unlike xp.isdtype(), kind must be a string.
192+
"""
193+
# TODO(shoyer): remove this wrapper when Xarray requires
194+
# numpy>=2 and pandas extensions arrays are implemented in
195+
# Xarray via the array API
196+
if not isinstance(kind, str) and not (
197+
isinstance(kind, tuple) and all(isinstance(k, str) for k in kind)
198+
):
199+
raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
200+
201+
if isinstance(dtype, np.dtype):
202+
return npcompat.isdtype(dtype, kind)
203+
elif is_extension_array_dtype(dtype):
204+
# we never want to match pandas extension array dtypes
205+
return False
206+
else:
207+
if xp is None:
208+
xp = np
209+
return xp.isdtype(dtype, kind)
167210

168211

169212
def result_type(
@@ -184,12 +227,26 @@ def result_type(
184227
-------
185228
numpy.dtype for the result.
186229
"""
187-
types = {np.result_type(t).type for t in arrays_and_dtypes}
230+
from xarray.core.duck_array_ops import get_array_namespace
231+
232+
# TODO(shoyer): consider moving this logic into get_array_namespace()
233+
# or another helper function.
234+
namespaces = {get_array_namespace(t) for t in arrays_and_dtypes}
235+
non_numpy = namespaces - {np}
236+
if non_numpy:
237+
[xp] = non_numpy
238+
else:
239+
xp = np
240+
241+
types = {xp.result_type(t) for t in arrays_and_dtypes}
188242

189-
for left, right in PROMOTE_TO_OBJECT:
190-
if any(issubclass(t, left) for t in types) and any(
191-
issubclass(t, right) for t in types
192-
):
193-
return np.dtype(object)
243+
if any(isinstance(t, np.dtype) for t in types):
244+
# only check if there's numpy dtypes – the array API does not
245+
# define the types we're checking for
246+
for left, right in PROMOTE_TO_OBJECT:
247+
if any(np.issubdtype(t, left) for t in types) and any(
248+
np.issubdtype(t, right) for t in types
249+
):
250+
return xp.dtype(object)
194251

195-
return np.result_type(*arrays_and_dtypes)
252+
return xp.result_type(*arrays_and_dtypes)

xarray/core/duck_array_ops.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,25 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
142142

143143
def isnull(data):
144144
data = asarray(data)
145-
scalar_type = data.dtype.type
146-
if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
145+
146+
xp = get_array_namespace(data)
147+
scalar_type = data.dtype
148+
if dtypes.is_datetime_like(scalar_type):
147149
# datetime types use NaT for null
148150
# note: must check timedelta64 before integers, because currently
149151
# timedelta64 inherits from np.integer
150152
return isnat(data)
151-
elif issubclass(scalar_type, np.inexact):
153+
elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
152154
# float types use NaN for null
153155
xp = get_array_namespace(data)
154156
return xp.isnan(data)
155-
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
157+
elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or (
158+
isinstance(scalar_type, np.dtype)
159+
and (
160+
np.issubdtype(scalar_type, np.character)
161+
or np.issubdtype(scalar_type, np.void)
162+
)
163+
):
156164
# these types cannot represent missing values
157165
return full_like(data, dtype=bool, fill_value=False)
158166
else:
@@ -406,13 +414,22 @@ def f(values, axis=None, skipna=None, **kwargs):
406414
if invariant_0d and axis == ():
407415
return values
408416

409-
values = asarray(values)
417+
xp = get_array_namespace(values)
418+
values = asarray(values, xp=xp)
410419

411-
if coerce_strings and values.dtype.kind in "SU":
420+
if coerce_strings and dtypes.is_string(values.dtype):
412421
values = astype(values, object)
413422

414423
func = None
415-
if skipna or (skipna is None and values.dtype.kind in "cfO"):
424+
if skipna or (
425+
skipna is None
426+
and (
427+
dtypes.isdtype(
428+
values.dtype, ("complex floating", "real floating"), xp=xp
429+
)
430+
or dtypes.is_object(values.dtype)
431+
)
432+
):
416433
nanname = "nan" + name
417434
func = getattr(nanops, nanname)
418435
else:
@@ -477,8 +494,8 @@ def _datetime_nanmin(array):
477494
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
478495
- dask min() does not work on datetime64 (all versions at the moment of writing)
479496
"""
480-
assert array.dtype.kind in "mM"
481497
dtype = array.dtype
498+
assert dtypes.is_datetime_like(dtype)
482499
# (NaT).astype(float) does not produce NaN...
483500
array = where(pandas_isnull(array), np.nan, array.astype(float))
484501
array = min(array, skipna=True)
@@ -515,7 +532,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
515532
"""
516533
# Set offset to minimum if not given
517534
if offset is None:
518-
if array.dtype.kind in "Mm":
535+
if dtypes.is_datetime_like(array.dtype):
519536
offset = _datetime_nanmin(array)
520537
else:
521538
offset = min(array)
@@ -527,7 +544,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
527544
# This map_blocks call is for backwards compatibility.
528545
# dask == 2021.04.1 does not support subtracting object arrays
529546
# which is required for cftime
530-
if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
547+
if is_duck_dask_array(array) and dtypes.is_object(array.dtype):
531548
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
532549
else:
533550
array = array - offset
@@ -537,11 +554,11 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
537554
array = np.array(array)
538555

539556
# Convert timedelta objects to float by first converting to microseconds.
540-
if array.dtype.kind in "O":
557+
if dtypes.is_object(array.dtype):
541558
return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype)
542559

543560
# Convert np.NaT to np.nan
544-
elif array.dtype.kind in "mM":
561+
elif dtypes.is_datetime_like(array.dtype):
545562
# Convert to specified timedelta units.
546563
if datetime_unit:
547564
array = array / np.timedelta64(1, datetime_unit)
@@ -641,7 +658,7 @@ def mean(array, axis=None, skipna=None, **kwargs):
641658
from xarray.core.common import _contains_cftime_datetimes
642659

643660
array = asarray(array)
644-
if array.dtype.kind in "Mm":
661+
if dtypes.is_datetime_like(array.dtype):
645662
offset = _datetime_nanmin(array)
646663

647664
# xarray always uses np.datetime64[ns] for np.datetime64 data
@@ -689,7 +706,9 @@ def cumsum(array, axis=None, **kwargs):
689706

690707
def first(values, axis, skipna=None):
691708
"""Return the first non-NA elements in this array along the given axis"""
692-
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
709+
if (skipna or skipna is None) and not (
710+
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
711+
):
693712
# only bother for dtypes that can hold NaN
694713
if is_chunked_array(values):
695714
return chunked_nanfirst(values, axis)
@@ -700,7 +719,9 @@ def first(values, axis, skipna=None):
700719

701720
def last(values, axis, skipna=None):
702721
"""Return the last non-NA elements in this array along the given axis"""
703-
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
722+
if (skipna or skipna is None) and not (
723+
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
724+
):
704725
# only bother for dtypes that can hold NaN
705726
if is_chunked_array(values):
706727
return chunked_nanlast(values, axis)

xarray/core/extension_array.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __setitem__(self, key, val):
123123
self.array[key] = val
124124

125125
def __eq__(self, other):
126-
if np.isscalar(other):
127-
other = type(self)(type(self.array)([other]))
128126
if isinstance(other, PandasExtensionArray):
129127
return self.array == other.array
130128
return self.array == other

xarray/core/npcompat.py

+30
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,33 @@
2828
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2929
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
3030
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
32+
try:
33+
# requires numpy>=2.0
34+
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
35+
except ImportError:
36+
import numpy as np
37+
38+
dtype_kinds = {
39+
"bool": np.bool_,
40+
"signed integer": np.signedinteger,
41+
"unsigned integer": np.unsignedinteger,
42+
"integral": np.integer,
43+
"real floating": np.floating,
44+
"complex floating": np.complexfloating,
45+
"numeric": np.number,
46+
}
47+
48+
def isdtype(dtype, kind):
49+
kinds = kind if isinstance(kind, tuple) else (kind,)
50+
51+
unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
52+
if unknown_dtypes:
53+
raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")
54+
55+
# verified the dtypes already, no need to check again
56+
translated_kinds = [dtype_kinds[kind] for kind in kinds]
57+
if isinstance(dtype, np.generic):
58+
return any(isinstance(dtype, kind) for kind in translated_kinds)
59+
else:
60+
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)

0 commit comments

Comments
 (0)