From a2a74837f3826ec3c39e421449118b1151b70c1c Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 22 Dec 2023 13:09:12 -0800 Subject: [PATCH] initial commit Signed-off-by: nstarman --- .copier-answers.yml | 12 + .git_archival.txt | 4 + .gitattributes | 1 + .github/CONTRIBUTING.md | 101 +++++ .github/dependabot.yml | 7 + .github/matchers/pylint.json | 32 ++ .github/workflows/cd.yml | 60 +++ .github/workflows/ci.yml | 69 ++++ .gitignore | 158 +++++++ .pre-commit-config.yaml | 91 +++++ .readthedocs.yml | 18 + LICENSE | 29 ++ README.md | 27 ++ docs/conf.py | 45 ++ docs/index.md | 17 + noxfile.py | 114 ++++++ pyproject.toml | 159 ++++++++ src/array_api_jax_compat/__init__.py | 64 +++ src/array_api_jax_compat/_constants.py | 5 + .../_creation_functions.py | 163 ++++++++ .../_data_type_functions.py | 48 +++ src/array_api_jax_compat/_dispatch.py | 7 + .../_elementwise_functions.py | 386 ++++++++++++++++++ .../_indexing_functions.py | 8 + .../_linear_algebra_functions.py | 35 ++ .../_manipulation_functions.py | 101 +++++ .../_searching_functions.py | 27 ++ src/array_api_jax_compat/_set_functions.py | 27 ++ .../_sorting_functions.py | 31 ++ .../_statistical_functions.py | 101 +++++ src/array_api_jax_compat/_types.py | 33 ++ .../_utility_functions.py | 30 ++ src/array_api_jax_compat/_utils.py | 10 + src/array_api_jax_compat/_version.pyi | 2 + src/array_api_jax_compat/fft.py | 170 ++++++++ src/array_api_jax_compat/linalg.py | 176 ++++++++ src/array_api_jax_compat/py.typed | 0 tests/test_package.py | 10 + 38 files changed, 2378 insertions(+) create mode 100644 .copier-answers.yml create mode 100644 .git_archival.txt create mode 100644 .gitattributes create mode 100644 .github/CONTRIBUTING.md create mode 100644 .github/dependabot.yml create mode 100644 .github/matchers/pylint.json create mode 100644 .github/workflows/cd.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 .readthedocs.yml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docs/conf.py create mode 100644 docs/index.md create mode 100644 noxfile.py create mode 100644 pyproject.toml create mode 100644 src/array_api_jax_compat/__init__.py create mode 100644 src/array_api_jax_compat/_constants.py create mode 100644 src/array_api_jax_compat/_creation_functions.py create mode 100644 src/array_api_jax_compat/_data_type_functions.py create mode 100644 src/array_api_jax_compat/_dispatch.py create mode 100644 src/array_api_jax_compat/_elementwise_functions.py create mode 100644 src/array_api_jax_compat/_indexing_functions.py create mode 100644 src/array_api_jax_compat/_linear_algebra_functions.py create mode 100644 src/array_api_jax_compat/_manipulation_functions.py create mode 100644 src/array_api_jax_compat/_searching_functions.py create mode 100644 src/array_api_jax_compat/_set_functions.py create mode 100644 src/array_api_jax_compat/_sorting_functions.py create mode 100644 src/array_api_jax_compat/_statistical_functions.py create mode 100644 src/array_api_jax_compat/_types.py create mode 100644 src/array_api_jax_compat/_utility_functions.py create mode 100644 src/array_api_jax_compat/_utils.py create mode 100644 src/array_api_jax_compat/_version.pyi create mode 100644 src/array_api_jax_compat/fft.py create mode 100644 src/array_api_jax_compat/linalg.py create mode 100644 src/array_api_jax_compat/py.typed create mode 100644 tests/test_package.py diff --git a/.copier-answers.yml b/.copier-answers.yml new file mode 100644 index 0000000..92a123d --- /dev/null +++ b/.copier-answers.yml @@ -0,0 +1,12 @@ +# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY +_commit: 2023.11.17 +_src_path: gh:scientific-python/cookie +backend: hatch +email: nstarman@users.noreply.github.com +full_name: Nathaniel Starkman +license: BSD +org: GalacticDynamics +project_name: array-api-jax-compat +project_short_description: Array-API JAX compatibility +url: https://github.com/GalacticDynamics/array-api-jax-compat +vcs: true diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 0000000..8fb235d --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..00a7b00 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..1d05f23 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,101 @@ +See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed +description of best practices for developing scientific packages. + +[spc-dev-intro]: https://learn.scientific-python.org/development/ + +# Quick development + +The fastest way to start with development is to use nox. If you don't have nox, +you can use `pipx run nox` to run it without installing, or `pipx install nox`. +If you don't have pipx (pip for applications), then you can install with +`pip install pipx` (the only case were installing an application with regular +pip is reasonable). If you use macOS, then pipx and nox are both in brew, use +`brew install pipx nox`. + +To use, run `nox`. This will lint and test using every installed version of +Python on your system, skipping ones that are not installed. You can also run +specific jobs: + +```console +$ nox -s lint # Lint only +$ nox -s tests # Python tests +$ nox -s docs -- serve # Build and serve the docs +$ nox -s build # Make an SDist and wheel +``` + +Nox handles everything for you, including setting up an temporary virtual +environment for each run. + +# Setting up a development environment manually + +You can set up a development environment by running: + +```bash +python3 -m venv .venv +source ./.venv/bin/activate +pip install -v -e .[dev] +``` + +If you have the +[Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you +can instead do: + +```bash +py -m venv .venv +py -m install -v -e .[dev] +``` + +# Post setup + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # Will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=array-api-jax-compat +``` + +# Building docs + +You can build the docs using: + +```bash +nox -s docs +``` + +You can see a preview with: + +```bash +nox -s docs -- serve +``` + +# Pre-commit + +This project uses pre-commit for all style checking. While you can run it with +nox, this is such an important tool that it deserves to be installed on its own. +Install pre-commit and run: + +```bash +pre-commit run -a +``` + +to check all files. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..6fddca0 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/matchers/pylint.json b/.github/matchers/pylint.json new file mode 100644 index 0000000..e3a6bd1 --- /dev/null +++ b/.github/matchers/pylint.json @@ -0,0 +1,32 @@ +{ + "problemMatcher": [ + { + "severity": "warning", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): ([A-DF-Z]\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-warning" + }, + { + "severity": "error", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): (E\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-error" + } + ] +} diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..5100345 --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,60 @@ +name: CD + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + release: + types: + - published + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 3 + +jobs: + dist: + name: Distribution build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Build sdist and wheel + run: pipx run build + + - uses: actions/upload-artifact@v3 + with: + path: dist + + - name: Check products + run: pipx run twine check dist/* + + publish: + needs: [dist] + name: Publish to PyPI + environment: pypi + permissions: + id-token: write + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + + steps: + - uses: actions/download-artifact@v3 + with: + name: artifact + path: dist + + - uses: pypa/gh-action-pypi-publish@release/v1 + if: github.event_name == 'release' && github.event.action == 'published' + with: + # Remember to tell (test-)pypi about this repo before publishing + # Remove this line to publish to PyPI + repository-url: https://test.pypi.org/legacy/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b075b96 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,69 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 3 + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --hook-stage manual --all-files + - name: Run PyLint + run: | + echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json" + pipx run nox -s pylint + + checks: + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.12"] + runs-on: [ubuntu-latest, macos-latest, windows-latest] + + include: + - python-version: pypy-3.10 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + allow-prereleases: true + + - name: Install package + run: python -m pip install .[test] + + - name: Test package + run: >- + python -m pytest -ra --cov --cov-report=xml --cov-report=term + --durations=20 + + - name: Upload coverage report + uses: codecov/codecov-action@v3.1.4 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..25cf9a4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,158 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# setuptools_scm +src/*/_version.py + + +# ruff +.ruff_cache/ + +# OS specific stuff +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Common editor files +*~ +*.swp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4de60c7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,91 @@ +ci: + autoupdate_commit_msg: "chore: update pre-commit hooks" + autofix_commit_msg: "style: pre-commit fixes" + +repos: + - repo: https://github.com/psf/black-pre-commit-mirror + rev: "23.11.0" + hooks: + - id: black-jupyter + + - repo: https://github.com/adamchainz/blacken-docs + rev: "1.16.0" + hooks: + - id: blacken-docs + additional_dependencies: [black==23.*] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.5.0" + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: name-tests-test + args: ["--pytest-test-first"] + - id: requirements-txt-fixer + - id: trailing-whitespace + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: "v1.10.0" + hooks: + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: "v3.1.0" + hooks: + - id: prettier + types_or: [yaml, markdown, html, css, scss, javascript, json] + args: [--prose-wrap=always] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.1.5" + hooks: + - id: ruff + args: ["--fix", "--show-fixes"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.7.0" + hooks: + - id: mypy + files: src|tests + args: [] + additional_dependencies: + - numpy + - pytest + + # - repo: https://github.com/codespell-project/codespell + # rev: "v2.2.6" + # hooks: + # - id: codespell + + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: "v0.9.0.6" + hooks: + - id: shellcheck + + - repo: local + hooks: + - id: disallow-caps + name: Disallow improper capitalization + language: pygrep + entry: PyBind|Numpy|Cmake|CCache|Github|PyTest + exclude: .pre-commit-config.yaml + + - repo: https://github.com/abravalheri/validate-pyproject + rev: v0.15 + hooks: + - id: validate-pyproject + + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.27.0 + hooks: + - id: check-dependabot + - id: check-github-workflows + - id: check-readthedocs diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..7e49657 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,18 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..48f9efb --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2023, Nathaniel Starkman. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the vector package developers nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..916e778 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# array-api-jax-compat + +[![Actions Status][actions-badge]][actions-link] +[![Documentation Status][rtd-badge]][rtd-link] + +[![PyPI version][pypi-version]][pypi-link] +[![Conda-Forge][conda-badge]][conda-link] +[![PyPI platforms][pypi-platforms]][pypi-link] + +[![GitHub Discussion][github-discussions-badge]][github-discussions-link] + + + + +[actions-badge]: https://github.com/GalacticDynamics/array-api-jax-compat/workflows/CI/badge.svg +[actions-link]: https://github.com/GalacticDynamics/array-api-jax-compat/actions +[conda-badge]: https://img.shields.io/conda/vn/conda-forge/array-api-jax-compat +[conda-link]: https://github.com/conda-forge/array-api-jax-compat-feedstock +[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github +[github-discussions-link]: https://github.com/GalacticDynamics/array-api-jax-compat/discussions +[pypi-link]: https://pypi.org/project/array-api-jax-compat/ +[pypi-platforms]: https://img.shields.io/pypi/pyversions/array-api-jax-compat +[pypi-version]: https://img.shields.io/pypi/v/array-api-jax-compat +[rtd-badge]: https://readthedocs.org/projects/array-api-jax-compat/badge/?version=latest +[rtd-link]: https://array-api-jax-compat.readthedocs.io/en/latest/?badge=latest + + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..d5a7dcf --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,45 @@ +"""Sphinx configuration.""" + +import importlib.metadata + +project = "array-api-jax-compat" +copyright = "2023, Nathaniel Starkman" +author = "Nathaniel Starkman" +version = release = importlib.metadata.version("array_api_jax_compat") + +extensions = [ + "myst_parser", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", + "sphinx_copybutton", +] + +source_suffix = [".rst", ".md"] +exclude_patterns = [ + "_build", + "**.ipynb_checkpoints", + "Thumbs.db", + ".DS_Store", + ".env", + ".venv", +] + +html_theme = "furo" + +myst_enable_extensions = [ + "colon_fence", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), +} + +nitpick_ignore = [ + ("py:class", "_io.StringIO"), + ("py:class", "_io.BytesIO"), +] + +always_document_param_types = True diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..f869c98 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +# array-api-jax-compat + +```{toctree} +:maxdepth: 2 +:hidden: + +``` + +```{include} ../README.md +:start-after: +``` + +## Indices and tables + +- {ref}`genindex` +- {ref}`modindex` +- {ref}`search` diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..a86dec8 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,114 @@ +"""Nox setup.""" + +import argparse +import shutil +from pathlib import Path + +import nox + +DIR = Path(__file__).parent.resolve() + +nox.options.sessions = ["lint", "pylint", "tests"] + + +@nox.session +def lint(session: nox.Session) -> None: + """Run the linter.""" + session.install("pre-commit") + session.run( + "pre-commit", + "run", + "--all-files", + "--show-diff-on-failure", + *session.posargs, + ) + + +@nox.session +def pylint(session: nox.Session) -> None: + """Run PyLint.""" + # This needs to be installed into the package environment, and is slower + # than a pre-commit check + session.install(".", "pylint") + session.run("pylint", "array_api_jax_compat", *session.posargs) + + +@nox.session +def tests(session: nox.Session) -> None: + """Run the unit and regular tests.""" + session.install(".[test]") + session.run("pytest", *session.posargs) + + +@nox.session(reuse_venv=True) +def docs(session: nox.Session) -> None: + """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" + parser = argparse.ArgumentParser() + parser.add_argument("--serve", action="store_true", help="Serve after building") + parser.add_argument( + "-b", + dest="builder", + default="html", + help="Build target (default: html)", + ) + args, posargs = parser.parse_known_args(session.posargs) + + if args.builder != "html" and args.serve: + session.error("Must not specify non-HTML builder with --serve") + + extra_installs = ["sphinx-autobuild"] if args.serve else [] + + session.install("-e.[docs]", *extra_installs) + session.chdir("docs") + + if args.builder == "linkcheck": + session.run( + "sphinx-build", + "-b", + "linkcheck", + ".", + "_build/linkcheck", + *posargs, + ) + return + + shared_args = ( + "-n", # nitpicky mode + "-T", # full tracebacks + f"-b={args.builder}", + ".", + f"_build/{args.builder}", + *posargs, + ) + + if args.serve: + session.run("sphinx-autobuild", *shared_args) + else: + session.run("sphinx-build", "--keep-going", *shared_args) + + +@nox.session +def build_api_docs(session: nox.Session) -> None: + """Build (regenerate) API docs.""" + session.install("sphinx") + session.chdir("docs") + session.run( + "sphinx-apidoc", + "-o", + "api/", + "--module-first", + "--no-toc", + "--force", + "../src/array_api_jax_compat", + ) + + +@nox.session +def build(session: nox.Session) -> None: + """Build an SDist and wheel.""" + build_path = DIR.joinpath("build") + if build_path.exists(): + shutil.rmtree(build_path) + + session.install("build") + session.run("python", "-m", "build") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0c914d0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,159 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + + +[project] +name = "array-api-jax-compat" +authors = [ + { name = "Nathaniel Starkman", email = "nstarman@users.noreply.github.com" }, +] +description = "Array-API JAX compatibility" +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Typing :: Typed", +] +dynamic = ["version"] +dependencies = [ + "numpy", + "plum-dispatch", + "quax", +] + +[project.optional-dependencies] +test = [ + "pytest >=6", + "pytest-cov >=3", +] +dev = [ + "pytest >=6", + "pytest-cov >=3", +] +docs = [ + "sphinx>=7.0", + "myst_parser>=0.13", + "sphinx_copybutton", + "sphinx_autodoc_typehints", + "furo>=2023.08.17", +] + +[project.urls] +Homepage = "https://github.com/GalacticDynamics/array-api-jax-compat" +"Bug Tracker" = "https://github.com/GalacticDynamics/array-api-jax-compat/issues" +Discussions = "https://github.com/GalacticDynamics/array-api-jax-compat/discussions" +Changelog = "https://github.com/GalacticDynamics/array-api-jax-compat/releases" + + +[tool.hatch] +version.source = "vcs" +build.hooks.vcs.version-file = "src/array_api_jax_compat/_version.py" + +[tool.hatch.env.default] +features = ["test"] +scripts.test = "pytest {args}" + + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + + +[tool.coverage] +run.source = ["array_api_jax_compat"] +port.exclude_lines = [ + 'pragma: no cover', + '\.\.\.', + 'if typing.TYPE_CHECKING:', +] + +[tool.mypy] +files = ["src", "tests"] +python_version = "3.10" +warn_unused_configs = true +strict = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +warn_unreachable = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +warn_return_any = false +plugins = [ + "numpy.typing.mypy_plugin", +] + + [[tool.mypy.overrides]] + module = "array_api_jax_compat.*" + disallow_untyped_defs = true + disallow_incomplete_defs = true + + [[tool.mypy.overrides]] + module = [ + "jax.*", + "plum.*", + "quax.*", + ] + ignore_missing_imports = true + + + +[tool.ruff] +src = ["src"] + +[tool.ruff.lint] +extend-select = ["ALL"] +ignore = [ + "A001", # Variable is shadowing a Python builtin + "A002", # Argument is shadowing a Python builtin + "ANN101", # Missing type annotation for self in method + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed # TODO + "ARG001", # Unused function argument + "D103", # Missing docstring in public function # TODO + "D203", # one-blank-line-before-class + "D213", # Multi-line docstring summary should start at the second line + "ERA001", # Found commented-out code + "FIX002", # Line contains TODO, consider resolving the issue + "PYI041", # Use `float` instead of `int | float` + "TD002", # Missing author in TODO; try: `# TODO(): . + "TD003", # Missing issue link on the line following this TODO +] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["INP001", "S101", "T20"] +"__init__.py" = ["F403"] +"noxfile.py" = ["T20"] +"docs/conf.py" = ["INP001"] +"scratch/**" = ["ANN", "D", "FBT", "INP"] + + +[tool.pylint] +py-version = "3.10" +ignore-paths = [".*/_version.py"] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "design", + "fixme", + "line-too-long", + "missing-module-docstring", + "wrong-import-position", +] diff --git a/src/array_api_jax_compat/__init__.py b/src/array_api_jax_compat/__init__.py new file mode 100644 index 0000000..58e3db7 --- /dev/null +++ b/src/array_api_jax_compat/__init__.py @@ -0,0 +1,64 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +array-api-jax-compat: Array-API JAX compatibility +""" + + +from __future__ import annotations + +from typing import Any + +from . import ( + _constants, + _creation_functions, + _data_type_functions, + _elementwise_functions, + _indexing_functions, + _linear_algebra_functions, + _manipulation_functions, + _searching_functions, + _set_functions, + _sorting_functions, + _statistical_functions, + _utility_functions, + fft, + linalg, +) +from ._constants import * +from ._creation_functions import * +from ._data_type_functions import * +from ._elementwise_functions import * +from ._indexing_functions import * +from ._linear_algebra_functions import * +from ._manipulation_functions import * +from ._searching_functions import * +from ._set_functions import * +from ._sorting_functions import * +from ._statistical_functions import * +from ._utility_functions import * +from ._version import version as __version__ + +__all__ = ["__version__", "fft", "linalg"] +__all__ += _constants.__all__ +__all__ += _creation_functions.__all__ +__all__ += _data_type_functions.__all__ +__all__ += _elementwise_functions.__all__ +__all__ += _indexing_functions.__all__ +__all__ += _linear_algebra_functions.__all__ +__all__ += _manipulation_functions.__all__ +__all__ += _searching_functions.__all__ +__all__ += _set_functions.__all__ +__all__ += _sorting_functions.__all__ +__all__ += _statistical_functions.__all__ +__all__ += _utility_functions.__all__ + + +def __getattr__(name: str) -> Any: # TODO: fuller annotation + """Forward all other attribute accesses to Quaxified JAX.""" + import jax + from quax import quaxify + + # TODO: detect if the attribute is a function or a module. + # If it is a function, quaxify it. If it is a module, return a proxy object + # that quaxifies all of its attributes. + return quaxify(getattr(jax, name)) diff --git a/src/array_api_jax_compat/_constants.py b/src/array_api_jax_compat/_constants.py new file mode 100644 index 0000000..6fa113a --- /dev/null +++ b/src/array_api_jax_compat/_constants.py @@ -0,0 +1,5 @@ +"""JAX-compatible constants.""" + +__all__ = ["e", "inf", "nan", "newaxis", "pi"] + +from jax.numpy import e, inf, nan, newaxis, pi diff --git a/src/array_api_jax_compat/_creation_functions.py b/src/array_api_jax_compat/_creation_functions.py new file mode 100644 index 0000000..a01d457 --- /dev/null +++ b/src/array_api_jax_compat/_creation_functions.py @@ -0,0 +1,163 @@ +"""Array API creation functions.""" + +__all__ = [ + # "arange", + "asarray", + # "empty", + "empty_like", + # "eye", + # "from_dlpack", + # "full", + "full_like", + # "linspace", + "meshgrid", + # "ones", + "ones_like", + "tril", + "triu", + # "zeros", + "zeros_like", +] + + +from functools import partial +from typing import Any, TypeVar + +import jax +import jax.numpy as jnp +from jax import Device +from quax import Value + +from ._dispatch import dispatcher +from ._types import DType, NestedSequence, SupportsBufferProtocol +from ._utils import quaxify + +T = TypeVar("T") + +# ============================================================================= + + +@partial(jax.jit, static_argnames=("dtype", "device", "copy")) +@quaxify +def asarray( + obj: Value + | bool + | int + | float + | complex + | NestedSequence[Any] + | SupportsBufferProtocol, + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, +) -> Value: + out = jnp.asarray(obj, dtype=dtype) + return jax.device_put(out, device=device) + # TODO: jax.lax.cond is not yet supported by Quax. + # out = jax.lax.cond(bool(copy), lambda x: jax.lax.copy_p.bind(x), lambda x: x, out) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("dtype", "device")) +# @quaxify # TODO: quaxify won't work here because of how the function is defined. +@dispatcher # type: ignore[misc] +def empty_like( + x: jax.Array | jax.core.Tracer | Value, + /, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | jax.core.Tracer | Value: + out = jnp.empty_like(x, dtype=dtype) + return jax.device_put(out, device=device) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("dtype", "device")) +# @quaxify # TODO: quaxify won't work here because of how the function is defined. +@dispatcher # type: ignore[misc] +def full_like( + x: jax.Array | jax.core.Tracer | Value, + /, + fill_value: bool | int | float | complex, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | jax.core.Tracer | Value: + out = jnp.full_like(x, fill_value, dtype=dtype) + return jax.device_put(out, device=device) + + +# ============================================================================= + + +@quaxify +def meshgrid(*arrays: Value, indexing: str = "xy") -> list[Value]: + return jnp.meshgrid(*arrays, indexing=indexing) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("dtype", "device")) +# @quaxify # TODO: quaxify won't work here because of how the function is defined. +@dispatcher # type: ignore[misc] +def ones_like( + x: jax.Array | jax.core.Tracer | Value, + /, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | jax.core.Tracer | Value: + out = jnp.ones_like(x, dtype=dtype) + return jax.device_put(out, device=device) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("k",)) +@quaxify +def tril(x: Value, /, *, k: int = 0) -> Value: + return jnp.tril(x, k=k) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("k",)) +@quaxify +def triu(x: Value, /, *, k: int = 0) -> Value: + return jnp.triu(x, k=k) + + +# ============================================================================= + + +# @partial(jax.jit, static_argnames=("dtype", "device")) +# @quaxify +@dispatcher # type: ignore[misc] +def zeros_like( + x: jax.Array | jax.core.Tracer | Value, + /, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> Value | jax.core.Tracer | jax.Array: + out = jnp.zeros_like(x, dtype=dtype) + return jax.device_put(out, device=device) + + +# @dispatcher +# def zeros_like( +# x: quax.zero.Zero, /, *, dtype: DType | None = None, device: Device | None = None +# ) -> jnp.ndarray: +# out = jnp.zeros_like(x, dtype=dtype) +# out = jax.device_put(out, device=device) +# return out diff --git a/src/array_api_jax_compat/_data_type_functions.py b/src/array_api_jax_compat/_data_type_functions.py new file mode 100644 index 0000000..ccd040a --- /dev/null +++ b/src/array_api_jax_compat/_data_type_functions.py @@ -0,0 +1,48 @@ +__all__ = ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] + + +import jax +from jax import Device +from quax import Value + +from ._types import DType +from ._utils import quaxify + + +@quaxify +def astype( + x: Value, + dtype: DType, + /, + *, + copy: bool = True, + device: Device | None = None, +) -> Value: + # TODO: copy is not yet supported + out = jax.lax.convert_element_type(x, dtype) + return jax.device_put(out, device=device) + + +@quaxify +def can_cast(from_: DType | Value, to: DType, /) -> bool: + return jax.numpy.can_cast(from_, to) + + +@quaxify +def finfo(type: DType | Value, /) -> jax.numpy.finfo: + return jax.numpy.finfo(type) + + +@quaxify +def iinfo(type: DType | Value, /) -> jax.numpy.iinfo: + return jax.numpy.iinfo(type) + + +@quaxify +def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: + raise NotImplementedError + + +@quaxify +def result_type(*arrays_and_dtypes: Value | DType) -> DType: + return jax.numpy.result_type(*arrays_and_dtypes) diff --git a/src/array_api_jax_compat/_dispatch.py b/src/array_api_jax_compat/_dispatch.py new file mode 100644 index 0000000..ebde228 --- /dev/null +++ b/src/array_api_jax_compat/_dispatch.py @@ -0,0 +1,7 @@ +"""Dispatching.""" + +__all__: list[str] = [] + +import plum + +dispatcher = plum.Dispatcher() diff --git a/src/array_api_jax_compat/_elementwise_functions.py b/src/array_api_jax_compat/_elementwise_functions.py new file mode 100644 index 0000000..d7e2e04 --- /dev/null +++ b/src/array_api_jax_compat/_elementwise_functions.py @@ -0,0 +1,386 @@ +__all__ = [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", +] + + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def abs(x: Value, /) -> Value: + return jnp.abs(x) + + +@quaxify +def acos(x: Value, /) -> Value: + return jnp.arccos(x) + + +@quaxify +def acosh(x: Value, /) -> Value: + return jnp.arccosh(x) + + +@quaxify +def add(x1: Value, x2: Value, /) -> Value: + return jnp.add(x1, x2) + + +@quaxify +def asin(x: Value, /) -> Value: + return jnp.arcsin(x) + + +@quaxify +def asinh(x: Value, /) -> Value: + return jnp.arcsinh(x) + + +@quaxify +def atan(x: Value, /) -> Value: + return jnp.arctan(x) + + +@quaxify +def atan2(x1: Value, x2: Value, /) -> Value: + return jnp.arctan2(x1, x2) + + +@quaxify +def atanh(x: Value, /) -> Value: + return jnp.arctanh(x) + + +@quaxify +def bitwise_and(x1: Value, x2: Value, /) -> Value: + return jnp.bitwise_and(x1, x2) + + +@quaxify +def bitwise_left_shift(x1: Value, x2: Value, /) -> Value: + return jnp.left_shift(x1, x2) + + +@quaxify +def bitwise_invert(x: Value, /) -> Value: + return jnp.bitwise_not(x) + + +@quaxify +def bitwise_or(x1: Value, x2: Value, /) -> Value: + return jnp.bitwise_or(x1, x2) + + +@quaxify +def bitwise_right_shift(x1: Value, x2: Value, /) -> Value: + return jnp.right_shift(x1, x2) + + +@quaxify +def bitwise_xor(x1: Value, x2: Value, /) -> Value: + return jnp.bitwise_xor(x1, x2) + + +@quaxify +def ceil(x: Value, /) -> Value: + return jnp.ceil(x) + + +@quaxify +def conj(x: Value, /) -> Value: + return jnp.conj(x) + + +@quaxify +def copysign(x1: Value, x2: Value, /) -> Value: + return jnp.copysign(x1, x2) + + +@quaxify +def cos(x: Value, /) -> Value: + return jnp.cos(x) + + +@quaxify +def cosh(x: Value, /) -> Value: + return jnp.cosh(x) + + +@quaxify +def divide(x1: Value, x2: Value, /) -> Value: + return jnp.divide(x1, x2) + + +@quaxify +def equal(x1: Value, x2: Value, /) -> Value: + return jnp.equal(x1, x2) + + +@quaxify +def exp(x: Value, /) -> Value: + return jnp.exp(x) + + +@quaxify +def expm1(x: Value, /) -> Value: + return jnp.expm1(x) + + +@quaxify +def floor(x: Value, /) -> Value: + return jnp.floor(x) + + +@quaxify +def floor_divide(x1: Value, x2: Value, /) -> Value: + return jnp.floor_divide(x1, x2) + + +@quaxify +def greater(x1: Value, x2: Value, /) -> Value: + return jnp.greater(x1, x2) + + +@quaxify +def greater_equal(x1: Value, x2: Value, /) -> Value: + return jnp.greater_equal(x1, x2) + + +@quaxify +def imag(x: Value, /) -> Value: + return jnp.imag(x) + + +@quaxify +def isfinite(x: Value, /) -> Value: + return jnp.isfinite(x) + + +@quaxify +def isinf(x: Value, /) -> Value: + return jnp.isinf(x) + + +@quaxify +def isnan(x: Value, /) -> Value: + return jnp.isnan(x) + + +@quaxify +def less(x1: Value, x2: Value, /) -> Value: + return jnp.less(x1, x2) + + +@quaxify +def less_equal(x1: Value, x2: Value, /) -> Value: + return jnp.less_equal(x1, x2) + + +@quaxify +def log(x: Value, /) -> Value: + return jnp.log(x) + + +@quaxify +def log1p(x: Value, /) -> Value: + return jnp.log1p(x) + + +@quaxify +def log2(x: Value, /) -> Value: + return jnp.log2(x) + + +@quaxify +def log10(x: Value, /) -> Value: + return jnp.log10(x) + + +@quaxify +def logaddexp(x1: Value, x2: Value, /) -> Value: + return jnp.logaddexp(x1, x2) + + +@quaxify +def logical_and(x1: Value, x2: Value, /) -> Value: + return jnp.logical_and(x1, x2) + + +@quaxify +def logical_not(x: Value, /) -> Value: + return jnp.logical_not(x) + + +@quaxify +def logical_or(x1: Value, x2: Value, /) -> Value: + return jnp.logical_or(x1, x2) + + +@quaxify +def logical_xor(x1: Value, x2: Value, /) -> Value: + return jnp.logical_xor(x1, x2) + + +@quaxify +def maximum(x1: Value, x2: Value, /) -> Value: + return jnp.maximum(x1, x2) + + +@quaxify +def minimum(x1: Value, x2: Value, /) -> Value: + return jnp.minimum(x1, x2) + + +@quaxify +def multiply(x1: Value, x2: Value, /) -> Value: + return jnp.multiply(x1, x2) + + +@quaxify +def negative(x: Value, /) -> Value: + return jnp.negative(x) + + +@quaxify +def not_equal(x1: Value, x2: Value, /) -> Value: + return jnp.not_equal(x1, x2) + + +@quaxify +def positive(x: Value, /) -> Value: + return jnp.positive(x) + + +@quaxify +def pow(x1: Value, x2: Value, /) -> Value: + return jnp.power(x1, x2) + + +@quaxify +def real(x: Value, /) -> Value: + return jnp.real(x) + + +@quaxify +def remainder(x1: Value, x2: Value, /) -> Value: + return jnp.remainder(x1, x2) + + +@quaxify +def round(x: Value, /) -> Value: + return jnp.round(x) + + +@quaxify +def sign(x: Value, /) -> Value: + return jnp.sign(x) + + +@quaxify +def signbit(x: Value, /) -> Value: + return jnp.signbit(x) + + +@quaxify +def sin(x: Value, /) -> Value: + return jnp.sin(x) + + +@quaxify +def sinh(x: Value, /) -> Value: + return jnp.sinh(x) + + +@quaxify +def square(x: Value, /) -> Value: + return jnp.square(x) + + +@quaxify +def sqrt(x: Value, /) -> Value: + return jnp.sqrt(x) + + +@quaxify +def subtract(x1: Value, x2: Value, /) -> Value: + return jnp.subtract(x1, x2) + + +@quaxify +def tan(x: Value, /) -> Value: + return jnp.tan(x) + + +@quaxify +def tanh(x: Value, /) -> Value: + return jnp.tanh(x) + + +@quaxify +def trunc(x: Value, /) -> Value: + return jnp.trunc(x) diff --git a/src/array_api_jax_compat/_indexing_functions.py b/src/array_api_jax_compat/_indexing_functions.py new file mode 100644 index 0000000..f05b37d --- /dev/null +++ b/src/array_api_jax_compat/_indexing_functions.py @@ -0,0 +1,8 @@ +__all__ = ["take"] + +import jax.numpy as jnp +from quax import Value + + +def take(x: Value, indices: Value, /, *, axis: int | None = None) -> Value: + return jnp.take(x, indices, axis=axis) diff --git a/src/array_api_jax_compat/_linear_algebra_functions.py b/src/array_api_jax_compat/_linear_algebra_functions.py new file mode 100644 index 0000000..d081399 --- /dev/null +++ b/src/array_api_jax_compat/_linear_algebra_functions.py @@ -0,0 +1,35 @@ +__all__ = ["matmul", "matrix_transpose", "tensordot", "vecdot"] + + +from collections.abc import Sequence + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def matmul(x1: Value, x2: Value, /) -> Value: + return jnp.matmul(x1, x2) + + +@quaxify +def matrix_transpose(x: Value, /) -> Value: + return jnp.transpose(x) + + +@quaxify +def tensordot( + x1: Value, + x2: Value, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Value: + return jnp.tensordot(x1, x2, axes=axes) + + +@quaxify +def vecdot(x1: Value, x2: Value, /, *, axis: int = -1) -> Value: + return jnp.dot(x1, x2, axis=axis) diff --git a/src/array_api_jax_compat/_manipulation_functions.py b/src/array_api_jax_compat/_manipulation_functions.py new file mode 100644 index 0000000..5d5c9e7 --- /dev/null +++ b/src/array_api_jax_compat/_manipulation_functions.py @@ -0,0 +1,101 @@ +__all__ = [ + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", +] + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def broadcast_arrays(*arrays: Value) -> list[Value]: + return jnp.broadcast_arrays(*arrays) + + +@quaxify +def broadcast_to(x: Value, /, shape: tuple[int, ...]) -> Value: + return jnp.broadcast_to(x, shape) + + +@quaxify +def concat( + arrays: tuple[Value, ...] | list[Value], + /, + *, + axis: int | None = 0, +) -> Value: + return jnp.concatenate(arrays, axis=axis) + + +@quaxify +def expand_dims(x: Value, /, *, axis: int = 0) -> Value: + return jnp.expand_dims(x, axis=axis) + + +@quaxify +def flip(x: Value, /, *, axis: int | tuple[int, ...] | None = None) -> Value: + return jnp.flip(x, axis=axis) + + +@quaxify +def moveaxis( + x: Value, + source: int | tuple[int, ...], + destination: int | tuple[int, ...], + /, +) -> Value: + return jnp.moveaxis(x, source, destination) + + +@quaxify +def permute_dims(x: Value, /, axes: tuple[int, ...]) -> Value: + return jnp.transpose(x, axes) + + +@quaxify +def reshape(x: Value, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Value: + return jnp.reshape(x, shape, order="C" if copy else "K") + + +@quaxify +def roll( + x: Value, + /, + shift: int | tuple[int, ...], + *, + axis: int | tuple[int, ...] | None = None, +) -> Value: + return jnp.roll(x, shift, axis=axis) + + +@quaxify +def squeeze(x: Value, /, axis: int | tuple[int, ...]) -> Value: + return jnp.squeeze(x, axis=axis) + + +@quaxify +def stack(arrays: tuple[Value, ...] | list[Value], /, *, axis: int = 0) -> Value: + return jnp.stack(arrays, axis=axis) + + +@quaxify +def tile(x: Value, repetitions: tuple[int, ...], /) -> Value: + return jnp.tile(x, repetitions) + + +@quaxify +def unstack(x: Value, /, *, axis: int = 0) -> tuple[Value, ...]: + return jnp.split(x, axis=axis) diff --git a/src/array_api_jax_compat/_searching_functions.py b/src/array_api_jax_compat/_searching_functions.py new file mode 100644 index 0000000..ff26aea --- /dev/null +++ b/src/array_api_jax_compat/_searching_functions.py @@ -0,0 +1,27 @@ +__all__ = ["argmax", "argmin", "nonzero", "where"] + + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def argmax(x: Value, /, *, axis: int | None = None, keepdims: bool = False) -> Value: + return jnp.argmax(x, axis=axis, keepdims=keepdims) + + +@quaxify +def argmin(x: Value, /, *, axis: int | None = None, keepdims: bool = False) -> Value: + return jnp.argmin(x, axis=axis, keepdims=keepdims) + + +@quaxify +def nonzero(x: Value, /) -> tuple[Value, ...]: + return jnp.nonzero(x) + + +@quaxify +def where(condition: Value, x1: Value, x2: Value, /) -> Value: + return jnp.where(condition, x1, x2) diff --git a/src/array_api_jax_compat/_set_functions.py b/src/array_api_jax_compat/_set_functions.py new file mode 100644 index 0000000..8e46b08 --- /dev/null +++ b/src/array_api_jax_compat/_set_functions.py @@ -0,0 +1,27 @@ +__all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"] + + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def unique_all(x: Value, /) -> tuple[Value, Value, Value, Value]: + return jnp.unique(x, return_counts=True, return_index=True, return_inverse=True) + + +@quaxify +def unique_counts(x: Value, /) -> tuple[Value, Value]: + return jnp.unique(x, return_counts=True) + + +@quaxify +def unique_inverse(x: Value, /) -> tuple[Value, Value]: + return jnp.unique(x, return_inverse=True) + + +@quaxify +def unique_values(x: Value, /) -> Value: + return jnp.unique(x) diff --git a/src/array_api_jax_compat/_sorting_functions.py b/src/array_api_jax_compat/_sorting_functions.py new file mode 100644 index 0000000..cc2d69b --- /dev/null +++ b/src/array_api_jax_compat/_sorting_functions.py @@ -0,0 +1,31 @@ +__all__ = ["argsort", "sort"] + + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def argsort( + x: Value, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> Value: + return jnp.argsort(x, axis=axis, descending=descending, stable=stable) + + +@quaxify +def sort( + x: Value, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, +) -> Value: + return jnp.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/src/array_api_jax_compat/_statistical_functions.py b/src/array_api_jax_compat/_statistical_functions.py new file mode 100644 index 0000000..982d25f --- /dev/null +++ b/src/array_api_jax_compat/_statistical_functions.py @@ -0,0 +1,101 @@ +__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] + + +import jax.numpy as jnp +from quax import Value + +from ._types import DType +from ._utils import quaxify + + +@quaxify +def cumulative_sum( + x: Value, + /, + *, + axis: int | None = None, + dtype: DType | None = None, + include_initial: bool = False, +) -> Value: + return jnp.cumsum(x, axis=axis, dtype=dtype) + + +@quaxify +def max( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Value: + return jnp.max(x, axis=axis, keepdims=keepdims) + + +@quaxify +def mean( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Value: + return jnp.mean(x, axis=axis, keepdims=keepdims) + + +@quaxify +def min( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Value: + return jnp.min(x, axis=axis, keepdims=keepdims) + + +@quaxify +def prod( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, + keepdims: bool = False, +) -> Value: + return jnp.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +@quaxify +def std( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, + keepdims: bool = False, +) -> Value: + return jnp.std(x, axis=axis, correction=correction, keepdims=keepdims) + + +@quaxify +def sum( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, + keepdims: bool = False, +) -> Value: + return jnp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +@quaxify +def var( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, + keepdims: bool = False, +) -> Value: + return jnp.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/src/array_api_jax_compat/_types.py b/src/array_api_jax_compat/_types.py new file mode 100644 index 0000000..1dd431c --- /dev/null +++ b/src/array_api_jax_compat/_types.py @@ -0,0 +1,33 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +array-api-jax-compat: Array-API JAX compatibility +""" + + +__all__: list[str] = [] + +from typing import Any, Protocol, TypeVar, runtime_checkable + +import numpy as np + + +@runtime_checkable +class DType(Protocol): + """The dtype of an array.""" + + dtype: np.dtype[Any] + + +class SupportsBufferProtocol(Protocol): + ... # TODO: add whatever defines the buffer protocol support + + +_T_co = TypeVar("_T_co", covariant=True) + + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> "_T_co | NestedSequence[_T_co]": + ... + + def __len__(self, /) -> int: + ... diff --git a/src/array_api_jax_compat/_utility_functions.py b/src/array_api_jax_compat/_utility_functions.py new file mode 100644 index 0000000..8c70fbf --- /dev/null +++ b/src/array_api_jax_compat/_utility_functions.py @@ -0,0 +1,30 @@ +"""Utility functions.""" + +__all__ = ["all", "any"] + +import jax.numpy as jnp +from quax import Value + +from ._utils import quaxify + + +@quaxify +def all( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Value: + return jnp.all(x, axis=axis, keepdims=keepdims) + + +@quaxify +def any( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Value: + return jnp.any(x, axis=axis, keepdims=keepdims) diff --git a/src/array_api_jax_compat/_utils.py b/src/array_api_jax_compat/_utils.py new file mode 100644 index 0000000..747e248 --- /dev/null +++ b/src/array_api_jax_compat/_utils.py @@ -0,0 +1,10 @@ +from typing import TypeVar + +import quax + +T = TypeVar("T") + + +def quaxify(func: T) -> T: + """Quaxify, but makes mypy happy.""" + return quax.quaxify(func) diff --git a/src/array_api_jax_compat/_version.pyi b/src/array_api_jax_compat/_version.pyi new file mode 100644 index 0000000..5bb2b22 --- /dev/null +++ b/src/array_api_jax_compat/_version.pyi @@ -0,0 +1,2 @@ +version: str +version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/array_api_jax_compat/fft.py b/src/array_api_jax_compat/fft.py new file mode 100644 index 0000000..2ab496e --- /dev/null +++ b/src/array_api_jax_compat/fft.py @@ -0,0 +1,170 @@ +"""FFT functions.""" + +__all__ = [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", +] + +from collections.abc import Sequence +from typing import Literal + +import jax +import jax.numpy as jnp +from jax import Device +from quax import Value + +from ._utils import quaxify + + +@quaxify +def fft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.fft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def ifft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def fftn( + x: Value, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) + + +@quaxify +def ifftn( + x: Value, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) + + +@quaxify +def rfft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def irfft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def rfftn( + x: Value, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) + + +@quaxify +def irfftn( + x: Value, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) + + +@quaxify +def hfft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def ihfft( + x: Value, + /, + *, + n: int | None = None, + axis: int = -1, + norm: Literal["backward", "ortho", "forward"] = "backward", +) -> Value: + return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) + + +@quaxify +def fftfreq(n: int, /, *, d: float = 1.0, device: Device | None = None) -> Value: + out = jnp.fft.fftfreq(n, d=d) + return jax.device_put(out, device=device) + + +@quaxify +def rfftfreq(n: int, /, *, d: float = 1.0, device: Device | None = None) -> Value: + out = jnp.fft.rfftfreq(n, d=d) + return jax.device_put(out, device=device) + + +@quaxify +def fftshift(x: Value, /, *, axes: int | Sequence[int] | None = None) -> Value: + return jnp.fft.fftshift(x, axes=axes) + + +@quaxify +def ifftshift(x: Value, /, *, axes: int | Sequence[int] | None = None) -> Value: + return jnp.fft.ifftshift(x, axes=axes) diff --git a/src/array_api_jax_compat/linalg.py b/src/array_api_jax_compat/linalg.py new file mode 100644 index 0000000..8487260 --- /dev/null +++ b/src/array_api_jax_compat/linalg.py @@ -0,0 +1,176 @@ +"""Linear algebra functions.""" + +__all__ = [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", +] + + +from collections.abc import Sequence +from typing import Literal + +import jax.numpy as jnp +from quax import Value + +from ._types import DType +from ._utils import quaxify + + +@quaxify +def cholesky(x: Value, /, *, upper: bool = False) -> Value: + return jnp.linalg.cholesky(x, upper=upper) + + +@quaxify +def cross(x1: Value, x2: Value, /, *, axis: int = -1) -> Value: + return jnp.cross(x1, x2, axis=axis) + + +@quaxify +def det(x: Value, /) -> Value: + return jnp.linalg.det(x) + + +@quaxify +def diagonal(x: Value, /, *, offset: int = 0) -> Value: + return jnp.diagonal(x, offset=offset) + + +@quaxify +def eigh(x: Value, /) -> tuple[Value]: + return jnp.linalg.eigh(x) + + +@quaxify +def eigvalsh(x: Value, /) -> Value: + return jnp.linalg.eigvalsh(x) + + +@quaxify +def inv(x: Value, /) -> Value: + return jnp.linalg.inv(x) + + +@quaxify +def matmul(x1: Value, x2: Value, /) -> Value: + return jnp.matmul(x1, x2) + + +@quaxify +def matrix_norm( + x: Value, + /, + *, + keepdims: bool = False, + ord: int | float | Literal["fro", "nuc"] | None = "fro", +) -> Value: + return jnp.linalg.norm(x, keepdims=keepdims, ord=ord) + + +@quaxify +def matrix_power(x: Value, n: int, /) -> Value: + return jnp.linalg.matrix_power(x, n) + + +@quaxify +def matrix_rank(x: Value, /, *, rtol: float | Value | None = None) -> Value: + return jnp.linalg.matrix_rank(x, rtol=rtol) + + +@quaxify +def matrix_transpose(x: Value, /) -> Value: + return jnp.transpose(x) + + +@quaxify +def outer(x1: Value, x2: Value, /) -> Value: + return jnp.outer(x1, x2) + + +@quaxify +def pinv(x: Value, /, *, rtol: float | Value | None = None) -> Value: + return jnp.linalg.pinv(x, rtol=rtol) + + +@quaxify +def qr( + x: Value, + /, + *, + mode: Literal["reduced", "complete"] = "reduced", +) -> tuple[Value, Value]: + return jnp.linalg.qr(x, mode=mode) + + +@quaxify +def slogdet(x: Value, /) -> tuple[Value, Value]: + return jnp.linalg.slogdet(x) + + +@quaxify +def solve(x1: Value, x2: Value, /) -> Value: + return jnp.linalg.solve(x1, x2) + + +@quaxify +def svd(x: Value, /, *, full_matrices: bool = True) -> tuple[Value, Value, Value]: + return jnp.linalg.svd(x, full_matrices=full_matrices) + + +@quaxify +def svdvals(x: Value, /) -> Value: + return jnp.linalg.svdvals(x) + + +@quaxify +def tensordot( + x1: Value, + x2: Value, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Value: + return jnp.tensordot(x1, x2, axes=axes) + + +@quaxify +def trace(x: Value, /, *, offset: int = 0, dtype: DType | None = None) -> Value: + return jnp.trace(x, offset=offset, dtype=dtype) + + +@quaxify +def vecdot(x1: Value, x2: Value, /, *, axis: int | None = None) -> Value: + return jnp.dot(x1, x2, axis=axis) + + +@quaxify +def vector_norm( + x: Value, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ord: int | float = 2, +) -> Value: + return jnp.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) diff --git a/src/array_api_jax_compat/py.typed b/src/array_api_jax_compat/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..9842c3e --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,10 @@ +"""Test the package itself.""" + + +import importlib.metadata + +import array_api_jax_compat as m + + +def test_version() -> None: + assert importlib.metadata.version("array_api_jax_compat") == m.__version__