Skip to content

Commit

Permalink
Move to ruff instead of alternatives (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Feb 29, 2024
1 parent 8a40cbd commit 68c2f8a
Show file tree
Hide file tree
Showing 37 changed files with 611 additions and 332 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -41,7 +41,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -61,7 +61,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -43,24 +23,16 @@ repos:
additional_dependencies: ["toml"]
files: "(surjectors|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(surjectors|examples)"

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@

Surjectors is a light-weight library for density estimation using
inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.
Surjectors builds on Distrax and Haiku and is fully compatible with both of them.

Surjectors makes use of

- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,
- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,
- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,
- [JAX](https://github.com/google/jax) for autodiff and XLA computation.

## Examples

Expand Down
12 changes: 12 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"examples/*py"
]

autodoc_typehints = "both"

html_theme = "sphinx_book_theme"

html_theme_options = {
Expand All @@ -52,3 +54,13 @@
}

html_title = "Surjectors 🚀"


def skip(app, what, name, obj, would_skip, options):
if name == "__init__":
return True
return would_skip


def setup(app):
app.connect("autodoc-skip-member", skip)
25 changes: 21 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Surjectors builds on Distrax and Haiku and is fully compatible with both of them

Surjectors makes use of

- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
- `Haiku's <https://github.com/deepmind/dm-haiku>`_ module system for neural networks,
- `Distrax <https://github.com/deepmind/distrax>`_ for probability distributions and some base bijectors,
- `Optax <https://github.com/deepmind/optax>`_ for gradient-based optimization,
- `JAX <https://github.com/google/jax>`_ for autodiff and XLA computation.

Example
-------
Expand Down Expand Up @@ -90,6 +90,23 @@ In order to contribute:
4) test it by calling :code:`hatch run test` on the (Unix) command line,
5) submit a PR 🙂

Citing Surjectors
-----------------

.. code-block:: latex

@article{dirmeier2024surjectors,
author = {Simon Dirmeier},
title = {Surjectors: surjection layers for density estimation with normalizing flows},
year = {2024},
journal = {Journal of Open Source Software},
publisher = {The Open Journal},
volume = {9},
number = {94},
pages = {6188},
doi = {10.21105/joss.06188}
}

License
-------

Expand Down
3 changes: 2 additions & 1 deletion docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Latest news on the development of `Surjectors`.

.. note::

No news so far :).
- 29.02.2024. Surjectors has been accepted for publication in the Journal of Open Source Software 🎉🎉🎉.
Find the paper here `10.21105/joss.06188 <https://joss.theoj.org/papers/10.21105/joss.06188>`_.
45 changes: 30 additions & 15 deletions docs/surjectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Hence, every normalizing flow can be composed by defining these three components
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)

Regardless of how the chain of transformations (called :code:`flow` above) is defined,
each pushforward has access to four methods :code:`sample`, :code:`sample_and_log_prob`:code:`log_prob`, and :code:`inverse_and_log_prob`.

The exact method declarations can be found in the API below.

General
-------
Expand All @@ -41,39 +45,51 @@ TransformedDistribution
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TransformedDistribution
:members:
:members: log_prob, sample, inverse_and_log_prob, sample_and_log_prob

Chain
~~~~~

.. autoclass:: Chain
:members: __init__
:members:

Bijective layers
----------------

.. autosummary::
MaskedAutoregressive
AffineMaskedAutoregressive
MaskedCoupling
AffineMaskedCoupling
RationalQuadraticSplineMaskedCoupling
Permutation

Autoregressive bijections
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedAutoregressive
:members: __init__
:members:

.. autoclass:: AffineMaskedAutoregressive
:members:

Coupling bijections
~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCoupling
:members: __init__
:members:

.. autoclass:: AffineMaskedCoupling
:members:

.. autoclass:: RationalQuadraticSplineMaskedCoupling
:members:

Other bijections
~~~~~~~~~~~~~~~~

.. autoclass:: Permutation
:members: __init__
:members:

Inference surjection layers
---------------------------
Expand All @@ -96,35 +112,34 @@ Coupling inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCouplingInferenceFunnel
:members: __init__

:members:

.. autoclass:: AffineMaskedCouplingInferenceFunnel
:members: __init__
:members:

.. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel
:members: __init__
:members:

Autoregressive inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

.. autoclass:: AffineMaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

.. autoclass:: RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel
:members: __init__
:members:

Other inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LULinear
:members: __init__
:members:

.. autoclass:: MLPInferenceFunnel
:members: __init__
:members:

.. autoclass:: Slice
:members: __init__
:members:
54 changes: 13 additions & 41 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
homepage = "https://github.com/dirmeier/surjectors"
keywords = ["normalizing flows", "surjections", "density estimation"]
classifiers = [
"Development Status :: 1 - Planning",
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -49,55 +49,27 @@ dependencies = [

[tool.hatch.envs.test]
dependencies = [
"pylint>=2.15.10",
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0"
]

[tool.hatch.envs.test.scripts]
lint = 'pylint surjectors'
lint = 'ruff check surjectors'
test = 'pytest -v --cov=./surjectors --cov-report=xml surjectors'

[tool.black]
line-length = 80
extend-ignore = "E203"
target-version = ['py310']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.bandit]
skips = ["B101"]

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "docs/**", "examples/**"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503"]
per-file-ignores = [
'__init__.py:F401',
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.pylint.messages_control]
disable = """
invalid-name,missing-module-docstring,R0801
"""

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'google'
match = '^surjectors/.*/((?!_test).)*\.py'
Loading

0 comments on commit 68c2f8a

Please sign in to comment.