diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a383cb..a9960a4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.4.7 hooks: - id: ruff # linter types_or: [ python, pyi, jupyter ] @@ -8,7 +8,7 @@ repos: - id: ruff-format # formatter types_or: [ python, pyi, jupyter ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.315 + rev: v1.1.365 hooks: - id: pyright additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping", "plum-dispatch"] diff --git a/docs/examples/redispatch.ipynb b/docs/examples/redispatch.ipynb index 996ea0a..75bef2d 100644 --- a/docs/examples/redispatch.ipynb +++ b/docs/examples/redispatch.ipynb @@ -354,8 +354,9 @@ "outputs": [], "source": [ "@quax.register(jax.lax.dot_general_p)\n", - "def _(x: LoraArray, y: SomeKindOfSparseVector, *, dimension_numbers, **params):\n", - " ... # some implementation here" + "def _(\n", + " x: LoraArray, y: SomeKindOfSparseVector, *, dimension_numbers, **params\n", + "): ... # some implementation here" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 699c3fd..0f5ac06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,13 +35,13 @@ include = ["quax/*"] [tool.pytest.ini_options] addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" -[tool.ruff] +[tool.ruff.lint] select = ["E", "F", "I001"] ignore = ["E402", "E721", "E731", "E741", "F722"] ignore-init-module-imports = true fixable = ["I001", "F401"] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 extra-standard-library = ["typing_extensions"] diff --git a/quax/examples/zero/_core.py b/quax/examples/zero/_core.py index fe4c442..407141a 100644 --- a/quax/examples/zero/_core.py +++ b/quax/examples/zero/_core.py @@ -171,7 +171,7 @@ def _(lhs: Zero, rhs: Zero, **kwargs) -> Zero: def _integer_pow(x: Zero, *, y: int) -> Union[Array, Zero]: # Zero is a special case, because 0^0 = 1. if y == 0: - return jnp.ones(x.shape, x.dtype) + return jnp.ones(x.shape, x.dtype) # pyright: ignore # Otherwise, we can just return a zero. # Inf and NaN are not integers, so we don't need to worry about them.