From e247b49cabb57a118ef50a4ecdbac9af96eb246b Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 16 Sep 2024 23:07:01 -0400 Subject: [PATCH] feat: deprecate array_api for jax v0.4.32+ (#79) Signed-off-by: nstarman --- conftest.py | 24 ++++++++++++++++++ pyproject.toml | 2 ++ src/quaxed/__init__.py | 11 +++++++-- src/quaxed/_setup.py | 5 ++++ src/quaxed/_utils.py | 6 ++--- src/quaxed/array_api/_data_type_functions.py | 3 ++- src/quaxed/numpy/__init__.py | 26 ++++++++++++++------ src/quaxed/numpy/_core.py | 1 + src/quaxed/numpy/_higher_order.py | 1 + tests/numpy/test_jax.py | 3 ++- 10 files changed, 67 insertions(+), 15 deletions(-) create mode 100644 conftest.py create mode 100644 src/quaxed/_setup.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..677ccdd --- /dev/null +++ b/conftest.py @@ -0,0 +1,24 @@ +"""Doctest configuration.""" + +import platform +from doctest import ELLIPSIS, NORMALIZE_WHITESPACE + +from sybil import Sybil +from sybil.parsers.rest import DocTestParser, PythonCodeBlockParser, SkipParser + +# TODO: stop skipping doctests on Windows when there is uniform support for +# numpy 2.0+ scalar repr. On windows it is printed as 1.0 instead of +# `np.float64(1.0)`. +parsers = ( + [DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE)] + if platform.system() != "Windows" + else [] +) + [ + PythonCodeBlockParser(), + SkipParser(), +] + +pytest_collect_file = Sybil( + parsers=parsers, + patterns=["*.rst", "*.py"], +).pytest() diff --git a/pyproject.toml b/pyproject.toml index 5839daf..304b68c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ test = [ "pytest >=6", "pytest-cov >=3", "pytest-github-actions-annotate-failures", # only applies to GH Actions + "sybil >= 7.1.0", ] docs = [ "griffe < 1.0", # For Python structure signatures" @@ -82,6 +83,7 @@ filterwarnings = [ "ignore:ast\\.Str is deprecated and will be removed in Python 3.14:DeprecationWarning", # jax "ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning", + "ignore:jax\\.experimental\\.array_api import is no longer required as of JAX v0\\.4\\.32" ] log_cli_level = "INFO" testpaths = ["tests"] diff --git a/src/quaxed/__init__.py b/src/quaxed/__init__.py index 27c1745..2f6a9cb 100644 --- a/src/quaxed/__init__.py +++ b/src/quaxed/__init__.py @@ -5,7 +5,7 @@ # pylint: disable=redefined-builtin -__all__ = ["__version__", "array_api", "lax", "scipy"] +__all__ = ["__version__", "lax", "scipy"] import sys from typing import Any @@ -13,12 +13,19 @@ import plum from jaxtyping import ArrayLike -from . import _jax, array_api, lax, scipy +from . import _jax, lax, scipy from ._jax import * +from ._setup import JAX_VERSION from ._version import version as __version__ __all__ += _jax.__all__ +if JAX_VERSION < (0, 4, 32): + from . import array_api + + __all__ += ["array_api"] + + # Simplify the display of ArrayLike plum.activate_union_aliases() plum.set_union_alias(ArrayLike, "ArrayLike") diff --git a/src/quaxed/_setup.py b/src/quaxed/_setup.py new file mode 100644 index 0000000..1e609e6 --- /dev/null +++ b/src/quaxed/_setup.py @@ -0,0 +1,5 @@ +"""Setup file for the Quaxed package.""" + +from importlib.metadata import version + +JAX_VERSION: tuple[int, ...] = tuple(map(int, version("jax").split("."))) diff --git a/src/quaxed/_utils.py b/src/quaxed/_utils.py index f0009fd..8482902 100644 --- a/src/quaxed/_utils.py +++ b/src/quaxed/_utils.py @@ -1,4 +1,5 @@ -from importlib.metadata import version +"""Utility functions for quaxed.""" + from typing import TypeVar import quax @@ -9,6 +10,3 @@ def quaxify(func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T: """Quaxify, but makes mypy happy.""" return quax.quaxify(func, filter_spec=filter_spec) - - -JAX_VERSION: tuple[int, ...] = tuple(map(int, version("jax").split("."))) diff --git a/src/quaxed/array_api/_data_type_functions.py b/src/quaxed/array_api/_data_type_functions.py index d459163..53fb12e 100644 --- a/src/quaxed/array_api/_data_type_functions.py +++ b/src/quaxed/array_api/_data_type_functions.py @@ -4,8 +4,9 @@ from jax.experimental import array_api from jaxtyping import ArrayLike +from quaxed._setup import JAX_VERSION from quaxed._types import DType -from quaxed._utils import JAX_VERSION, quaxify +from quaxed._utils import quaxify if JAX_VERSION < (0, 4, 31): from jax.experimental.array_api._data_type_functions import FInfo diff --git a/src/quaxed/numpy/__init__.py b/src/quaxed/numpy/__init__.py index ba3391e..a0065cd 100644 --- a/src/quaxed/numpy/__init__.py +++ b/src/quaxed/numpy/__init__.py @@ -1,18 +1,30 @@ """Quaxed :mod:`jax.numpy`.""" - # pylint: disable=redefined-builtin +from typing import Any + from jaxtyping import install_import_hook -with install_import_hook("quaxed.numpy", None): - from . import _core, _creation_functions, _dispatch, _higher_order - from ._core import * # TODO: make this lazy - from ._creation_functions import * - from ._dispatch import * - from ._higher_order import * +from . import _core, _creation_functions, _dispatch, _higher_order +from ._creation_functions import * +from ._dispatch import * +from ._higher_order import * __all__: list[str] = [] __all__ += _core.__all__ __all__ += _higher_order.__all__ __all__ += _creation_functions.__all__ __all__ += _dispatch.__all__ + + +# TODO: consolidate with ``_core.__getattr__``. +def __getattr__(name: str) -> Any: + if name in __all__: + return getattr(_core, name) + + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) + + +# TODO: figure out how to install this import hook, with the __getattr__. +install_import_hook("quaxed.numpy", None) diff --git a/src/quaxed/numpy/_core.py b/src/quaxed/numpy/_core.py index 04beb5a..ec32f61 100644 --- a/src/quaxed/numpy/_core.py +++ b/src/quaxed/numpy/_core.py @@ -1,5 +1,6 @@ # ruff: noqa: F822 """Quaxed :mod:`jax.numpy`.""" +# pylint: disable=undefined-all-variable __all__ = [ # modules diff --git a/src/quaxed/numpy/_higher_order.py b/src/quaxed/numpy/_higher_order.py index 99f7162..bf4f3b1 100644 --- a/src/quaxed/numpy/_higher_order.py +++ b/src/quaxed/numpy/_higher_order.py @@ -60,6 +60,7 @@ def vectorize( # noqa: C901 routines using :func:`vectorize`: >>> from functools import partial + >>> import quaxed.numpy as jnp >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') ... def cross_product(a, b): diff --git a/tests/numpy/test_jax.py b/tests/numpy/test_jax.py index 259c80b..bd11951 100644 --- a/tests/numpy/test_jax.py +++ b/tests/numpy/test_jax.py @@ -1672,7 +1672,8 @@ def test_round(x1): def test_round_(x1): """Test `quaxed.numpy.round_`.""" - assert jnp.all(qnp.round_(x1) == jnp.round_(x1)) + with pytest.deprecated_call(): + assert jnp.all(qnp.round_(x1) == jnp.round_(x1)) @pytest.mark.skip("TODO")