From c44cda11f9d0104f65ee44436d26972a6612061c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 18:42:30 +0100 Subject: [PATCH 1/5] fix lints :) --- README.md | 10 ++- docs/index.rst | 25 ++++++-- docs/news.rst | 3 +- docs/surjectors.rst | 13 +++- pyproject.toml | 1 + surjectors/__init__.py | 18 ++++-- .../bijectors/affine_masked_autoregressive.py | 45 +++++++++++++ .../_src/bijectors/affine_masked_coupling.py | 52 +++++++++++++++ .../_src/bijectors/masked_autoregressive.py | 12 +--- .../_src/bijectors/rq_masked_coupling.py | 63 +++++++++++++++++++ ..._masked_autoregressive_inference_funnel.py | 2 +- 11 files changed, 215 insertions(+), 29 deletions(-) create mode 100644 surjectors/_src/bijectors/affine_masked_autoregressive.py create mode 100644 surjectors/_src/bijectors/affine_masked_coupling.py create mode 100644 surjectors/_src/bijectors/rq_masked_coupling.py diff --git a/README.md b/README.md index 591e27f..db21dea 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,12 @@ Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. -Surjectors builds on Distrax and Haiku and is fully compatible with both of them. - Surjectors makes use of -- Haiku`s module system for neural networks, -- Distrax for probability distributions and some base bijectors, -- Optax for gradient-based optimization, -- JAX for autodiff and XLA computation. +- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks, +- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors, +- [Optax](https://github.com/deepmind/optax) for gradient-based optimization, +- [JAX](https://github.com/google/jax) for autodiff and XLA computation. ## Examples diff --git a/docs/index.rst b/docs/index.rst index 122beeb..2e9c0c6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,10 +15,10 @@ Surjectors builds on Distrax and Haiku and is fully compatible with both of them Surjectors makes use of -- Haiku`s module system for neural networks, -- Distrax for probability distributions and some base bijectors, -- Optax for gradient-based optimization, -- JAX for autodiff and XLA computation. +- `Haiku's `_ module system for neural networks, +- `Distrax `_ for probability distributions and some base bijectors, +- `Optax `_ for gradient-based optimization, +- `JAX `_ for autodiff and XLA computation. Example ------- @@ -90,6 +90,23 @@ In order to contribute: 4) test it by calling :code:`hatch run test` on the (Unix) command line, 5) submit a PR 🙂 +Citing Surjectors +----------------- + +.. code-block:: latex + + @article{dirmeier2024surjectors, + author = {Simon Dirmeier}, + title = {Surjectors: surjection layers for density estimation with normalizing flows}, + year = {2024}, + journal = {Journal of Open Source Software}, + publisher = {The Open Journal}, + volume = {9}, + number = {94}, + pages = {6188}, + doi = {10.21105/joss.06188} + } + License ------- diff --git a/docs/news.rst b/docs/news.rst index cd7524a..1542a62 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -7,4 +7,5 @@ Latest news on the development of `Surjectors`. .. note:: - No news so far :). + - 29.02.2024. Surjectors has been accepted for publication in the Journal of Open Source Software 🎉🎉🎉. + Find the paper here `10.21105/joss.06188 `_. diff --git a/docs/surjectors.rst b/docs/surjectors.rst index ca670cc..6fa65a9 100644 --- a/docs/surjectors.rst +++ b/docs/surjectors.rst @@ -54,7 +54,10 @@ Bijective layers .. autosummary:: MaskedAutoregressive + AffineMaskedAutoregressive MaskedCoupling + AffineMaskedCoupling + RationalQuadraticSplineMaskedCoupling Permutation Autoregressive bijections @@ -63,12 +66,21 @@ Autoregressive bijections .. autoclass:: MaskedAutoregressive :members: __init__ +.. autoclass:: AffineMaskedAutoregressive + :members: __init__ + Coupling bijections ~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedCoupling :members: __init__ +.. autoclass:: AffineMaskedCoupling + :members: __init__ + +.. autoclass:: RationalQuadraticSplineMaskedCoupling + :members: __init__ + Other bijections ~~~~~~~~~~~~~~~~ @@ -98,7 +110,6 @@ Coupling inference surjections .. autoclass:: MaskedCouplingInferenceFunnel :members: __init__ - .. autoclass:: AffineMaskedCouplingInferenceFunnel :members: __init__ diff --git a/pyproject.toml b/pyproject.toml index 7ce6352..1278d5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,3 +101,4 @@ skips = ["B101"] [tool.pydocstyle] convention= 'google' match = '^surjectors/.*/((?!_test).)*\.py' +add-ignore = ["D107"] diff --git a/surjectors/__init__.py b/surjectors/__init__.py index b69fc0f..8b82fe4 100644 --- a/surjectors/__init__.py +++ b/surjectors/__init__.py @@ -2,25 +2,30 @@ surjectors: Surjection layers for density estimation with normalizing flows """ -__version__ = "0.3.1" +__version__ = "0.3.2" +from surjectors._src.bijectors.affine_masked_autoregressive import ( + AffineMaskedAutoregressive, +) +from surjectors._src.bijectors.affine_masked_coupling import ( + AffineMaskedCoupling, +) from surjectors._src.bijectors.lu_linear import LULinear from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive from surjectors._src.bijectors.masked_coupling import MaskedCoupling from surjectors._src.bijectors.permutation import Permutation +from surjectors._src.bijectors.rq_masked_coupling import ( + RationalQuadraticSplineMaskedCoupling, +) from surjectors._src.distributions.transformed_distribution import ( TransformedDistribution, ) from surjectors._src.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501 AffineMaskedAutoregressiveInferenceFunnel, ) -from surjectors._src.surjectors.affine_masked_coupling_generative_funnel import ( # noqa: E501 - AffineMaskedCouplingGenerativeFunnel, -) from surjectors._src.surjectors.affine_masked_coupling_inference_funnel import ( AffineMaskedCouplingInferenceFunnel, ) -from surjectors._src.surjectors.augment import Augment from surjectors._src.surjectors.chain import Chain from surjectors._src.surjectors.masked_autoregressive_inference_funnel import ( # noqa: E501 MaskedAutoregressiveInferenceFunnel, @@ -41,11 +46,14 @@ "Chain", "Permutation", "TransformedDistribution", + "AffineMaskedAutoregressive", "MaskedAutoregressive", "MaskedAutoregressiveInferenceFunnel", "AffineMaskedAutoregressiveInferenceFunnel", "RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel", "MaskedCoupling", + "AffineMaskedCoupling", + "RationalQuadraticSplineMaskedCoupling", "MaskedCouplingInferenceFunnel", "AffineMaskedCouplingInferenceFunnel", "RationalQuadraticSplineMaskedCouplingInferenceFunnel", diff --git a/surjectors/_src/bijectors/affine_masked_autoregressive.py b/surjectors/_src/bijectors/affine_masked_autoregressive.py new file mode 100644 index 0000000..fbdd9dd --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_autoregressive.py @@ -0,0 +1,45 @@ +import distrax +from jax import numpy as jnp + +from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive +from surjectors._src.conditioners.nn.made import MADE +from surjectors.util import unstack + + +# pylint: disable=too-many-arguments,arguments-renamed +class AffineMaskedAutoregressive(MaskedAutoregressive): + """An affine masked autoregressive layer. + + Args: + conditioner: a MADE network + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: tthe number of array dimensions the bijector + operates on + + References: + .. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. + + Examples: + >>> import distrax + >>> from surjectors import AffineMaskedAutoregressive + >>> + >>> layer = AffineMaskedAutoregressive( + >>> conditioner=MADE(10, [8, 8], 2), + >>> ) + """ + + def __init__( + self, + conditioner: MADE, + event_ndims: int = 1, + inner_event_ndims: int = 0, + ): + def bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + super().__init__( + conditioner, bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/bijectors/affine_masked_coupling.py b/surjectors/_src/bijectors/affine_masked_coupling.py new file mode 100644 index 0000000..fe0ccec --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_coupling.py @@ -0,0 +1,52 @@ +from typing import Callable, Optional + +import distrax +from jax import numpy as jnp + +from surjectors._src.bijectors.masked_coupling import MaskedCoupling +from surjectors._src.distributions.transformed_distribution import Array + + +# pylint: disable=too-many-arguments, arguments-renamed,too-many-ancestors +class AffineMaskedCoupling(MaskedCoupling): + """An affine masked coupling layer. + + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. + + Examples: + >>> import distrax + >>> from surjectors import AffineMaskedCoupling + >>> from surjectors.nn import make_mlp + >>> from surjectors.util import make_alternating_binary_mask + >>> + >>> layer = MaskedCoupling( + >>> mask=make_alternating_binary_mask(10, True), + >>> conditioner=make_mlp([8, 8, 10 * 2]), + >>> ) + """ + + def __init__( + self, + mask: Array, + conditioner: Callable, + event_ndims: Optional[int] = None, + inner_event_ndims: int = 0, + ): + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + super().__init__( + mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/bijectors/masked_autoregressive.py b/surjectors/_src/bijectors/masked_autoregressive.py index 20e39a5..16fa233 100644 --- a/surjectors/_src/bijectors/masked_autoregressive.py +++ b/surjectors/_src/bijectors/masked_autoregressive.py @@ -7,7 +7,7 @@ from surjectors._src.conditioners.nn.made import MADE -# pylint: disable=too-many-arguments, arguments-renamed +# pylint: disable=too-many-arguments,arguments-renamed class MaskedAutoregressive(Bijector): """A masked autoregressive layer. @@ -38,16 +38,6 @@ def __init__( event_ndims: int = 1, inner_event_ndims: int = 0, ): - """Construct a masked autoregressive layer. - - Args: - conditioner: a MADE network - bijector_fn: a callable that returns the inner bijector that will - be used to transform the input - event_ndims: the number of array dimensions the bijector operates on - inner_event_ndims: tthe number of array dimensions the bijector - operates on - """ if event_ndims is not None and event_ndims < inner_event_ndims: raise ValueError( f"'event_ndims={event_ndims}' should be at least as" diff --git a/surjectors/_src/bijectors/rq_masked_coupling.py b/surjectors/_src/bijectors/rq_masked_coupling.py new file mode 100644 index 0000000..7c93947 --- /dev/null +++ b/surjectors/_src/bijectors/rq_masked_coupling.py @@ -0,0 +1,63 @@ +from typing import Callable, Optional + +import distrax + +from surjectors._src.bijectors.masked_coupling import MaskedCoupling +from surjectors._src.distributions.transformed_distribution import Array + + +# pylint: disable=too-many-arguments, arguments-renamed,too-many-ancestors +class RationalQuadraticSplineMaskedCoupling(MaskedCoupling): + """A rational quadratic spline masked coupling layer. + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + + Examples: + >>> import distrax + >>> from surjectors import AffineMaskedCoupling + >>> from surjectors.nn import make_mlp + >>> from surjectors.util import make_alternating_binary_mask + >>> + >>> layer = MaskedCoupling( + >>> mask=make_alternating_binary_mask(10, True), + >>> conditioner=make_mlp([8, 8, 10 * 2]), + >>> ) + """ + + def __init__( + self, + mask: Array, + conditioner: Callable, + range_min: float, + range_max: float, + event_ndims: Optional[int] = None, + inner_event_ndims: int = 0, + ): + """Construct a rational quadratic spline masked coupling layer. + + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + range_min: minimum range of the spline + range_max: maximum range of the spline + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on + """ + self.range_min = range_min + self.range_max = range_max + + def _bijector_fn(params: Array): + return distrax.RationalQuadraticSpline( + params, self.range_min, self.range_max + ) + + super().__init__( + mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py index 1b98792..b01eab2 100644 --- a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py @@ -68,7 +68,7 @@ def __init__( range_min: minimum range of the spline range_max: maximum range of the spline """ - warnings.warn("class has not been tested. use at own risk") + warnings.warn("class has not been tested properly. use at own risk") self.range_min = range_min self.range_max = range_max From e08488249832f8d32fc3e0aebcdb2a4db1360e88 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 20:22:52 +0100 Subject: [PATCH 2/5] fix lints :) --- .pre-commit-config.yaml | 17 ++--- docs/conf.py | 12 ++++ docs/surjectors.rst | 38 ++++++----- pyproject.toml | 22 ++++--- surjectors/__init__.py | 4 +- .../affine_masked_autoregressive_test.py | 53 +++++++++++++++ .../bijectors/affine_masked_coupling_test.py | 55 ++++++++++++++++ surjectors/_src/bijectors/lu_linear.py | 12 ++-- .../_src/bijectors/masked_autoregressive.py | 8 +++ surjectors/_src/bijectors/masked_coupling.py | 30 ++++----- surjectors/_src/bijectors/permutation.py | 36 ++-------- .../_src/bijectors/rq_masked_coupling.py | 6 +- .../_src/bijectors/rq_masked_coupling_test.py | 65 +++++++++++++++++++ surjectors/_src/conditioners/mlp.py | 1 + surjectors/_src/conditioners/nn/made.py | 4 +- .../distributions/transformed_distribution.py | 10 ++- ..._masked_autoregressive_inference_funnel.py | 29 ++++----- ...ffine_masked_coupling_generative_funnel.py | 15 ++--- ...affine_masked_coupling_inference_funnel.py | 27 ++++---- surjectors/_src/surjectors/augment.py | 13 ++-- surjectors/_src/surjectors/chain.py | 15 ++--- .../masked_autoregressive_inference_funnel.py | 30 ++++----- .../masked_coupling_inference_funnel.py | 28 ++++---- surjectors/_src/surjectors/mlp.py | 18 +++-- ..._masked_autoregressive_inference_funnel.py | 36 +++++----- .../rq_masked_coupling_inference_funnel.py | 34 +++++----- ...q_masked_coupling_inference_funnel_test.py | 2 +- surjectors/_src/surjectors/slice.py | 18 +++-- surjectors/_src/surjectors/surjector.py | 2 +- surjectors/util.py | 2 + 30 files changed, 392 insertions(+), 250 deletions(-) create mode 100644 surjectors/_src/bijectors/affine_masked_autoregressive_test.py create mode 100644 surjectors/_src/bijectors/affine_masked_coupling_test.py create mode 100644 surjectors/_src/bijectors/rq_masked_coupling_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e47831..1d6a88e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,15 +43,6 @@ repos: additional_dependencies: ["toml"] files: "(surjectors|examples)" -- repo: https://github.com/PyCQA/flake8 - rev: 5.0.1 - hooks: - - id: flake8 - additional_dependencies: [ - flake8-typing-imports==1.14.0, - flake8-pyproject==1.1.0.post0 - ] - - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910-1 hooks: @@ -59,8 +50,8 @@ repos: args: ["--ignore-missing-imports"] files: "(surjectors|examples)" -- repo: https://github.com/pycqa/pydocstyle - rev: 6.1.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.4 hooks: - - id: pydocstyle - additional_dependencies: ["toml"] + - id: ruff + - id: ruff-format \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 44b14bd..00db2a4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,8 @@ "examples/*py" ] +autodoc_typehints = "both" + html_theme = "sphinx_book_theme" html_theme_options = { @@ -52,3 +54,13 @@ } html_title = "Surjectors 🚀" + + +def skip(app, what, name, obj, would_skip, options): + if name == "__init__": + return True + return would_skip + + +def setup(app): + app.connect("autodoc-skip-member", skip) \ No newline at end of file diff --git a/docs/surjectors.rst b/docs/surjectors.rst index 6fa65a9..2db2170 100644 --- a/docs/surjectors.rst +++ b/docs/surjectors.rst @@ -29,6 +29,10 @@ Hence, every normalizing flow can be composed by defining these three components >>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)]) >>> pushforward = TransformedDistribution(base_distribution, flow) +Regardless of how the chain of transformations (called :code:`flow` above) is defined, +each pushforward has access to four methods :code:`sample`, :code:`sample_and_log_prob`:code:`log_prob`, and :code:`inverse_and_log_prob`. + +The exact method declarations can be found in the API below. General ------- @@ -41,13 +45,13 @@ TransformedDistribution ~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: TransformedDistribution - :members: + :members: log_prob, sample, inverse_and_log_prob, sample_and_log_prob Chain ~~~~~ .. autoclass:: Chain - :members: __init__ + :members: Bijective layers ---------------- @@ -64,28 +68,28 @@ Autoregressive bijections ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedAutoregressive - :members: __init__ + :members: .. autoclass:: AffineMaskedAutoregressive - :members: __init__ + :members: Coupling bijections ~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedCoupling - :members: __init__ + :members: .. autoclass:: AffineMaskedCoupling - :members: __init__ + :members: .. autoclass:: RationalQuadraticSplineMaskedCoupling - :members: __init__ + :members: Other bijections ~~~~~~~~~~~~~~~~ .. autoclass:: Permutation - :members: __init__ + :members: Inference surjection layers --------------------------- @@ -108,34 +112,34 @@ Coupling inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedCouplingInferenceFunnel - :members: __init__ + :members: .. autoclass:: AffineMaskedCouplingInferenceFunnel - :members: __init__ + :members: .. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel - :members: __init__ + :members: Autoregressive inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: .. autoclass:: AffineMaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: .. autoclass:: RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: Other inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LULinear - :members: __init__ + :members: .. autoclass:: MLPInferenceFunnel - :members: __init__ + :members: .. autoclass:: Slice - :members: __init__ + :members: diff --git a/pyproject.toml b/pyproject.toml index 1278d5d..6626ef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,13 +83,6 @@ profile = "black" line_length = 80 include_trailing_comma = true -[tool.flake8] -max-line-length = 80 -extend-ignore = ["E203", "W503"] -per-file-ignores = [ - '__init__.py:F401', -] - [tool.pylint.messages_control] disable = """ invalid-name,missing-module-docstring,R0801 @@ -98,7 +91,16 @@ invalid-name,missing-module-docstring,R0801 [tool.bandit] skips = ["B101"] -[tool.pydocstyle] +[tool.ruff] +line-length = 80 +exclude = ["*_test.py", "docs/**", "examples/**"] + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F"] +extend-select = [ + "UP", # pyupgrade + "D", # pydocstyle +] + +[tool.ruff.lint.pydocstyle] convention= 'google' -match = '^surjectors/.*/((?!_test).)*\.py' -add-ignore = ["D107"] diff --git a/surjectors/__init__.py b/surjectors/__init__.py index 8b82fe4..1105a10 100644 --- a/surjectors/__init__.py +++ b/surjectors/__init__.py @@ -1,6 +1,4 @@ -""" -surjectors: Surjection layers for density estimation with normalizing flows -""" +"""surjectors: Surjection layers for density estimation with normalizing flows.""" __version__ = "0.3.2" diff --git a/surjectors/_src/bijectors/affine_masked_autoregressive_test.py b/surjectors/_src/bijectors/affine_masked_autoregressive_test.py new file mode 100644 index 0000000..863e1bb --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_autoregressive_test.py @@ -0,0 +1,53 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import AffineMaskedAutoregressive, TransformedDistribution +from surjectors.nn import MADE + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = AffineMaskedAutoregressive( + MADE(n_dimension, [8, 8], 2), + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_affine_masked_autoregressive(): + n_dimension = 4 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_affine_masked_autoregressive(): + n_dimension = 4 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/bijectors/affine_masked_coupling_test.py b/surjectors/_src/bijectors/affine_masked_coupling_test.py new file mode 100644 index 0000000..f53d045 --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_coupling_test.py @@ -0,0 +1,55 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import AffineMaskedCoupling, TransformedDistribution +from surjectors.nn import make_mlp +from surjectors.util import make_alternating_binary_mask + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = AffineMaskedCoupling( + make_alternating_binary_mask(n_dimension, 0 % 2 == 0), + make_mlp([8, 8, n_dimension * 2]), + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_affine_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_affine_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/bijectors/lu_linear.py b/surjectors/_src/bijectors/lu_linear.py index 01e1dd8..f8dc1e9 100644 --- a/surjectors/_src/bijectors/lu_linear.py +++ b/surjectors/_src/bijectors/lu_linear.py @@ -10,6 +10,11 @@ class LULinear(Bijector, hk.Module): """An bijection based on the LU composition. + Args: + n_dimension: number of dimensions to keep + with_bias: use a bias term or not + dtype: parameter dtype + References: .. [1] Oliva, Junier, et al. "Transformation Autoregressive Networks". Proceedings of the 35th International Conference on @@ -21,13 +26,6 @@ class LULinear(Bijector, hk.Module): """ def __init__(self, n_dimension, with_bias=False, dtype=jnp.float32): - """Constructs a LULinear layer. - - Args: - n_dimension: number of dimensions to keep - with_bias: use a bias term or not - dtype: parameter dtype - """ super().__init__() if with_bias: raise NotImplementedError() diff --git a/surjectors/_src/bijectors/masked_autoregressive.py b/surjectors/_src/bijectors/masked_autoregressive.py index 16fa233..59d1e87 100644 --- a/surjectors/_src/bijectors/masked_autoregressive.py +++ b/surjectors/_src/bijectors/masked_autoregressive.py @@ -11,6 +11,14 @@ class MaskedAutoregressive(Bijector): """A masked autoregressive layer. + Args: + conditioner: a MADE network + bijector_fn: a callable that returns the inner bijector that will + be used to transform the input + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: tthe number of array dimensions the bijector + operates on + References: .. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for Density Estimation". Advances in Neural Information Processing diff --git a/surjectors/_src/bijectors/masked_coupling.py b/surjectors/_src/bijectors/masked_coupling.py index 8480df3..cbb2f99 100644 --- a/surjectors/_src/bijectors/masked_coupling.py +++ b/surjectors/_src/bijectors/masked_coupling.py @@ -12,9 +12,16 @@ class MaskedCoupling(Bijector, distrax.MaskedCoupling): """A masked coupling layer. - References: - .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + bijector_fn: a callable that returns the inner bijector that will be + used to transform the input + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on Examples: >>> import distrax @@ -31,6 +38,10 @@ class MaskedCoupling(Bijector, distrax.MaskedCoupling): >>> bijector_fn=bijector_fn, >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> ) + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__( @@ -41,19 +52,6 @@ def __init__( event_ndims: Optional[int] = None, inner_event_ndims: int = 0, ): - """Construct a masked coupling layer. - - Args: - mask: a boolean mask of length n_dim. A value - of True indicates that the corresponding input remains unchanged - conditioner: a function that computes the parameters of the inner - bijector - bijector_fn: a callable that returns the inner bijector that will be - used to transform the input - event_ndims: the number of array dimensions the bijector operates on - inner_event_ndims: the number of array dimensions the inner bijector - operates on - """ super().__init__( mask, conditioner, bijector_fn, event_ndims, inner_event_ndims ) diff --git a/surjectors/_src/bijectors/permutation.py b/surjectors/_src/bijectors/permutation.py index 629cf78..1302f36 100644 --- a/surjectors/_src/bijectors/permutation.py +++ b/surjectors/_src/bijectors/permutation.py @@ -6,6 +6,11 @@ class Permutation(distrax.Bijector): """Permute the dimensions of a vector. + Args: + permutation: a vector of integer indexes representing the order of + the elements + event_ndims_in: number of input event dimensions + Examples: >>> from surjectors import Permutation >>> from jax import numpy as jnp @@ -15,40 +20,13 @@ class Permutation(distrax.Bijector): """ def __init__(self, permutation, event_ndims_in: int): - """Construct a permutation layer. - - Args: - permutation: a vector of integer indexes representing the order of - the elements - event_ndims_in: number of input event dimensions - """ super().__init__(event_ndims_in) self.permutation = permutation - def forward_and_log_det(self, z): - """Compute the forward transformation and its Jacobian determinant. - - Args: - z: event for which the forward transform and likelihood contribution - is computed - - Returns: - tuple of two arrays of floats. The first one is the forward - transformation, the second one its likelihood contribution - """ + def _forward_and_likelihood_contribution(self, z): return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0) - def inverse_and_log_det(self, y): - """Compute the inverse transformation and its Jacobian determinant. - - Args: - y: event for which the inverse and likelihood contribution is - computed - - Returns: - tuple of two arrays of floats. The first one is the inverse - transformation, the second one its likelihood contribution - """ + def _inverse_and_likelihood_contribution(self, y): size = self.permutation.size permutation_inv = ( jnp.zeros(size, dtype=jnp.result_type(int)) diff --git a/surjectors/_src/bijectors/rq_masked_coupling.py b/surjectors/_src/bijectors/rq_masked_coupling.py index 7c93947..6fa4db4 100644 --- a/surjectors/_src/bijectors/rq_masked_coupling.py +++ b/surjectors/_src/bijectors/rq_masked_coupling.py @@ -18,13 +18,15 @@ class RationalQuadraticSplineMaskedCoupling(MaskedCoupling): Examples: >>> import distrax - >>> from surjectors import AffineMaskedCoupling + >>> from surjectors import RationalQuadraticSplineMaskedCoupling >>> from surjectors.nn import make_mlp >>> from surjectors.util import make_alternating_binary_mask >>> - >>> layer = MaskedCoupling( + >>> layer = RationalQuadraticSplineMaskedCoupling( >>> mask=make_alternating_binary_mask(10, True), >>> conditioner=make_mlp([8, 8, 10 * 2]), + >>> range_min=-1.0, + >>> range_max=1.0 >>> ) """ diff --git a/surjectors/_src/bijectors/rq_masked_coupling_test.py b/surjectors/_src/bijectors/rq_masked_coupling_test.py new file mode 100644 index 0000000..b8ad3bf --- /dev/null +++ b/surjectors/_src/bijectors/rq_masked_coupling_test.py @@ -0,0 +1,65 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import ( + RationalQuadraticSplineMaskedCoupling, + TransformedDistribution, +) +from surjectors.nn import make_mlp +from surjectors.util import make_alternating_binary_mask + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = RationalQuadraticSplineMaskedCoupling( + make_alternating_binary_mask(n_dimension, 0 % 2 == 0), + hk.Sequential( + [ + make_mlp([8, 8, n_dimension * 10]), + hk.Reshape((n_dimension, 10)), + ] + ), + -1.0, + 1.0, + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_rq_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_rq_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/conditioners/mlp.py b/surjectors/_src/conditioners/mlp.py index 63db8e9..9d24f1b 100644 --- a/surjectors/_src/conditioners/mlp.py +++ b/surjectors/_src/conditioners/mlp.py @@ -3,6 +3,7 @@ from jax import numpy as jnp +# type: ignore[B008] def make_mlp( dims, activation=jax.nn.gelu, diff --git a/surjectors/_src/conditioners/nn/made.py b/surjectors/_src/conditioners/nn/made.py index 4904e01..e2f46ea 100644 --- a/surjectors/_src/conditioners/nn/made.py +++ b/surjectors/_src/conditioners/nn/made.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import haiku as hk import jax @@ -26,7 +26,7 @@ class MADE(hk.Module): def __init__( self, input_size: int, - hidden_layer_sizes: Union[List[int], Tuple[int]], + hidden_layer_sizes: Union[list[int], tuple[int]], n_params: int, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, diff --git a/surjectors/_src/distributions/transformed_distribution.py b/surjectors/_src/distributions/transformed_distribution.py index 37cfb5e..2df71db 100644 --- a/surjectors/_src/distributions/transformed_distribution.py +++ b/surjectors/_src/distributions/transformed_distribution.py @@ -12,6 +12,10 @@ class TransformedDistribution: Can be used to define a pushforward measure. + Args: + base_distribution: a distribution object + transform: some transformation + Examples: >>> import distrax >>> from jax import numpy as jnp @@ -28,12 +32,6 @@ class TransformedDistribution: """ def __init__(self, base_distribution: Distribution, transform: Surjector): - """Constructs a TransformedDistribution. - - Args: - base_distribution: a distribution object - transform: some transformation - """ self.base_distribution = base_distribution self.transform = transform diff --git a/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py index 203ebaa..8e4de9c 100644 --- a/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py @@ -21,13 +21,11 @@ class AffineMaskedAutoregressiveInferenceFunnel( transformation from data to latent space using a masking mechanism as in MaskedAutoegressive. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a MADE neural network Examples: >>> import distrax @@ -49,18 +47,17 @@ class AffineMaskedAutoregressiveInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__(self, n_keep: int, decoder: Callable, conditioner: MADE): - """Constructs a AffineMaskedAutoregressiveInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a MADE neural network - """ - def bijector_fn(params: Array): shift, log_scale = unstack(params, axis=-1) return distrax.ScalarAffine(shift, jnp.exp(log_scale)) diff --git a/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py b/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py index 784bea5..d0b8894 100644 --- a/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py @@ -8,16 +8,15 @@ class AffineMaskedCouplingGenerativeFunnel(Surjector): - """A Generative funnel layer using masked affine coupling.""" + """A Generative funnel layer using masked affine coupling. - def __init__(self, n_keep, encoder, conditioner): - """Construct a AffineMaskedCouplingGenerativeFunnel layer. + Args: + n_keep: number of dimensions to keep + encoder: callable + conditioner: callable + """ - Args: - n_keep: number of dimensions to keep - encoder: callable - conditioner: callable - """ + def __init__(self, n_keep, encoder, conditioner): self.n_keep = n_keep self.encoder = encoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py index 6b04305..0a74524 100644 --- a/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py @@ -12,12 +12,11 @@ class AffineMaskedCouplingInferenceFunnel(MaskedCouplingInferenceFunnel): """A masked coupling inference funnel that uses an affine transformation. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network Examples: >>> import distrax @@ -39,18 +38,16 @@ class AffineMaskedCouplingInferenceFunnel(MaskedCouplingInferenceFunnel): >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__(self, n_keep: int, decoder: Callable, conditioner: Callable): - """Constructs a AffineMaskedCouplingInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - """ - def bijector_fn(params: Array): shift, log_scale = jnp.split(params, 2, axis=-1) return distrax.ScalarAffine(shift, jnp.exp(log_scale)) diff --git a/surjectors/_src/surjectors/augment.py b/surjectors/_src/surjectors/augment.py index f1ce7ad..9526e0f 100644 --- a/surjectors/_src/surjectors/augment.py +++ b/surjectors/_src/surjectors/augment.py @@ -5,15 +5,14 @@ class Augment(Surjector): - """Augment generative funnel.""" + """Augment generative funnel. - def __init__(self, n_keep, encoder): - """Construct an augmentation layer. + Args: + n_keep: number of dimensions to keep + encoder: encoder callable + """ - Args: - n_keep: number of dimensions to keep - encoder: encoder callable - """ + def __init__(self, n_keep, encoder): self.n_keep = n_keep self.encoder = encoder diff --git a/surjectors/_src/surjectors/chain.py b/surjectors/_src/surjectors/chain.py index 0814ac7..98e9166 100644 --- a/surjectors/_src/surjectors/chain.py +++ b/surjectors/_src/surjectors/chain.py @@ -1,14 +1,17 @@ -from typing import List - from surjectors._src._transform import Transform from surjectors._src.surjectors.surjector import Surjector +# type: ignore[B009] class Chain(Surjector): """Chain of normalizing flows. Can be used to concatenate several normalizing flows together. + Args: + transforms: a list of transformations, such as bijections or + surjections + Examples: >>> from surjectors import Slice, Chain >>> a = Slice(10) @@ -16,13 +19,7 @@ class Chain(Surjector): >>> ab = Chain([a, b]) """ - def __init__(self, transforms: List[Transform]): - """Constructs a Chain. - - Args: - transforms: a list of transformations, such as bijections or - surjections - """ + def __init__(self, transforms: list[Transform]): self._transforms = transforms def _inverse_and_likelihood_contribution(self, y, x=None, **kwargs): diff --git a/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py index e23f46f..8044c36 100644 --- a/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py @@ -17,13 +17,12 @@ class MaskedAutoregressiveInferenceFunnel(Surjector): comparison to AffineMaskedAutoregressiveInferenceFunnel and RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a MADE neural network + bijector_fn: an inner bijector function to be used Examples: >>> import distrax @@ -50,6 +49,14 @@ class MaskedAutoregressiveInferenceFunnel(Surjector): >>> conditioner=MADE(10, [8, 8], 2), >>> bijector_fn=bijector_fn >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__( @@ -59,15 +66,6 @@ def __init__( conditioner: MADE, bijector_fn: Callable, ): - """Constructs a MaskedAutoregressiveInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a MADE neural network - bijector_fn: an inner bijector function to be used - """ self.n_keep = n_keep self.decoder = decoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/masked_coupling_inference_funnel.py index 42c67e1..7e071e1 100644 --- a/surjectors/_src/surjectors/masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/masked_coupling_inference_funnel.py @@ -17,12 +17,12 @@ class MaskedCouplingInferenceFunnel(Surjector): comparison to ASffineMaskedCouplingInferenceFunnel and RationalQuadraticSplineMaskedCouplingInferenceFunnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + bijector_fn: an inner bijector function to be used Examples: >>> import distrax @@ -48,6 +48,13 @@ class MaskedCouplingInferenceFunnel(Surjector): >>> conditioner=make_mlp([4, 4, 10 * 2]), >>> bijector_fn=bijector_fn >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__( @@ -57,15 +64,6 @@ def __init__( conditioner: Callable, bijector_fn: Callable, ): - """Construct a MaskedCouplingInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - bijector_fn: an inner bijector function to be used - """ self.n_keep = n_keep self.decoder = decoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/mlp.py b/surjectors/_src/surjectors/mlp.py index 67c4274..719f4ea 100644 --- a/surjectors/_src/surjectors/mlp.py +++ b/surjectors/_src/surjectors/mlp.py @@ -10,10 +10,9 @@ class MLPInferenceFunnel(Surjector, hk.Module): """A multilayer perceptron inference funnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. + Args: + n_keep: number of dimensions to keep + decoder: a conditional probability function Examples: >>> import distrax @@ -31,15 +30,14 @@ class MLPInferenceFunnel(Surjector, hk.Module): >>> >>> decoder = decoder_fn(5) >>> a = MLPInferenceFunnel(10, decoder) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. """ def __init__(self, n_keep: int, decoder: Callable): - """Constructs a MLPInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a conditional probability function - """ super().__init__() self._r = LULinear(n_keep, False) self._w_prime = hk.Linear(n_keep, True) diff --git a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py index b01eab2..2020265 100644 --- a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py @@ -17,15 +17,13 @@ class RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( ): """A masked autoregressive inference funnel that uses RQ-NSFs. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Durkan, Conor, et al. "Neural Spline Flows". - Advances in Neural Information Processing Systems, 2019. - .. [3] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + range_min: minimum range of the spline + range_max: maximum range of the spline Examples: >>> import distrax @@ -48,6 +46,16 @@ class RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + .. [3] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__( @@ -58,16 +66,6 @@ def __init__( range_min: float, range_max: float, ): - """Constructs a RQ-NSF inference funnel. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - range_min: minimum range of the spline - range_max: maximum range of the spline - """ warnings.warn("class has not been tested properly. use at own risk") self.range_min = range_min self.range_max = range_max diff --git a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py index b4a53d3..6e025df 100644 --- a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py @@ -12,14 +12,13 @@ class RationalQuadraticSplineMaskedCouplingInferenceFunnel( ): """A masked coupling inference funnel that uses a rational quatratic spline. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Durkan, Conor, et al. "Neural Spline Flows". - Advances in Neural Information Processing Systems, 2019. - .. [3] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + range_min: minimum range of the spline + range_max: maximum range of the spline Examples: >>> import distrax @@ -42,19 +41,18 @@ class RationalQuadraticSplineMaskedCouplingInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + .. [3] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__(self, n_keep, decoder, conditioner, range_min, range_max): - """Construct a RationalQuadraticSplineMaskedCouplingInferenceFunnel. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - range_min: minimum range of the spline - range_max: maximum range of the spline - """ self.range_min = range_min self.range_max = range_max diff --git a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py index bba233c..e4d1e10 100644 --- a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py +++ b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py @@ -1,4 +1,4 @@ -# pylint: skip-file +# type: ignore import distrax import haiku as hk diff --git a/surjectors/_src/surjectors/slice.py b/surjectors/_src/surjectors/slice.py index 5cf9763..6d3b024 100644 --- a/surjectors/_src/surjectors/slice.py +++ b/surjectors/_src/surjectors/slice.py @@ -9,10 +9,9 @@ class Slice(Surjector): """A slice funnel. - References: - .. [1] Nielsen, Didrik, et al. "SurVAE Flows: Surjections to Bridge the - Gap between VAEs and Flows". Advances in Neural Information - Processing Systems, 2020. + Args: + n_keep: number if dimensions to keep + decoder: callable Examples: >>> import distrax @@ -29,15 +28,14 @@ class Slice(Surjector): >>> return _fn >>> >>> layer = Slice(10, decoder_fn(10)) + + References: + .. [1] Nielsen, Didrik, et al. "SurVAE Flows: Surjections to Bridge the + Gap between VAEs and Flows". Advances in Neural Information + Processing Systems, 2020. """ def __init__(self, n_keep: int, decoder: Callable): - """Constructs a slice layer. - - Args: - n_keep: number if dimensions to keep - decoder: callable - """ self.n_keep = n_keep self.decoder = decoder diff --git a/surjectors/_src/surjectors/surjector.py b/surjectors/_src/surjectors/surjector.py index 0a20a17..9c8709e 100644 --- a/surjectors/_src/surjectors/surjector.py +++ b/surjectors/_src/surjectors/surjector.py @@ -10,7 +10,7 @@ class Surjector(Transform, ABC): """A surjective transformation.""" def __call__(self, method, **kwargs): - """Call the Surjector. + """Call the surjector. Depending on "method", computes - inverse, diff --git a/surjectors/util.py b/surjectors/util.py index 78055c9..828ffcf 100644 --- a/surjectors/util.py +++ b/surjectors/util.py @@ -1,3 +1,5 @@ +"""Utility functions.""" + from collections import namedtuple import numpy as np From 4c69671022cd6e399080cc21ff0eb6eaff03a6e5 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 20:36:23 +0100 Subject: [PATCH 3/5] fix lints :) --- .pre-commit-config.yaml | 20 -------------- pyproject.toml | 27 +++---------------- surjectors/_src/bijectors/masked_coupling.py | 2 +- .../_src/bijectors/rq_masked_coupling.py | 2 +- surjectors/_src/conditioners/nn/made.py | 2 +- surjectors/_src/conditioners/transformer.py | 2 +- ..._masked_autoregressive_inference_funnel.py | 2 +- .../rq_masked_coupling_inference_funnel.py | 1 + 8 files changed, 9 insertions(+), 49 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d6a88e..44310c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,26 +12,6 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v2.29.1 - hooks: - - id: pyupgrade - args: [--py38-plus] - -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - args: ["--config=pyproject.toml"] - files: "(surjectors|examples)" - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--settings-path=pyproject.toml"] - files: "(surjectors|examples)" - - repo: https://github.com/pycqa/bandit rev: 1.7.1 hooks: diff --git a/pyproject.toml b/pyproject.toml index 6626ef7..eeee310 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,35 +49,15 @@ dependencies = [ [tool.hatch.envs.test] dependencies = [ - "pylint>=2.15.10", + "ruff>=0.3.0", "pytest>=7.2.0", "pytest-cov>=4.0.0" ] [tool.hatch.envs.test.scripts] -lint = 'pylint surjectors' +lint = 'ruff check surjectors' test = 'pytest -v --cov=./surjectors --cov-report=xml surjectors' -[tool.black] -line-length = 80 -extend-ignore = "E203" -target-version = ['py310'] -exclude = ''' -/( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - - [tool.isort] profile = "black" line_length = 80 @@ -98,8 +78,7 @@ exclude = ["*_test.py", "docs/**", "examples/**"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F"] extend-select = [ - "UP", # pyupgrade - "D", # pydocstyle + "UP", "D", "I", "PL", "S" ] [tool.ruff.lint.pydocstyle] diff --git a/surjectors/_src/bijectors/masked_coupling.py b/surjectors/_src/bijectors/masked_coupling.py index cbb2f99..73fdd3a 100644 --- a/surjectors/_src/bijectors/masked_coupling.py +++ b/surjectors/_src/bijectors/masked_coupling.py @@ -8,7 +8,7 @@ from surjectors._src.distributions.transformed_distribution import Array -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class MaskedCoupling(Bijector, distrax.MaskedCoupling): """A masked coupling layer. diff --git a/surjectors/_src/bijectors/rq_masked_coupling.py b/surjectors/_src/bijectors/rq_masked_coupling.py index 6fa4db4..7bccd25 100644 --- a/surjectors/_src/bijectors/rq_masked_coupling.py +++ b/surjectors/_src/bijectors/rq_masked_coupling.py @@ -6,7 +6,7 @@ from surjectors._src.distributions.transformed_distribution import Array -# pylint: disable=too-many-arguments, arguments-renamed,too-many-ancestors +# ruff: noqa: PLR0913 class RationalQuadraticSplineMaskedCoupling(MaskedCoupling): """A rational quadratic spline masked coupling layer. diff --git a/surjectors/_src/conditioners/nn/made.py b/surjectors/_src/conditioners/nn/made.py index e2f46ea..0dc29aa 100644 --- a/surjectors/_src/conditioners/nn/made.py +++ b/surjectors/_src/conditioners/nn/made.py @@ -11,7 +11,7 @@ from surjectors._src.conditioners.nn.masked_linear import MaskedLinear -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class MADE(hk.Module): """Masked Autoregressive Density Estimator. diff --git a/surjectors/_src/conditioners/transformer.py b/surjectors/_src/conditioners/transformer.py index 37bd068..1967e27 100644 --- a/surjectors/_src/conditioners/transformer.py +++ b/surjectors/_src/conditioners/transformer.py @@ -64,7 +64,7 @@ def __call__(self, inputs, *, is_training=True): return hk.Linear(self.output_size)(h) -# pylint: disable=too-many-arguments +# ruff: noqa: PLR0913 def make_transformer( output_size, num_heads=4, diff --git a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py index 2020265..79aa4f2 100644 --- a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py @@ -11,7 +11,7 @@ ) -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( MaskedAutoregressiveInferenceFunnel ): diff --git a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py index 6e025df..e61a789 100644 --- a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py @@ -7,6 +7,7 @@ ) +# ruff: noqa: PLR0913 class RationalQuadraticSplineMaskedCouplingInferenceFunnel( MaskedCouplingInferenceFunnel ): From a56effbef5bda4ce2828d604f038a98755c93604 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 20:38:01 +0100 Subject: [PATCH 4/5] fix lints :) --- .github/workflows/ci.yaml | 6 +++--- .pre-commit-config.yaml | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca131d7..d018580 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -41,7 +41,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -61,7 +61,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 44310c2..815e351 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,8 @@ repos: files: "(surjectors|examples)" - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.4 + rev: v0.3.0 hooks: - id: ruff + args: [ --fix ] - id: ruff-format \ No newline at end of file From ba4f40b8840c49f0b0ea908671365d9926a7ea6c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 29 Feb 2024 21:01:45 +0100 Subject: [PATCH 5/5] move from planning to alpha --- pyproject.toml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eeee310..00d0471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" homepage = "https://github.com/dirmeier/surjectors" keywords = ["normalizing flows", "surjections", "density estimation"] classifiers = [ - "Development Status :: 1 - Planning", + "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", @@ -58,16 +58,6 @@ dependencies = [ lint = 'ruff check surjectors' test = 'pytest -v --cov=./surjectors --cov-report=xml surjectors' -[tool.isort] -profile = "black" -line_length = 80 -include_trailing_comma = true - -[tool.pylint.messages_control] -disable = """ -invalid-name,missing-module-docstring,R0801 -""" - [tool.bandit] skips = ["B101"]