Skip to content

Commit

Permalink
feat: deprecate array_api for jax v0.4.32+ (#79)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Sep 17, 2024
1 parent 4ffe7f2 commit e247b49
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 15 deletions.
24 changes: 24 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Doctest configuration."""

import platform
from doctest import ELLIPSIS, NORMALIZE_WHITESPACE

from sybil import Sybil
from sybil.parsers.rest import DocTestParser, PythonCodeBlockParser, SkipParser

# TODO: stop skipping doctests on Windows when there is uniform support for
# numpy 2.0+ scalar repr. On windows it is printed as 1.0 instead of
# `np.float64(1.0)`.
parsers = (
[DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE)]
if platform.system() != "Windows"
else []
) + [
PythonCodeBlockParser(),
SkipParser(),
]

pytest_collect_file = Sybil(
parsers=parsers,
patterns=["*.rst", "*.py"],
).pytest()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ test = [
"pytest >=6",
"pytest-cov >=3",
"pytest-github-actions-annotate-failures", # only applies to GH Actions
"sybil >= 7.1.0",
]
docs = [
"griffe < 1.0", # For Python structure signatures"
Expand Down Expand Up @@ -82,6 +83,7 @@ filterwarnings = [
"ignore:ast\\.Str is deprecated and will be removed in Python 3.14:DeprecationWarning",
# jax
"ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning",
"ignore:jax\\.experimental\\.array_api import is no longer required as of JAX v0\\.4\\.32"
]
log_cli_level = "INFO"
testpaths = ["tests"]
Expand Down
11 changes: 9 additions & 2 deletions src/quaxed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@

# pylint: disable=redefined-builtin

__all__ = ["__version__", "array_api", "lax", "scipy"]
__all__ = ["__version__", "lax", "scipy"]

import sys
from typing import Any

import plum
from jaxtyping import ArrayLike

from . import _jax, array_api, lax, scipy
from . import _jax, lax, scipy
from ._jax import *
from ._setup import JAX_VERSION
from ._version import version as __version__

__all__ += _jax.__all__

if JAX_VERSION < (0, 4, 32):
from . import array_api

__all__ += ["array_api"]


# Simplify the display of ArrayLike
plum.activate_union_aliases()
plum.set_union_alias(ArrayLike, "ArrayLike")
Expand Down
5 changes: 5 additions & 0 deletions src/quaxed/_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Setup file for the Quaxed package."""

from importlib.metadata import version

JAX_VERSION: tuple[int, ...] = tuple(map(int, version("jax").split(".")))
6 changes: 2 additions & 4 deletions src/quaxed/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.metadata import version
"""Utility functions for quaxed."""

from typing import TypeVar

import quax
Expand All @@ -9,6 +10,3 @@
def quaxify(func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
"""Quaxify, but makes mypy happy."""
return quax.quaxify(func, filter_spec=filter_spec)


JAX_VERSION: tuple[int, ...] = tuple(map(int, version("jax").split(".")))
3 changes: 2 additions & 1 deletion src/quaxed/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from jax.experimental import array_api
from jaxtyping import ArrayLike

from quaxed._setup import JAX_VERSION
from quaxed._types import DType
from quaxed._utils import JAX_VERSION, quaxify
from quaxed._utils import quaxify

if JAX_VERSION < (0, 4, 31):
from jax.experimental.array_api._data_type_functions import FInfo
Expand Down
26 changes: 19 additions & 7 deletions src/quaxed/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
"""Quaxed :mod:`jax.numpy`."""

# pylint: disable=redefined-builtin

from typing import Any

from jaxtyping import install_import_hook

with install_import_hook("quaxed.numpy", None):
from . import _core, _creation_functions, _dispatch, _higher_order
from ._core import * # TODO: make this lazy
from ._creation_functions import *
from ._dispatch import *
from ._higher_order import *
from . import _core, _creation_functions, _dispatch, _higher_order
from ._creation_functions import *
from ._dispatch import *
from ._higher_order import *

__all__: list[str] = []
__all__ += _core.__all__
__all__ += _higher_order.__all__
__all__ += _creation_functions.__all__
__all__ += _dispatch.__all__


# TODO: consolidate with ``_core.__getattr__``.
def __getattr__(name: str) -> Any:
if name in __all__:
return getattr(_core, name)

msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)


# TODO: figure out how to install this import hook, with the __getattr__.
install_import_hook("quaxed.numpy", None)
1 change: 1 addition & 0 deletions src/quaxed/numpy/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: F822
"""Quaxed :mod:`jax.numpy`."""
# pylint: disable=undefined-all-variable

__all__ = [
# modules
Expand Down
1 change: 1 addition & 0 deletions src/quaxed/numpy/_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def vectorize( # noqa: C901
routines using :func:`vectorize`:
>>> from functools import partial
>>> import quaxed.numpy as jnp
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
Expand Down
3 changes: 2 additions & 1 deletion tests/numpy/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,8 @@ def test_round(x1):

def test_round_(x1):
"""Test `quaxed.numpy.round_`."""
assert jnp.all(qnp.round_(x1) == jnp.round_(x1))
with pytest.deprecated_call():
assert jnp.all(qnp.round_(x1) == jnp.round_(x1))


@pytest.mark.skip("TODO")
Expand Down

0 comments on commit e247b49

Please sign in to comment.