Skip to content

Commit 9c0ea9c

Browse files
committed
Cache is_array_api_object
1 parent eb7e95e commit 9c0ea9c

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

array_api_compat/common/_helpers.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,11 @@ def is_array_api_obj(x: object) -> bool:
229229
is_dask_array
230230
is_jax_array
231231
"""
232-
if hasattr(x, '__array_namespace__'):
233-
return True
232+
return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x))
234233

235-
cls = type(x)
234+
235+
@cache
236+
def _is_array_api_cls(cls: type) -> bool:
236237
return (
237238
# TODO: drop support for numpy<2 which didn't have __array_namespace__
238239
_issubclass_fast(cls, "numpy", "ndarray")

0 commit comments

Comments
 (0)