Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quaxedfuncs #33

Merged
merged 6 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ ignore = [
"docs/conf.py" = ["INP001"]
"scratch/**" = ["ANN", "D", "FBT", "INP"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"


[tool.pylint]
py-version = "3.10"
Expand Down
16 changes: 13 additions & 3 deletions src/quaxed/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def grad( # noqa: PLR0913
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
filter_spec: Any = True,
) -> Callable[..., Any]:
"""Quaxed version of :func:`jax.grad`."""
return quaxify(
Expand All @@ -38,7 +39,8 @@ def grad( # noqa: PLR0913
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
),
filter_spec=filter_spec,
)


Expand All @@ -48,9 +50,13 @@ def hessian(
*,
has_aux: bool = False,
holomorphic: bool = False,
filter_spec: Any = True,
) -> Callable[..., Any]:
"""Quaxed version of :func:`jax.hessian`."""
return quaxify(jax.hessian(fun, argnums, holomorphic=holomorphic, has_aux=has_aux))
return quaxify(
jax.hessian(fun, argnums, holomorphic=holomorphic, has_aux=has_aux),
filter_spec=filter_spec,
)


def jacfwd(
Expand All @@ -59,6 +65,10 @@ def jacfwd(
*,
has_aux: bool = False,
holomorphic: bool = False,
filter_spec: Any = True,
) -> Callable[..., Any]:
"""Quaxed version of :func:`jax.jacfwd`."""
return quaxify(jax.jacfwd(fun, argnums, holomorphic=holomorphic, has_aux=has_aux))
return quaxify(
jax.jacfwd(fun, argnums, holomorphic=holomorphic, has_aux=has_aux),
filter_spec=filter_spec,
)
193 changes: 188 additions & 5 deletions src/quaxed/numpy/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,43 @@
__all__ = [
"allclose",
"array_equal",
"asarray",
"cbrt",
"copy",
"equal",
"exp2",
"expand_dims",
"greater",
"hypot",
"isclose",
"matmul",
"moveaxis",
"squeeze",
"trace",
"vectorize",
]

from collections.abc import Callable, Iterable
import functools
from collections.abc import Callable, Collection
from typing import Any, TypeVar

import jax
import jax.numpy as jnp
from jax._src.numpy.vectorize import (
_apply_excluded,
_check_output_dims,
_parse_gufunc_signature,
_parse_input_dimensions,
)
from quax import quaxify

T = TypeVar("T")


# ============================================================================
# Helper functions


def _doc(jax_func: Callable[..., Any]) -> Callable[[T], T]:
"""Copy docstrings from JAX functions."""

Expand All @@ -37,27 +52,195 @@ def transfer_doc(func: T) -> T:


##############################################################################
# Quaxified `jax.numpy` namespace


allclose = quaxify(jnp.allclose)
array_equal = quaxify(jnp.array_equal)
asarray = quaxify(jnp.asarray)
cbrt = quaxify(jnp.cbrt)
copy = quaxify(jnp.copy)
equal = quaxify(jnp.equal)
exp2 = quaxify(jnp.exp2)
expand_dims = quaxify(jnp.expand_dims)
greater = quaxify(jnp.greater)
hypot = quaxify(jnp.hypot)
isclose = quaxify(jnp.isclose)
matmul = quaxify(jnp.matmul)
moveaxis = quaxify(jnp.moveaxis)
squeeze = quaxify(jnp.squeeze)
trace = quaxify(jnp.trace)


@_doc(jnp.vectorize)
def vectorize(
# =====================================
# `jax.numpy.vectorize`


def vectorize( # noqa: C901
pyfunc: Callable[..., Any],
*,
excluded: Iterable[int] = frozenset(),
excluded: Collection[int | str] = frozenset(),
signature: str | None = None,
) -> Callable[..., Any]:
return quaxify(jnp.vectorize(pyfunc, excluded=excluded, signature=signature))
"""Define a vectorized function with broadcasting.

This is a copy-paste from :func:`jax.numpy.vectorize`, but the internals are
all replaced with their :mod:`quaxed` counterparts to allow quax-friendly
objects to pass through. The only thing that isn't quaxed is `jax.vmap`,
which allows any array-like object to pass through without converting it.
Note that this behaviour is DIFFERENT than doing ``quaxify(jnp.vectorize)``
since `quaxify` makes objects look like arrays, not their actual type,
which can be problematic. This function passes through the objects
unchanged (so long as they are amenable to the reshapes and ``vamap``).

Arguments:
---------
pyfunc: callable
function to vectorize.
excluded: Collection[int | str], optional
optional set of integers representing positional arguments for which the
function will not be vectorized. These will be passed directly to
``pyfunc`` unmodified.
signature: str | None
optional generalized universal function signature, e.g.,
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.

Returns
-------
callable
Vectorized version of the given function.

Here are a few examples of how one could write vectorized linear algebra
routines using :func:`vectorize`:

>>> from functools import partial

>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
... assert a.shape == b.shape and a.ndim == b.ndim == 1
... return jnp.array([a[1] * b[2] - a[2] * b[1],
... a[2] * b[0] - a[0] * b[2],
... a[0] * b[1] - a[1] * b[0]])

>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
... return matrix @ vector

These functions are only written to handle 1D or 2D arrays (the ``assert``
statements will never be violated), but with vectorize they support
arbitrary dimensional inputs with NumPy style broadcasting, e.g.,

>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)

Note that this has different semantics than `jnp.matmul`:

>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
""" # noqa: E501
if any(not isinstance(exclude, str | int) for exclude in excluded):
msg = (
"jax.numpy.vectorize can only exclude integer or string arguments, "
f"but excluded={excluded!r}"
)
raise TypeError(msg)

if any(isinstance(e, int) and e < 0 for e in excluded):
msg = f"excluded={excluded!r} contains negative numbers"
raise ValueError(msg)

@functools.wraps(pyfunc)
def wrapped(*args: Any, **kwargs: Any) -> Any:
error_context = (
f"on vectorized function with excluded={excluded!r} and "
f"signature={signature!r}"
)
excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs)

if signature is not None:
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
else:
input_core_dims = [()] * len(args)
output_core_dims = None

none_args = {i for i, arg in enumerate(args) if arg is None}
if any(none_args):
if any(input_core_dims[i] != () for i in none_args):
msg = f"Cannot pass None at locations {none_args} with {signature=}"
raise ValueError(msg)
excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {})
input_core_dims = [
dim for i, dim in enumerate(input_core_dims) if i not in none_args
]

args = tuple(map(asarray, args))

broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims, error_context
)

checked_func = _check_output_dims(
excluded_func, dim_sizes, output_core_dims, error_context
)

# Rather than broadcasting all arguments to full broadcast shapes, prefer
# expanding dimensions using vmap. By pushing broadcasting
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.

squeezed_args = []
rev_filled_shapes = []

for arg, core_dims in zip(args, input_core_dims, strict=True):
noncore_shape = arg.shape[: arg.ndim - len(core_dims)]

pad_ndim = len(broadcast_shape) - len(noncore_shape)
filled_shape = pad_ndim * (1,) + noncore_shape
rev_filled_shapes.append(filled_shape[::-1])

squeeze_indices = tuple(
i for i, size in enumerate(noncore_shape) if size == 1
)
squeezed_arg = squeeze(arg, axis=squeeze_indices)
squeezed_args.append(squeezed_arg)

vectorized_func = checked_func
dims_to_expand = []
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes, strict=True)):
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
if all(axis is None for axis in in_axes):
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
else:
vectorized_func = jax.vmap(vectorized_func, in_axes)
result = vectorized_func(*squeezed_args)

if not dims_to_expand:
out = result
elif isinstance(result, tuple):
out = tuple(expand_dims(r, axis=dims_to_expand) for r in result)
else:
out = expand_dims(result, axis=dims_to_expand)

return out

return wrapped