4
4
from typing import Any
5
5
6
6
import numpy as np
7
+ from pandas .api .types import is_extension_array_dtype
7
8
8
- from xarray .core import utils
9
+ from xarray .core import npcompat , utils
9
10
10
11
# Use as a sentinel value to indicate a dtype appropriate NA value.
11
12
NA = utils .ReprObject ("<NA>" )
@@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
60
61
# N.B. these casting rules should match pandas
61
62
dtype_ : np .typing .DTypeLike
62
63
fill_value : Any
63
- if np . issubdtype (dtype , np . floating ):
64
+ if isdtype (dtype , "real floating" ):
64
65
dtype_ = dtype
65
66
fill_value = np .nan
66
- elif np .issubdtype (dtype , np .timedelta64 ):
67
+ elif isinstance ( dtype , np . dtype ) and np .issubdtype (dtype , np .timedelta64 ):
67
68
# See https://github.com/numpy/numpy/issues/10685
68
69
# np.timedelta64 is a subclass of np.integer
69
70
# Check np.timedelta64 before np.integer
70
71
fill_value = np .timedelta64 ("NaT" )
71
72
dtype_ = dtype
72
- elif np . issubdtype (dtype , np . integer ):
73
+ elif isdtype (dtype , "integral" ):
73
74
dtype_ = np .float32 if dtype .itemsize <= 2 else np .float64
74
75
fill_value = np .nan
75
- elif np . issubdtype (dtype , np . complexfloating ):
76
+ elif isdtype (dtype , "complex floating" ):
76
77
dtype_ = dtype
77
78
fill_value = np .nan + np .nan * 1j
78
- elif np .issubdtype (dtype , np .datetime64 ):
79
+ elif isinstance ( dtype , np . dtype ) and np .issubdtype (dtype , np .datetime64 ):
79
80
dtype_ = dtype
80
81
fill_value = np .datetime64 ("NaT" )
81
82
else :
@@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=False):
118
119
-------
119
120
fill_value : positive infinity value corresponding to this dtype.
120
121
"""
121
- if issubclass (dtype . type , np . floating ):
122
+ if isdtype (dtype , "real floating" ):
122
123
return np .inf
123
124
124
- if issubclass (dtype . type , np . integer ):
125
+ if isdtype (dtype , "integral" ):
125
126
if max_for_int :
126
127
return np .iinfo (dtype ).max
127
128
else :
128
129
return np .inf
129
130
130
- if issubclass (dtype . type , np . complexfloating ):
131
+ if isdtype (dtype , "complex floating" ):
131
132
return np .inf + 1j * np .inf
132
133
133
134
return INF
@@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=False):
146
147
-------
147
148
fill_value : positive infinity value corresponding to this dtype.
148
149
"""
149
- if issubclass (dtype . type , np . floating ):
150
+ if isdtype (dtype , "real floating" ):
150
151
return - np .inf
151
152
152
- if issubclass (dtype . type , np . integer ):
153
+ if isdtype (dtype , "integral" ):
153
154
if min_for_int :
154
155
return np .iinfo (dtype ).min
155
156
else :
156
157
return - np .inf
157
158
158
- if issubclass (dtype . type , np . complexfloating ):
159
+ if isdtype (dtype , "complex floating" ):
159
160
return - np .inf - 1j * np .inf
160
161
161
162
return NINF
162
163
163
164
164
- def is_datetime_like (dtype ):
165
+ def is_datetime_like (dtype ) -> bool :
165
166
"""Check if a dtype is a subclass of the numpy datetime types"""
166
- return np .issubdtype (dtype , np .datetime64 ) or np .issubdtype (dtype , np .timedelta64 )
167
+ return _is_numpy_subdtype (dtype , (np .datetime64 , np .timedelta64 ))
168
+
169
+
170
+ def is_object (dtype ) -> bool :
171
+ """Check if a dtype is object"""
172
+ return _is_numpy_subdtype (dtype , object )
173
+
174
+
175
+ def is_string (dtype ) -> bool :
176
+ """Check if a dtype is a string dtype"""
177
+ return _is_numpy_subdtype (dtype , (np .str_ , np .character ))
178
+
179
+
180
+ def _is_numpy_subdtype (dtype , kind ) -> bool :
181
+ if not isinstance (dtype , np .dtype ):
182
+ return False
183
+
184
+ kinds = kind if isinstance (kind , tuple ) else (kind ,)
185
+ return any (np .issubdtype (dtype , kind ) for kind in kinds )
186
+
187
+
188
+ def isdtype (dtype , kind : str | tuple [str , ...], xp = None ) -> bool :
189
+ """Compatibility wrapper for isdtype() from the array API standard.
190
+
191
+ Unlike xp.isdtype(), kind must be a string.
192
+ """
193
+ # TODO(shoyer): remove this wrapper when Xarray requires
194
+ # numpy>=2 and pandas extensions arrays are implemented in
195
+ # Xarray via the array API
196
+ if not isinstance (kind , str ) and not (
197
+ isinstance (kind , tuple ) and all (isinstance (k , str ) for k in kind )
198
+ ):
199
+ raise TypeError (f"kind must be a string or a tuple of strings: { repr (kind )} " )
200
+
201
+ if isinstance (dtype , np .dtype ):
202
+ return npcompat .isdtype (dtype , kind )
203
+ elif is_extension_array_dtype (dtype ):
204
+ # we never want to match pandas extension array dtypes
205
+ return False
206
+ else :
207
+ if xp is None :
208
+ xp = np
209
+ return xp .isdtype (dtype , kind )
167
210
168
211
169
212
def result_type (
@@ -184,12 +227,26 @@ def result_type(
184
227
-------
185
228
numpy.dtype for the result.
186
229
"""
187
- types = {np .result_type (t ).type for t in arrays_and_dtypes }
230
+ from xarray .core .duck_array_ops import get_array_namespace
231
+
232
+ # TODO(shoyer): consider moving this logic into get_array_namespace()
233
+ # or another helper function.
234
+ namespaces = {get_array_namespace (t ) for t in arrays_and_dtypes }
235
+ non_numpy = namespaces - {np }
236
+ if non_numpy :
237
+ [xp ] = non_numpy
238
+ else :
239
+ xp = np
240
+
241
+ types = {xp .result_type (t ) for t in arrays_and_dtypes }
188
242
189
- for left , right in PROMOTE_TO_OBJECT :
190
- if any (issubclass (t , left ) for t in types ) and any (
191
- issubclass (t , right ) for t in types
192
- ):
193
- return np .dtype (object )
243
+ if any (isinstance (t , np .dtype ) for t in types ):
244
+ # only check if there's numpy dtypes – the array API does not
245
+ # define the types we're checking for
246
+ for left , right in PROMOTE_TO_OBJECT :
247
+ if any (np .issubdtype (t , left ) for t in types ) and any (
248
+ np .issubdtype (t , right ) for t in types
249
+ ):
250
+ return xp .dtype (object )
194
251
195
- return np .result_type (* arrays_and_dtypes )
252
+ return xp .result_type (* arrays_and_dtypes )
0 commit comments