Skip to content

Commit

Permalink
feat: lax (#67)
Browse files Browse the repository at this point in the history
* feat: lax.dot
* feat: full lax
* tests: add

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Aug 2, 2024
1 parent 77537a5 commit efbb6d0
Show file tree
Hide file tree
Showing 12 changed files with 1,907 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ repos:
- pytest
exclude: |
(?x)^(
src/quaxed/lax/__init__.py|
src/quaxed/lax/linalg.py|
src/quaxed/numpy/__init__.py|
src/quaxed/operator.py|
src/quaxed/scipy/special.py
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["ANN", "INP001", "PLR0913", "S101", "T20", "TID252"]
"tests/**" = ["ANN", "INP001", "PLR0913", "PLR2004", "S101", "T20", "TID252"]
"__init__.py" = ["F403"]
"noxfile.py" = ["T20"]
"docs/conf.py" = ["INP001"]
Expand Down
4 changes: 2 additions & 2 deletions src/quaxed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

# pylint: disable=redefined-builtin

__all__ = ["__version__", "array_api", "scipy"]
__all__ = ["__version__", "array_api", "lax", "scipy"]

import sys
from typing import Any

import plum
from jaxtyping import ArrayLike

from . import _jax, array_api, scipy
from . import _jax, array_api, lax, scipy
from ._jax import *
from ._version import version as __version__

Expand Down
37 changes: 0 additions & 37 deletions src/quaxed/lax.py

This file was deleted.

201 changes: 201 additions & 0 deletions src/quaxed/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Quaxed :mod:`jax.lax`."""
# pylint: disable=undefined-all-variable

__all__ = [
# ----- Operators -----
"abs",
"acos",
"acosh",
"add",
# "after_all",
"approx_max_k",
"approx_min_k",
"argmax",
"argmin",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"batch_matmul",
"bessel_i0e",
"bessel_i1e",
"betainc",
"bitcast_convert_type",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bitwise_xor",
"population_count",
"broadcast",
"broadcast_in_dim",
"broadcast_shapes",
"broadcast_to_rank",
"broadcasted_iota",
"cbrt",
"ceil",
"clamp",
"clz",
"collapse",
"complex",
"concatenate",
"conj",
"conv",
"convert_element_type",
"conv_dimension_numbers",
"conv_general_dilated",
"conv_general_dilated_local",
"conv_general_dilated_patches",
"conv_transpose",
"conv_with_general_padding",
"cos",
"cosh",
"cumlogsumexp",
"cummax",
"cummin",
"cumprod",
"cumsum",
"digamma",
"div",
"dot",
"dot_general",
"dynamic_index_in_dim",
"dynamic_slice",
"dynamic_slice_in_dim",
"dynamic_update_index_in_dim",
"dynamic_update_slice",
"dynamic_update_slice_in_dim",
"eq",
"erf",
"erfc",
"erf_inv",
"exp",
"expand_dims",
"expm1",
"fft",
"floor",
"full",
"full_like",
"gather",
"ge",
"gt",
"igamma",
"igammac",
"imag",
"index_in_dim",
"index_take",
"integer_pow",
"iota",
"is_finite",
"le",
"lgamma",
"log",
"log1p",
"logistic",
"lt",
"max",
"min",
"mul" "ne",
"neg",
"nextafter",
"pad",
"polygamma",
"population_count",
"pow",
"random_gamma_grad",
"real",
"reciprocal",
"reduce",
"reduce_precision",
"reduce_window",
"rem",
"reshape",
"rev",
"rng_bit_generator",
"rng_uniform",
"round",
"rsqrt",
"scatter",
"scatter_add",
"scatter_apply",
"scatter_max",
"scatter_min",
"scatter_mul",
"shift_left",
"shift_right_arithmetic",
"shift_right_logical",
"sign",
"sin",
"sinh",
"slice",
"slice_in_dim",
"sort",
"sort_key_val",
"sqrt",
"square",
"squeeze",
"sub",
"tan",
"tanh",
"top_k",
"transpose",
"zeros_like_array",
"zeta",
# ----- Control Flow Operators -----
"associative_scan",
"cond",
"fori_loop",
"map",
"scan",
"select",
"select_n",
"switch",
"while_loop",
# ----- Custom Gradient Operators -----
"stop_gradient",
"custom_linear_solve",
"custom_root",
# ----- Parallel Operators -----
"all_gather",
"all_to_all",
"psum",
"psum_scatter",
"pmax",
"pmin",
"pmean",
"ppermute",
"pshuffle",
"pswapaxes",
"axis_index",
# ----- Sharding-related Operators -----
"with_sharding_constraint",
# ----- Linear Algebra Operators -----
"linalg",
]


import sys
from collections.abc import Callable
from typing import Any

from jax import lax
from quax import quaxify

from . import linalg


def __dir__() -> list[str]:
"""List the operators."""
return sorted(__all__)


# TODO: return type hint signature
def __getattr__(name: str) -> Callable[..., Any]:
"""Get the operator."""
# Quaxify the operator
out = quaxify(getattr(lax, name))

# Cache the function in this module
setattr(sys.modules[__name__], name, out)

return out
Loading

0 comments on commit efbb6d0

Please sign in to comment.