Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: misc #73

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ If you found this library to be useful and want to support the development and
maintenance of lower-level utility libraries for the scientific community,
please consider citing this work.

<!-- SPHINX-START -->

<!-- prettier-ignore-start -->
[actions-badge]: https://github.com/GalacticDynamics/quaxed/workflows/CI/badge.svg
[actions-link]: https://github.com/GalacticDynamics/quaxed/actions
<!-- [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
[github-discussions-link]: https://github.com/GalacticDynamics/quaxed/discussions -->
[pypi-link]: https://pypi.org/project/quaxed/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/quaxed
[pypi-version]: https://img.shields.io/pypi/v/quaxed
Expand Down
20 changes: 9 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,19 @@ test = [
"pytest-github-actions-annotate-failures", # only applies to GH Actions
]
docs = [
"mkdocs==1.3.0", # Main documentation generator.
"mkdocs-material==7.3.6", # Theme
"griffe < 1.0", # For Python structure signatures"
"mkdocs==1.6.0", # Main documentation generator.
"mkdocs-material==9.5", # Theme
"mkdocs_include_exclude_files==0.0.1", # Tweak which files are included/excluded
"mkdocstrings[python] >= 0.18", # Autogenerate documentation from docstrings.
"mknotebooks==0.7.1", # Turn Jupyter Lab notebooks into webpages.
"nbconvert==6.5.0",
"pygments==2.14.0",
"pymdown-extensions==9.4", # Markdown extensions e.g. to handle LaTeX.
"mknotebooks==0.8", # Turn Jupyter Lab notebooks into webpages.
"nbconvert==7.16",
"pygments==2.16",
"pymdown-extensions==10.2", # Markdown extensions e.g. to handle LaTeX.
"pytkdocs_tweaks==0.0.8", # Tweaks mkdocstrings to improve various aspects
"jinja2==3.0.3" # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings.
]
dev = [
"quaxed[test]",
"quaxed[docs]",
"jinja2==3.1"
]
dev = ["quaxed[test,docs]"]

[project.urls]
Homepage = "https://github.com/GalacticDynamics/quaxed"
Expand Down
13 changes: 3 additions & 10 deletions src/quaxed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


from collections.abc import Sequence
from typing import Literal
from typing import Any, Literal

import jax.numpy as jnp
from jax.experimental import array_api
Expand Down Expand Up @@ -167,12 +167,5 @@


@quaxify
def vector_norm(
x: ArrayLike,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
ord: int | float = 2, # pylint: disable=redefined-builtin
) -> Value:
return array_api.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
def vector_norm(x: ArrayLike, /, **kwargs: Any) -> Value:
return array_api.linalg.vector_norm(x, **kwargs)

Check warning on line 171 in src/quaxed/array_api/linalg.py

View check run for this annotation

Codecov / codecov/patch

src/quaxed/array_api/linalg.py#L171

Added line #L171 was not covered by tests
2 changes: 2 additions & 0 deletions src/quaxed/numpy/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def __dir__() -> list[str]:
"euler_gamma",
"flexible",
"floating",
"float32",
"float64",
"generic",
"index_exp",
"indices",
Expand Down
22 changes: 3 additions & 19 deletions tests/myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import jax
import jax.experimental.array_api as jax_xp
from jax import Device, lax
from jax._src.lax.lax import DotDimensionNumbers, PrecisionLike
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
from jax._src.typing import DTypeLike, Shape
from jax._src.typing import Shape
from jaxtyping import ArrayLike
from quax import ArrayValue, register

Expand Down Expand Up @@ -424,23 +423,8 @@ def _div_p(x: MyArray, y: ArrayLike) -> MyArray:


@register(lax.dot_general_p) # TODO: implement
def _dot_general_p(
lhs: MyArray,
rhs: MyArray,
*,
dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
) -> MyArray:
return MyArray(
lax.dot_general_p.bind(
lhs.array,
rhs.array,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
),
)
def _dot_general_p(lhs: MyArray, rhs: MyArray, **kwargs: Any) -> MyArray:
return MyArray(lax.dot_general_p.bind(lhs.array, rhs.array, **kwargs))


# ==============================================================================
Expand Down