Skip to content

Commit 632f081

Browse files
committed
ENH: cache helper functions
1 parent 16978e6 commit 632f081

File tree

1 file changed

+55
-65
lines changed

1 file changed

+55
-65
lines changed

array_api_compat/common/_helpers.py

+55-65
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,42 @@
1111
import math
1212
import inspect
1313
import warnings
14+
from functools import cache
1415
from typing import Optional, Union, Any
1516

1617
from ._typing import Array, Device, Namespace
1718

1819

19-
def _is_jax_zero_gradient_array(x: object) -> bool:
20+
@cache
21+
def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
22+
try:
23+
mod = sys.modules[modname]
24+
except KeyError:
25+
return False
26+
parent_cls = getattr(mod, clsname)
27+
return issubclass(cls, parent_cls)
28+
29+
30+
def _is_jax_zero_gradient_array(x: Array) -> bool:
2031
"""Return True if `x` is a zero-gradient array.
2132
2233
These arrays are a design quirk of Jax that may one day be removed.
2334
See https://github.com/google/jax/issues/20620.
2435
"""
25-
if 'numpy' not in sys.modules or 'jax' not in sys.modules:
36+
# Fast exit
37+
try:
38+
dtype = x.dtype
39+
except AttributeError:
40+
return False
41+
if not _issubclass_fast(type(dtype), "numpy.dtypes", "VoidDType"):
2642
return False
2743

28-
import numpy as np
29-
import jax
44+
if "jax" not in sys.modules:
45+
return False
3046

31-
return isinstance(x, np.ndarray) and x.dtype == jax.float0
47+
import jax
48+
# jax.float0 is a np.dtype([('float0', 'V')])
49+
return dtype == jax.float0
3250

3351

3452
def is_numpy_array(x: object) -> bool:
@@ -52,15 +70,12 @@ def is_numpy_array(x: object) -> bool:
5270
is_jax_array
5371
is_pydata_sparse_array
5472
"""
55-
# Avoid importing NumPy if it isn't already
56-
if 'numpy' not in sys.modules:
57-
return False
58-
59-
import numpy as np
60-
6173
# TODO: Should we reject ndarray subclasses?
62-
return (isinstance(x, (np.ndarray, np.generic))
63-
and not _is_jax_zero_gradient_array(x))
74+
cls = type(x)
75+
return (
76+
_issubclass_fast(cls, "numpy", "ndarray")
77+
or _issubclass_fast(cls, "numpy", "generic")
78+
) and not _is_jax_zero_gradient_array(x)
6479

6580

6681
def is_cupy_array(x: object) -> bool:
@@ -84,14 +99,7 @@ def is_cupy_array(x: object) -> bool:
8499
is_jax_array
85100
is_pydata_sparse_array
86101
"""
87-
# Avoid importing CuPy if it isn't already
88-
if 'cupy' not in sys.modules:
89-
return False
90-
91-
import cupy as cp
92-
93-
# TODO: Should we reject ndarray subclasses?
94-
return isinstance(x, cp.ndarray)
102+
return _issubclass_fast(type(x), "cupy", "ndarray")
95103

96104

97105
def is_torch_array(x: object) -> bool:
@@ -112,14 +120,7 @@ def is_torch_array(x: object) -> bool:
112120
is_jax_array
113121
is_pydata_sparse_array
114122
"""
115-
# Avoid importing torch if it isn't already
116-
if 'torch' not in sys.modules:
117-
return False
118-
119-
import torch
120-
121-
# TODO: Should we reject ndarray subclasses?
122-
return isinstance(x, torch.Tensor)
123+
return _issubclass_fast(type(x), "torch", "Tensor")
123124

124125

125126
def is_ndonnx_array(x: object) -> bool:
@@ -141,13 +142,7 @@ def is_ndonnx_array(x: object) -> bool:
141142
is_jax_array
142143
is_pydata_sparse_array
143144
"""
144-
# Avoid importing torch if it isn't already
145-
if 'ndonnx' not in sys.modules:
146-
return False
147-
148-
import ndonnx as ndx
149-
150-
return isinstance(x, ndx.Array)
145+
return _issubclass_fast(type(x), "ndonnx", "Array")
151146

152147

153148
def is_dask_array(x: object) -> bool:
@@ -169,13 +164,7 @@ def is_dask_array(x: object) -> bool:
169164
is_jax_array
170165
is_pydata_sparse_array
171166
"""
172-
# Avoid importing dask if it isn't already
173-
if 'dask.array' not in sys.modules:
174-
return False
175-
176-
import dask.array
177-
178-
return isinstance(x, dask.array.Array)
167+
return _issubclass_fast(type(x), "dask.array", "Array")
179168

180169

181170
def is_jax_array(x: object) -> bool:
@@ -198,13 +187,7 @@ def is_jax_array(x: object) -> bool:
198187
is_dask_array
199188
is_pydata_sparse_array
200189
"""
201-
# Avoid importing jax if it isn't already
202-
if 'jax' not in sys.modules:
203-
return False
204-
205-
import jax
206-
207-
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
190+
return _issubclass_fast(type(x), "jax", "Array") or _is_jax_zero_gradient_array(x)
208191

209192

210193
def is_pydata_sparse_array(x) -> bool:
@@ -227,14 +210,8 @@ def is_pydata_sparse_array(x) -> bool:
227210
is_dask_array
228211
is_jax_array
229212
"""
230-
# Avoid importing jax if it isn't already
231-
if 'sparse' not in sys.modules:
232-
return False
233-
234-
import sparse
235-
236213
# TODO: Account for other backends.
237-
return isinstance(x, sparse.SparseArray)
214+
return _issubclass_fast(type(x), "sparse", "SparseArray")
238215

239216

240217
def is_array_api_obj(x: object) -> bool:
@@ -252,20 +229,30 @@ def is_array_api_obj(x: object) -> bool:
252229
is_dask_array
253230
is_jax_array
254231
"""
255-
return is_numpy_array(x) \
256-
or is_cupy_array(x) \
257-
or is_torch_array(x) \
258-
or is_dask_array(x) \
259-
or is_jax_array(x) \
260-
or is_pydata_sparse_array(x) \
261-
or hasattr(x, '__array_namespace__')
232+
return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x))
233+
234+
235+
@cache
236+
def _is_array_api_cls(cls: type) -> bool:
237+
return (
238+
# TODO: drop support for numpy<2 which didn't have __array_namespace__
239+
_issubclass_fast(cls, "numpy", "ndarray")
240+
or _issubclass_fast(cls, "numpy", "generic")
241+
or _issubclass_fast(cls, "cupy", "ndarray")
242+
or _issubclass_fast(cls, "torch", "Tensor")
243+
or _issubclass_fast(cls, "dask.array", "Array")
244+
or _issubclass_fast(cls, "sparse", "SparseArray")
245+
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
246+
or _issubclass_fast(cls, "jax", "Array")
247+
)
262248

263249

264250
def _compat_module_name() -> str:
265251
assert __name__.endswith('.common._helpers')
266252
return __name__.removesuffix('.common._helpers')
267253

268254

255+
@cache
269256
def is_numpy_namespace(xp: Namespace) -> bool:
270257
"""
271258
Returns True if `xp` is a NumPy namespace.
@@ -287,6 +274,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
287274
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
288275

289276

277+
@cache
290278
def is_cupy_namespace(xp: Namespace) -> bool:
291279
"""
292280
Returns True if `xp` is a CuPy namespace.
@@ -308,6 +296,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
308296
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
309297

310298

299+
@cache
311300
def is_torch_namespace(xp: Namespace) -> bool:
312301
"""
313302
Returns True if `xp` is a PyTorch namespace.
@@ -348,6 +337,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
348337
return xp.__name__ == 'ndonnx'
349338

350339

340+
@cache
351341
def is_dask_namespace(xp: Namespace) -> bool:
352342
"""
353343
Returns True if `xp` is a Dask namespace.
@@ -952,4 +942,4 @@ def is_lazy_array(x: object) -> bool:
952942
"to_device",
953943
]
954944

955-
_all_ignore = ['sys', 'math', 'inspect', 'warnings']
945+
_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings']

0 commit comments

Comments
 (0)