diff --git a/src/quaxed/lax/__init__.py b/src/quaxed/lax/__init__.py index 383a4a7..a5e0871 100644 --- a/src/quaxed/lax/__init__.py +++ b/src/quaxed/lax/__init__.py @@ -172,6 +172,19 @@ "with_sharding_constraint", # ----- Linear Algebra Operators ----- "linalg", + # ----- Argument classes ----- + "ConvDimensionNumbers", + "ConvGeneralDilatedDimensionNumbers", + "DotAlgorithm", + "DotAlgorithmPreset", + "FftType", + "GatherDimensionNumbers", + "GatherScatterMode", + "Precision", + "PrecisionLike", + "RandomAlgorithm", + "RoundingMethod", + "ScatterDimensionNumbers", ] @@ -184,6 +197,23 @@ from . import linalg +# Explicit imports that don't need to be quaxified +# isort: split +from jax.lax import ( + ConvDimensionNumbers, + ConvGeneralDilatedDimensionNumbers, + DotAlgorithm, + DotAlgorithmPreset, + FftType, + GatherDimensionNumbers, + GatherScatterMode, + Precision, + PrecisionLike, + RandomAlgorithm, + RoundingMethod, + ScatterDimensionNumbers, +) + def __dir__() -> list[str]: """List the module contents."""