Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More tests #31

Merged
merged 5 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -41,7 +41,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -61,7 +61,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -43,24 +23,16 @@ repos:
additional_dependencies: ["toml"]
files: "(surjectors|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@

Surjectors is a light-weight library for density estimation using
inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.
Surjectors builds on Distrax and Haiku and is fully compatible with both of them.

Surjectors makes use of

- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,
- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,
- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,
- [JAX](https://github.com/google/jax) for autodiff and XLA computation.

## Examples

Expand Down
12 changes: 12 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"examples/*py"
]

autodoc_typehints = "both"

html_theme = "sphinx_book_theme"

html_theme_options = {
Expand All @@ -52,3 +54,13 @@
}

html_title = "Surjectors 🚀"


def skip(app, what, name, obj, would_skip, options):
if name == "__init__":
return True
return would_skip


def setup(app):
app.connect("autodoc-skip-member", skip)
25 changes: 21 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Surjectors builds on Distrax and Haiku and is fully compatible with both of them

Surjectors makes use of

- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
- `Haiku's <https://github.com/deepmind/dm-haiku>`_ module system for neural networks,
- `Distrax <https://github.com/deepmind/distrax>`_ for probability distributions and some base bijectors,
- `Optax <https://github.com/deepmind/optax>`_ for gradient-based optimization,
- `JAX <https://github.com/google/jax>`_ for autodiff and XLA computation.

Example
-------
Expand Down Expand Up @@ -90,6 +90,23 @@ In order to contribute:
4) test it by calling :code:`hatch run test` on the (Unix) command line,
5) submit a PR 🙂

Citing Surjectors
-----------------

.. code-block:: latex

@article{dirmeier2024surjectors,
author = {Simon Dirmeier},
title = {Surjectors: surjection layers for density estimation with normalizing flows},
year = {2024},
journal = {Journal of Open Source Software},
publisher = {The Open Journal},
volume = {9},
number = {94},
pages = {6188},
doi = {10.21105/joss.06188}
}

License
-------

Expand Down
3 changes: 2 additions & 1 deletion docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Latest news on the development of `Surjectors`.

.. note::

No news so far :).
- 29.02.2024. Surjectors has been accepted for publication in the Journal of Open Source Software 🎉🎉🎉.
Find the paper here `10.21105/joss.06188 <https://joss.theoj.org/papers/10.21105/joss.06188>`_.
45 changes: 30 additions & 15 deletions docs/surjectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Hence, every normalizing flow can be composed by defining these three components
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)

Regardless of how the chain of transformations (called :code:`flow` above) is defined,
each pushforward has access to four methods :code:`sample`, :code:`sample_and_log_prob`:code:`log_prob`, and :code:`inverse_and_log_prob`.

The exact method declarations can be found in the API below.

General
-------
Expand All @@ -41,39 +45,51 @@ TransformedDistribution
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TransformedDistribution
:members:
:members: log_prob, sample, inverse_and_log_prob, sample_and_log_prob

Chain
~~~~~

.. autoclass:: Chain
:members: __init__
:members:

Bijective layers
----------------

.. autosummary::
MaskedAutoregressive
AffineMaskedAutoregressive
MaskedCoupling
AffineMaskedCoupling
RationalQuadraticSplineMaskedCoupling
Permutation

Autoregressive bijections
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedAutoregressive
:members: __init__
:members:

.. autoclass:: AffineMaskedAutoregressive
:members:

Coupling bijections
~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCoupling
:members: __init__
:members:

.. autoclass:: AffineMaskedCoupling
:members:

.. autoclass:: RationalQuadraticSplineMaskedCoupling
:members:

Other bijections
~~~~~~~~~~~~~~~~

.. autoclass:: Permutation
:members: __init__
:members:

Inference surjection layers
---------------------------
Expand All @@ -96,35 +112,34 @@ Coupling inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCouplingInferenceFunnel
:members: __init__

:members:

.. autoclass:: AffineMaskedCouplingInferenceFunnel
:members: __init__
:members:

.. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel
:members: __init__
:members:

Autoregressive inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

.. autoclass:: AffineMaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

.. autoclass:: RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

Other inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LULinear
:members: __init__
:members:

.. autoclass:: MLPInferenceFunnel
:members: __init__
:members:

.. autoclass:: Slice
:members: __init__
:members:
54 changes: 13 additions & 41 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
homepage = "https://github.com/dirmeier/surjectors"
keywords = ["normalizing flows", "surjections", "density estimation"]
classifiers = [
"Development Status :: 1 - Planning",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -49,55 +49,27 @@ dependencies = [

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0"
]

[tool.hatch.envs.test.scripts]
lint = 'pylint surjectors'
lint = 'ruff check surjectors'
test = 'pytest -v --cov=./surjectors --cov-report=xml surjectors'

[tool.black]
line-length = 80
extend-ignore = "E203"
target-version = ['py310']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.bandit]
skips = ["B101"]

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "docs/**", "examples/**"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503"]
per-file-ignores = [
'__init__.py:F401',
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.pylint.messages_control]
disable = """
invalid-name,missing-module-docstring,R0801
"""

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'google'
match = '^surjectors/.*/((?!_test).)*\.py'
Loading
Loading