11
11
import math
12
12
import inspect
13
13
import warnings
14
+ from functools import cache
14
15
from typing import Optional , Union , Any
15
16
16
17
from ._typing import Array , Device , Namespace
17
18
18
19
19
- def _is_jax_zero_gradient_array (x : object ) -> bool :
20
+ @cache
21
+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
22
+ try :
23
+ mod = sys .modules [modname ]
24
+ except KeyError :
25
+ return False
26
+ parent_cls = getattr (mod , clsname )
27
+ return issubclass (cls , parent_cls )
28
+
29
+
30
+ def _is_jax_zero_gradient_array (x : Array ) -> bool :
20
31
"""Return True if `x` is a zero-gradient array.
21
32
22
33
These arrays are a design quirk of Jax that may one day be removed.
23
34
See https://github.com/google/jax/issues/20620.
24
35
"""
25
- if 'numpy' not in sys .modules or 'jax' not in sys .modules :
36
+ # Fast exit
37
+ try :
38
+ dtype = x .dtype
39
+ except AttributeError :
40
+ return False
41
+ if not _issubclass_fast (type (dtype ), "numpy.dtypes" , "VoidDType" ):
26
42
return False
27
43
28
- import numpy as np
29
- import jax
44
+ if "jax" not in sys . modules :
45
+ return False
30
46
31
- return isinstance (x , np .ndarray ) and x .dtype == jax .float0
47
+ import jax
48
+ # jax.float0 is a np.dtype([('float0', 'V')])
49
+ return dtype == jax .float0
32
50
33
51
34
52
def is_numpy_array (x : object ) -> bool :
@@ -52,15 +70,12 @@ def is_numpy_array(x: object) -> bool:
52
70
is_jax_array
53
71
is_pydata_sparse_array
54
72
"""
55
- # Avoid importing NumPy if it isn't already
56
- if 'numpy' not in sys .modules :
57
- return False
58
-
59
- import numpy as np
60
-
61
73
# TODO: Should we reject ndarray subclasses?
62
- return (isinstance (x , (np .ndarray , np .generic ))
63
- and not _is_jax_zero_gradient_array (x ))
74
+ cls = type (x )
75
+ return (
76
+ _issubclass_fast (cls , "numpy" , "ndarray" )
77
+ or _issubclass_fast (cls , "numpy" , "generic" )
78
+ ) and not _is_jax_zero_gradient_array (x )
64
79
65
80
66
81
def is_cupy_array (x : object ) -> bool :
@@ -84,14 +99,7 @@ def is_cupy_array(x: object) -> bool:
84
99
is_jax_array
85
100
is_pydata_sparse_array
86
101
"""
87
- # Avoid importing CuPy if it isn't already
88
- if 'cupy' not in sys .modules :
89
- return False
90
-
91
- import cupy as cp
92
-
93
- # TODO: Should we reject ndarray subclasses?
94
- return isinstance (x , cp .ndarray )
102
+ return _issubclass_fast (type (x ), "cupy" , "ndarray" )
95
103
96
104
97
105
def is_torch_array (x : object ) -> bool :
@@ -112,14 +120,7 @@ def is_torch_array(x: object) -> bool:
112
120
is_jax_array
113
121
is_pydata_sparse_array
114
122
"""
115
- # Avoid importing torch if it isn't already
116
- if 'torch' not in sys .modules :
117
- return False
118
-
119
- import torch
120
-
121
- # TODO: Should we reject ndarray subclasses?
122
- return isinstance (x , torch .Tensor )
123
+ return _issubclass_fast (type (x ), "torch" , "Tensor" )
123
124
124
125
125
126
def is_ndonnx_array (x : object ) -> bool :
@@ -141,13 +142,7 @@ def is_ndonnx_array(x: object) -> bool:
141
142
is_jax_array
142
143
is_pydata_sparse_array
143
144
"""
144
- # Avoid importing torch if it isn't already
145
- if 'ndonnx' not in sys .modules :
146
- return False
147
-
148
- import ndonnx as ndx
149
-
150
- return isinstance (x , ndx .Array )
145
+ return _issubclass_fast (type (x ), "ndonnx" , "Array" )
151
146
152
147
153
148
def is_dask_array (x : object ) -> bool :
@@ -169,13 +164,7 @@ def is_dask_array(x: object) -> bool:
169
164
is_jax_array
170
165
is_pydata_sparse_array
171
166
"""
172
- # Avoid importing dask if it isn't already
173
- if 'dask.array' not in sys .modules :
174
- return False
175
-
176
- import dask .array
177
-
178
- return isinstance (x , dask .array .Array )
167
+ return _issubclass_fast (type (x ), "dask.array" , "Array" )
179
168
180
169
181
170
def is_jax_array (x : object ) -> bool :
@@ -198,13 +187,7 @@ def is_jax_array(x: object) -> bool:
198
187
is_dask_array
199
188
is_pydata_sparse_array
200
189
"""
201
- # Avoid importing jax if it isn't already
202
- if 'jax' not in sys .modules :
203
- return False
204
-
205
- import jax
206
-
207
- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
190
+ return _issubclass_fast (type (x ), "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
208
191
209
192
210
193
def is_pydata_sparse_array (x ) -> bool :
@@ -227,14 +210,8 @@ def is_pydata_sparse_array(x) -> bool:
227
210
is_dask_array
228
211
is_jax_array
229
212
"""
230
- # Avoid importing jax if it isn't already
231
- if 'sparse' not in sys .modules :
232
- return False
233
-
234
- import sparse
235
-
236
213
# TODO: Account for other backends.
237
- return isinstance ( x , sparse . SparseArray )
214
+ return _issubclass_fast ( type ( x ), " sparse" , " SparseArray" )
238
215
239
216
240
217
def is_array_api_obj (x : object ) -> bool :
@@ -252,20 +229,30 @@ def is_array_api_obj(x: object) -> bool:
252
229
is_dask_array
253
230
is_jax_array
254
231
"""
255
- return is_numpy_array (x ) \
256
- or is_cupy_array (x ) \
257
- or is_torch_array (x ) \
258
- or is_dask_array (x ) \
259
- or is_jax_array (x ) \
260
- or is_pydata_sparse_array (x ) \
261
- or hasattr (x , '__array_namespace__' )
232
+ return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
233
+
234
+
235
+ @cache
236
+ def _is_array_api_cls (cls : type ) -> bool :
237
+ return (
238
+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
239
+ _issubclass_fast (cls , "numpy" , "ndarray" )
240
+ or _issubclass_fast (cls , "numpy" , "generic" )
241
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
242
+ or _issubclass_fast (cls , "torch" , "Tensor" )
243
+ or _issubclass_fast (cls , "dask.array" , "Array" )
244
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
245
+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
246
+ or _issubclass_fast (cls , "jax" , "Array" )
247
+ )
262
248
263
249
264
250
def _compat_module_name () -> str :
265
251
assert __name__ .endswith ('.common._helpers' )
266
252
return __name__ .removesuffix ('.common._helpers' )
267
253
268
254
255
+ @cache
269
256
def is_numpy_namespace (xp : Namespace ) -> bool :
270
257
"""
271
258
Returns True if `xp` is a NumPy namespace.
@@ -287,6 +274,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
287
274
return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
288
275
289
276
277
+ @cache
290
278
def is_cupy_namespace (xp : Namespace ) -> bool :
291
279
"""
292
280
Returns True if `xp` is a CuPy namespace.
@@ -308,6 +296,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
308
296
return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
309
297
310
298
299
+ @cache
311
300
def is_torch_namespace (xp : Namespace ) -> bool :
312
301
"""
313
302
Returns True if `xp` is a PyTorch namespace.
@@ -348,6 +337,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
348
337
return xp .__name__ == 'ndonnx'
349
338
350
339
340
+ @cache
351
341
def is_dask_namespace (xp : Namespace ) -> bool :
352
342
"""
353
343
Returns True if `xp` is a Dask namespace.
@@ -952,4 +942,4 @@ def is_lazy_array(x: object) -> bool:
952
942
"to_device" ,
953
943
]
954
944
955
- _all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
945
+ _all_ignore = ['cache' , ' sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments