Skip to content

ENH: cache helper functions #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 21, 2025
Merged

ENH: cache helper functions #308

merged 7 commits into from
Apr 21, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 15, 2025

Speed up helper functions through caching.

>>> import sys
>>> import array_api_compat.numpy as np
>>> import array_api_compat.torch as torch
>>> from array_api_compat import is_numpy_namespace, is_numpy_array, is_torch_array
>>> a = np.asarray(1)
>>> b = torch.asarray(1)

>>> %timeit is_numpy_namespace(np)
BEFORE 333 ns ± 1.83 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 62 ns ± 0.227 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_namespace(sys)
BEFORE 334 ns ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 69.3 ns ± 1.18 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_array(a)
BEFORE 382 ns ± 2.01 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 281 ns ± 4.03 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_numpy_array(1)
BEFORE 272 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 175 ns ± 5.22 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_array([1])
BEFORE 288 ns ± 1.16 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 213 ns ± 7.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_torch_array(b)
BEFORE 214 ns ± 0.244 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 126 ns ± 1.05 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_torch_array(1)
BEFORE 249 ns ± 4.49 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 121 ns ± 0.724 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_array_api_obj(a)
BEFORE 423 ns ± 1.35 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 99.5 ns ± 1.47 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_array_api_obj(1)
BEFORE 773 ns ± 2.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 142 ns ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_lazy_array(a)
BEFORE 437 ns ± 7.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 381 ns ± 5.56 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_lazy_array(1)
BEFORE 2.12 μs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
AFTER 624 ns ± 8.06 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_writeable_array(a)
BEFORE 491 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 183 ns ± 0.446 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_writeable_array(b)
BEFORE 1.15 μs ± 5.57 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 185 ns ± 0.766 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_writeable_array(1)
BEFORE 1.52 μs ± 13.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 189 ns ± 5.21 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Note: is_numpy_array and is_jax_array are slower than the other equivalent functions due to the reclassification of JAX zero gradient arrays.

@Copilot Copilot AI review requested due to automatic review settings April 15, 2025 16:27
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (1)

array_api_compat/common/_helpers.py:944

  • [nitpick] Consider adding a comment to explain why 'cache' is included in _all_ignore to improve clarity for future maintainers.
_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings']

from typing import Optional, Union, Any

from ._typing import Array, Device, Namespace


def _is_jax_zero_gradient_array(x: object) -> bool:
@cache
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in theory could lead to a memory leak for a user that somehow dynamically defines and then forgets a lot of classes. I don't think it's something to worry about in real life?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we limit the cache size just not even think about this possibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. @lru_cache(100) costs just 4ns more than @cache as long as it doesn't need to evict anything (which in 99% of the times it won't happen). Amended.

dtype = x.dtype # type: ignore[attr-defined]
except AttributeError:
return False
cls = cast(Hashable, type(dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't say I'm happy to see cryptic things from typing doing something at runtime, but OK, am ready to believe it's somehow useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast is a noop at runtime.

@ev-br ev-br merged commit 52e01be into data-apis:main Apr 21, 2025
23 checks passed
@ev-br
Copy link
Member

ev-br commented Apr 21, 2025

Merged, thanks @crusaderky

@crusaderky crusaderky deleted the cache_helpers branch April 21, 2025 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants