diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca131d7..d018580 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -41,7 +41,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -61,7 +61,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e47831..815e351 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,26 +12,6 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v2.29.1 - hooks: - - id: pyupgrade - args: [--py38-plus] - -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - args: ["--config=pyproject.toml"] - files: "(surjectors|examples)" - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--settings-path=pyproject.toml"] - files: "(surjectors|examples)" - - repo: https://github.com/pycqa/bandit rev: 1.7.1 hooks: @@ -43,15 +23,6 @@ repos: additional_dependencies: ["toml"] files: "(surjectors|examples)" -- repo: https://github.com/PyCQA/flake8 - rev: 5.0.1 - hooks: - - id: flake8 - additional_dependencies: [ - flake8-typing-imports==1.14.0, - flake8-pyproject==1.1.0.post0 - ] - - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910-1 hooks: @@ -59,8 +30,9 @@ repos: args: ["--ignore-missing-imports"] files: "(surjectors|examples)" -- repo: https://github.com/pycqa/pydocstyle - rev: 6.1.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 hooks: - - id: pydocstyle - additional_dependencies: ["toml"] + - id: ruff + args: [ --fix ] + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index 591e27f..db21dea 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,12 @@ Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. -Surjectors builds on Distrax and Haiku and is fully compatible with both of them. - Surjectors makes use of -- Haiku`s module system for neural networks, -- Distrax for probability distributions and some base bijectors, -- Optax for gradient-based optimization, -- JAX for autodiff and XLA computation. +- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks, +- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors, +- [Optax](https://github.com/deepmind/optax) for gradient-based optimization, +- [JAX](https://github.com/google/jax) for autodiff and XLA computation. ## Examples diff --git a/docs/conf.py b/docs/conf.py index 44b14bd..00db2a4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,8 @@ "examples/*py" ] +autodoc_typehints = "both" + html_theme = "sphinx_book_theme" html_theme_options = { @@ -52,3 +54,13 @@ } html_title = "Surjectors 🚀" + + +def skip(app, what, name, obj, would_skip, options): + if name == "__init__": + return True + return would_skip + + +def setup(app): + app.connect("autodoc-skip-member", skip) \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 122beeb..2e9c0c6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,10 +15,10 @@ Surjectors builds on Distrax and Haiku and is fully compatible with both of them Surjectors makes use of -- Haiku`s module system for neural networks, -- Distrax for probability distributions and some base bijectors, -- Optax for gradient-based optimization, -- JAX for autodiff and XLA computation. +- `Haiku's `_ module system for neural networks, +- `Distrax `_ for probability distributions and some base bijectors, +- `Optax `_ for gradient-based optimization, +- `JAX `_ for autodiff and XLA computation. Example ------- @@ -90,6 +90,23 @@ In order to contribute: 4) test it by calling :code:`hatch run test` on the (Unix) command line, 5) submit a PR 🙂 +Citing Surjectors +----------------- + +.. code-block:: latex + + @article{dirmeier2024surjectors, + author = {Simon Dirmeier}, + title = {Surjectors: surjection layers for density estimation with normalizing flows}, + year = {2024}, + journal = {Journal of Open Source Software}, + publisher = {The Open Journal}, + volume = {9}, + number = {94}, + pages = {6188}, + doi = {10.21105/joss.06188} + } + License ------- diff --git a/docs/news.rst b/docs/news.rst index cd7524a..1542a62 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -7,4 +7,5 @@ Latest news on the development of `Surjectors`. .. note:: - No news so far :). + - 29.02.2024. Surjectors has been accepted for publication in the Journal of Open Source Software 🎉🎉🎉. + Find the paper here `10.21105/joss.06188 `_. diff --git a/docs/surjectors.rst b/docs/surjectors.rst index ca670cc..2db2170 100644 --- a/docs/surjectors.rst +++ b/docs/surjectors.rst @@ -29,6 +29,10 @@ Hence, every normalizing flow can be composed by defining these three components >>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)]) >>> pushforward = TransformedDistribution(base_distribution, flow) +Regardless of how the chain of transformations (called :code:`flow` above) is defined, +each pushforward has access to four methods :code:`sample`, :code:`sample_and_log_prob`:code:`log_prob`, and :code:`inverse_and_log_prob`. + +The exact method declarations can be found in the API below. General ------- @@ -41,39 +45,51 @@ TransformedDistribution ~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: TransformedDistribution - :members: + :members: log_prob, sample, inverse_and_log_prob, sample_and_log_prob Chain ~~~~~ .. autoclass:: Chain - :members: __init__ + :members: Bijective layers ---------------- .. autosummary:: MaskedAutoregressive + AffineMaskedAutoregressive MaskedCoupling + AffineMaskedCoupling + RationalQuadraticSplineMaskedCoupling Permutation Autoregressive bijections ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedAutoregressive - :members: __init__ + :members: + +.. autoclass:: AffineMaskedAutoregressive + :members: Coupling bijections ~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedCoupling - :members: __init__ + :members: + +.. autoclass:: AffineMaskedCoupling + :members: + +.. autoclass:: RationalQuadraticSplineMaskedCoupling + :members: Other bijections ~~~~~~~~~~~~~~~~ .. autoclass:: Permutation - :members: __init__ + :members: Inference surjection layers --------------------------- @@ -96,35 +112,34 @@ Coupling inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedCouplingInferenceFunnel - :members: __init__ - + :members: .. autoclass:: AffineMaskedCouplingInferenceFunnel - :members: __init__ + :members: .. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel - :members: __init__ + :members: Autoregressive inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: .. autoclass:: AffineMaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: .. autoclass:: RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel - :members: __init__ + :members: Other inference surjections ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LULinear - :members: __init__ + :members: .. autoclass:: MLPInferenceFunnel - :members: __init__ + :members: .. autoclass:: Slice - :members: __init__ + :members: diff --git a/pyproject.toml b/pyproject.toml index 7ce6352..00d0471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" homepage = "https://github.com/dirmeier/surjectors" keywords = ["normalizing flows", "surjections", "density estimation"] classifiers = [ - "Development Status :: 1 - Planning", + "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", @@ -49,55 +49,27 @@ dependencies = [ [tool.hatch.envs.test] dependencies = [ - "pylint>=2.15.10", + "ruff>=0.3.0", "pytest>=7.2.0", "pytest-cov>=4.0.0" ] [tool.hatch.envs.test.scripts] -lint = 'pylint surjectors' +lint = 'ruff check surjectors' test = 'pytest -v --cov=./surjectors --cov-report=xml surjectors' -[tool.black] -line-length = 80 -extend-ignore = "E203" -target-version = ['py310'] -exclude = ''' -/( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - +[tool.bandit] +skips = ["B101"] -[tool.isort] -profile = "black" -line_length = 80 -include_trailing_comma = true +[tool.ruff] +line-length = 80 +exclude = ["*_test.py", "docs/**", "examples/**"] -[tool.flake8] -max-line-length = 80 -extend-ignore = ["E203", "W503"] -per-file-ignores = [ - '__init__.py:F401', +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F"] +extend-select = [ + "UP", "D", "I", "PL", "S" ] -[tool.pylint.messages_control] -disable = """ -invalid-name,missing-module-docstring,R0801 -""" - -[tool.bandit] -skips = ["B101"] - -[tool.pydocstyle] +[tool.ruff.lint.pydocstyle] convention= 'google' -match = '^surjectors/.*/((?!_test).)*\.py' diff --git a/surjectors/__init__.py b/surjectors/__init__.py index b69fc0f..1105a10 100644 --- a/surjectors/__init__.py +++ b/surjectors/__init__.py @@ -1,26 +1,29 @@ -""" -surjectors: Surjection layers for density estimation with normalizing flows -""" +"""surjectors: Surjection layers for density estimation with normalizing flows.""" -__version__ = "0.3.1" +__version__ = "0.3.2" +from surjectors._src.bijectors.affine_masked_autoregressive import ( + AffineMaskedAutoregressive, +) +from surjectors._src.bijectors.affine_masked_coupling import ( + AffineMaskedCoupling, +) from surjectors._src.bijectors.lu_linear import LULinear from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive from surjectors._src.bijectors.masked_coupling import MaskedCoupling from surjectors._src.bijectors.permutation import Permutation +from surjectors._src.bijectors.rq_masked_coupling import ( + RationalQuadraticSplineMaskedCoupling, +) from surjectors._src.distributions.transformed_distribution import ( TransformedDistribution, ) from surjectors._src.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501 AffineMaskedAutoregressiveInferenceFunnel, ) -from surjectors._src.surjectors.affine_masked_coupling_generative_funnel import ( # noqa: E501 - AffineMaskedCouplingGenerativeFunnel, -) from surjectors._src.surjectors.affine_masked_coupling_inference_funnel import ( AffineMaskedCouplingInferenceFunnel, ) -from surjectors._src.surjectors.augment import Augment from surjectors._src.surjectors.chain import Chain from surjectors._src.surjectors.masked_autoregressive_inference_funnel import ( # noqa: E501 MaskedAutoregressiveInferenceFunnel, @@ -41,11 +44,14 @@ "Chain", "Permutation", "TransformedDistribution", + "AffineMaskedAutoregressive", "MaskedAutoregressive", "MaskedAutoregressiveInferenceFunnel", "AffineMaskedAutoregressiveInferenceFunnel", "RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel", "MaskedCoupling", + "AffineMaskedCoupling", + "RationalQuadraticSplineMaskedCoupling", "MaskedCouplingInferenceFunnel", "AffineMaskedCouplingInferenceFunnel", "RationalQuadraticSplineMaskedCouplingInferenceFunnel", diff --git a/surjectors/_src/bijectors/affine_masked_autoregressive.py b/surjectors/_src/bijectors/affine_masked_autoregressive.py new file mode 100644 index 0000000..fbdd9dd --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_autoregressive.py @@ -0,0 +1,45 @@ +import distrax +from jax import numpy as jnp + +from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive +from surjectors._src.conditioners.nn.made import MADE +from surjectors.util import unstack + + +# pylint: disable=too-many-arguments,arguments-renamed +class AffineMaskedAutoregressive(MaskedAutoregressive): + """An affine masked autoregressive layer. + + Args: + conditioner: a MADE network + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: tthe number of array dimensions the bijector + operates on + + References: + .. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. + + Examples: + >>> import distrax + >>> from surjectors import AffineMaskedAutoregressive + >>> + >>> layer = AffineMaskedAutoregressive( + >>> conditioner=MADE(10, [8, 8], 2), + >>> ) + """ + + def __init__( + self, + conditioner: MADE, + event_ndims: int = 1, + inner_event_ndims: int = 0, + ): + def bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + super().__init__( + conditioner, bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/bijectors/affine_masked_autoregressive_test.py b/surjectors/_src/bijectors/affine_masked_autoregressive_test.py new file mode 100644 index 0000000..863e1bb --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_autoregressive_test.py @@ -0,0 +1,53 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import AffineMaskedAutoregressive, TransformedDistribution +from surjectors.nn import MADE + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = AffineMaskedAutoregressive( + MADE(n_dimension, [8, 8], 2), + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_affine_masked_autoregressive(): + n_dimension = 4 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_affine_masked_autoregressive(): + n_dimension = 4 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/bijectors/affine_masked_coupling.py b/surjectors/_src/bijectors/affine_masked_coupling.py new file mode 100644 index 0000000..fe0ccec --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_coupling.py @@ -0,0 +1,52 @@ +from typing import Callable, Optional + +import distrax +from jax import numpy as jnp + +from surjectors._src.bijectors.masked_coupling import MaskedCoupling +from surjectors._src.distributions.transformed_distribution import Array + + +# pylint: disable=too-many-arguments, arguments-renamed,too-many-ancestors +class AffineMaskedCoupling(MaskedCoupling): + """An affine masked coupling layer. + + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. + + Examples: + >>> import distrax + >>> from surjectors import AffineMaskedCoupling + >>> from surjectors.nn import make_mlp + >>> from surjectors.util import make_alternating_binary_mask + >>> + >>> layer = MaskedCoupling( + >>> mask=make_alternating_binary_mask(10, True), + >>> conditioner=make_mlp([8, 8, 10 * 2]), + >>> ) + """ + + def __init__( + self, + mask: Array, + conditioner: Callable, + event_ndims: Optional[int] = None, + inner_event_ndims: int = 0, + ): + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + super().__init__( + mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/bijectors/affine_masked_coupling_test.py b/surjectors/_src/bijectors/affine_masked_coupling_test.py new file mode 100644 index 0000000..f53d045 --- /dev/null +++ b/surjectors/_src/bijectors/affine_masked_coupling_test.py @@ -0,0 +1,55 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import AffineMaskedCoupling, TransformedDistribution +from surjectors.nn import make_mlp +from surjectors.util import make_alternating_binary_mask + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = AffineMaskedCoupling( + make_alternating_binary_mask(n_dimension, 0 % 2 == 0), + make_mlp([8, 8, n_dimension * 2]), + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_affine_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_affine_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/bijectors/lu_linear.py b/surjectors/_src/bijectors/lu_linear.py index 01e1dd8..f8dc1e9 100644 --- a/surjectors/_src/bijectors/lu_linear.py +++ b/surjectors/_src/bijectors/lu_linear.py @@ -10,6 +10,11 @@ class LULinear(Bijector, hk.Module): """An bijection based on the LU composition. + Args: + n_dimension: number of dimensions to keep + with_bias: use a bias term or not + dtype: parameter dtype + References: .. [1] Oliva, Junier, et al. "Transformation Autoregressive Networks". Proceedings of the 35th International Conference on @@ -21,13 +26,6 @@ class LULinear(Bijector, hk.Module): """ def __init__(self, n_dimension, with_bias=False, dtype=jnp.float32): - """Constructs a LULinear layer. - - Args: - n_dimension: number of dimensions to keep - with_bias: use a bias term or not - dtype: parameter dtype - """ super().__init__() if with_bias: raise NotImplementedError() diff --git a/surjectors/_src/bijectors/masked_autoregressive.py b/surjectors/_src/bijectors/masked_autoregressive.py index 20e39a5..59d1e87 100644 --- a/surjectors/_src/bijectors/masked_autoregressive.py +++ b/surjectors/_src/bijectors/masked_autoregressive.py @@ -7,10 +7,18 @@ from surjectors._src.conditioners.nn.made import MADE -# pylint: disable=too-many-arguments, arguments-renamed +# pylint: disable=too-many-arguments,arguments-renamed class MaskedAutoregressive(Bijector): """A masked autoregressive layer. + Args: + conditioner: a MADE network + bijector_fn: a callable that returns the inner bijector that will + be used to transform the input + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: tthe number of array dimensions the bijector + operates on + References: .. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for Density Estimation". Advances in Neural Information Processing @@ -38,16 +46,6 @@ def __init__( event_ndims: int = 1, inner_event_ndims: int = 0, ): - """Construct a masked autoregressive layer. - - Args: - conditioner: a MADE network - bijector_fn: a callable that returns the inner bijector that will - be used to transform the input - event_ndims: the number of array dimensions the bijector operates on - inner_event_ndims: tthe number of array dimensions the bijector - operates on - """ if event_ndims is not None and event_ndims < inner_event_ndims: raise ValueError( f"'event_ndims={event_ndims}' should be at least as" diff --git a/surjectors/_src/bijectors/masked_coupling.py b/surjectors/_src/bijectors/masked_coupling.py index 8480df3..73fdd3a 100644 --- a/surjectors/_src/bijectors/masked_coupling.py +++ b/surjectors/_src/bijectors/masked_coupling.py @@ -8,13 +8,20 @@ from surjectors._src.distributions.transformed_distribution import Array -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class MaskedCoupling(Bijector, distrax.MaskedCoupling): """A masked coupling layer. - References: - .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + bijector_fn: a callable that returns the inner bijector that will be + used to transform the input + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on Examples: >>> import distrax @@ -31,6 +38,10 @@ class MaskedCoupling(Bijector, distrax.MaskedCoupling): >>> bijector_fn=bijector_fn, >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> ) + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__( @@ -41,19 +52,6 @@ def __init__( event_ndims: Optional[int] = None, inner_event_ndims: int = 0, ): - """Construct a masked coupling layer. - - Args: - mask: a boolean mask of length n_dim. A value - of True indicates that the corresponding input remains unchanged - conditioner: a function that computes the parameters of the inner - bijector - bijector_fn: a callable that returns the inner bijector that will be - used to transform the input - event_ndims: the number of array dimensions the bijector operates on - inner_event_ndims: the number of array dimensions the inner bijector - operates on - """ super().__init__( mask, conditioner, bijector_fn, event_ndims, inner_event_ndims ) diff --git a/surjectors/_src/bijectors/permutation.py b/surjectors/_src/bijectors/permutation.py index 629cf78..1302f36 100644 --- a/surjectors/_src/bijectors/permutation.py +++ b/surjectors/_src/bijectors/permutation.py @@ -6,6 +6,11 @@ class Permutation(distrax.Bijector): """Permute the dimensions of a vector. + Args: + permutation: a vector of integer indexes representing the order of + the elements + event_ndims_in: number of input event dimensions + Examples: >>> from surjectors import Permutation >>> from jax import numpy as jnp @@ -15,40 +20,13 @@ class Permutation(distrax.Bijector): """ def __init__(self, permutation, event_ndims_in: int): - """Construct a permutation layer. - - Args: - permutation: a vector of integer indexes representing the order of - the elements - event_ndims_in: number of input event dimensions - """ super().__init__(event_ndims_in) self.permutation = permutation - def forward_and_log_det(self, z): - """Compute the forward transformation and its Jacobian determinant. - - Args: - z: event for which the forward transform and likelihood contribution - is computed - - Returns: - tuple of two arrays of floats. The first one is the forward - transformation, the second one its likelihood contribution - """ + def _forward_and_likelihood_contribution(self, z): return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0) - def inverse_and_log_det(self, y): - """Compute the inverse transformation and its Jacobian determinant. - - Args: - y: event for which the inverse and likelihood contribution is - computed - - Returns: - tuple of two arrays of floats. The first one is the inverse - transformation, the second one its likelihood contribution - """ + def _inverse_and_likelihood_contribution(self, y): size = self.permutation.size permutation_inv = ( jnp.zeros(size, dtype=jnp.result_type(int)) diff --git a/surjectors/_src/bijectors/rq_masked_coupling.py b/surjectors/_src/bijectors/rq_masked_coupling.py new file mode 100644 index 0000000..7bccd25 --- /dev/null +++ b/surjectors/_src/bijectors/rq_masked_coupling.py @@ -0,0 +1,65 @@ +from typing import Callable, Optional + +import distrax + +from surjectors._src.bijectors.masked_coupling import MaskedCoupling +from surjectors._src.distributions.transformed_distribution import Array + + +# ruff: noqa: PLR0913 +class RationalQuadraticSplineMaskedCoupling(MaskedCoupling): + """A rational quadratic spline masked coupling layer. + + References: + .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + + Examples: + >>> import distrax + >>> from surjectors import RationalQuadraticSplineMaskedCoupling + >>> from surjectors.nn import make_mlp + >>> from surjectors.util import make_alternating_binary_mask + >>> + >>> layer = RationalQuadraticSplineMaskedCoupling( + >>> mask=make_alternating_binary_mask(10, True), + >>> conditioner=make_mlp([8, 8, 10 * 2]), + >>> range_min=-1.0, + >>> range_max=1.0 + >>> ) + """ + + def __init__( + self, + mask: Array, + conditioner: Callable, + range_min: float, + range_max: float, + event_ndims: Optional[int] = None, + inner_event_ndims: int = 0, + ): + """Construct a rational quadratic spline masked coupling layer. + + Args: + mask: a boolean mask of length n_dim. A value + of True indicates that the corresponding input remains unchanged + conditioner: a function that computes the parameters of the inner + bijector + range_min: minimum range of the spline + range_max: maximum range of the spline + event_ndims: the number of array dimensions the bijector operates on + inner_event_ndims: the number of array dimensions the inner bijector + operates on + """ + self.range_min = range_min + self.range_max = range_max + + def _bijector_fn(params: Array): + return distrax.RationalQuadraticSpline( + params, self.range_min, self.range_max + ) + + super().__init__( + mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims + ) diff --git a/surjectors/_src/bijectors/rq_masked_coupling_test.py b/surjectors/_src/bijectors/rq_masked_coupling_test.py new file mode 100644 index 0000000..b8ad3bf --- /dev/null +++ b/surjectors/_src/bijectors/rq_masked_coupling_test.py @@ -0,0 +1,65 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import ( + RationalQuadraticSplineMaskedCoupling, + TransformedDistribution, +) +from surjectors.nn import make_mlp +from surjectors.util import make_alternating_binary_mask + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def make_bijector(n_dimension): + def _transformation_fn(n_dimension): + bij = RationalQuadraticSplineMaskedCoupling( + make_alternating_binary_mask(n_dimension, 0 % 2 == 0), + hk.Sequential( + [ + make_mlp([8, 8, n_dimension * 10]), + hk.Reshape((n_dimension, 10)), + ] + ), + -1.0, + 1.0, + ) + return bij + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_rq_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_rq_masked_coupling(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_bijector(n_dimension) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) diff --git a/surjectors/_src/conditioners/mlp.py b/surjectors/_src/conditioners/mlp.py index 63db8e9..9d24f1b 100644 --- a/surjectors/_src/conditioners/mlp.py +++ b/surjectors/_src/conditioners/mlp.py @@ -3,6 +3,7 @@ from jax import numpy as jnp +# type: ignore[B008] def make_mlp( dims, activation=jax.nn.gelu, diff --git a/surjectors/_src/conditioners/nn/made.py b/surjectors/_src/conditioners/nn/made.py index 4904e01..0dc29aa 100644 --- a/surjectors/_src/conditioners/nn/made.py +++ b/surjectors/_src/conditioners/nn/made.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import haiku as hk import jax @@ -11,7 +11,7 @@ from surjectors._src.conditioners.nn.masked_linear import MaskedLinear -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class MADE(hk.Module): """Masked Autoregressive Density Estimator. @@ -26,7 +26,7 @@ class MADE(hk.Module): def __init__( self, input_size: int, - hidden_layer_sizes: Union[List[int], Tuple[int]], + hidden_layer_sizes: Union[list[int], tuple[int]], n_params: int, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, diff --git a/surjectors/_src/conditioners/transformer.py b/surjectors/_src/conditioners/transformer.py index 37bd068..1967e27 100644 --- a/surjectors/_src/conditioners/transformer.py +++ b/surjectors/_src/conditioners/transformer.py @@ -64,7 +64,7 @@ def __call__(self, inputs, *, is_training=True): return hk.Linear(self.output_size)(h) -# pylint: disable=too-many-arguments +# ruff: noqa: PLR0913 def make_transformer( output_size, num_heads=4, diff --git a/surjectors/_src/distributions/transformed_distribution.py b/surjectors/_src/distributions/transformed_distribution.py index 37cfb5e..2df71db 100644 --- a/surjectors/_src/distributions/transformed_distribution.py +++ b/surjectors/_src/distributions/transformed_distribution.py @@ -12,6 +12,10 @@ class TransformedDistribution: Can be used to define a pushforward measure. + Args: + base_distribution: a distribution object + transform: some transformation + Examples: >>> import distrax >>> from jax import numpy as jnp @@ -28,12 +32,6 @@ class TransformedDistribution: """ def __init__(self, base_distribution: Distribution, transform: Surjector): - """Constructs a TransformedDistribution. - - Args: - base_distribution: a distribution object - transform: some transformation - """ self.base_distribution = base_distribution self.transform = transform diff --git a/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py index 203ebaa..8e4de9c 100644 --- a/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_autoregressive_inference_funnel.py @@ -21,13 +21,11 @@ class AffineMaskedAutoregressiveInferenceFunnel( transformation from data to latent space using a masking mechanism as in MaskedAutoegressive. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a MADE neural network Examples: >>> import distrax @@ -49,18 +47,17 @@ class AffineMaskedAutoregressiveInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__(self, n_keep: int, decoder: Callable, conditioner: MADE): - """Constructs a AffineMaskedAutoregressiveInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a MADE neural network - """ - def bijector_fn(params: Array): shift, log_scale = unstack(params, axis=-1) return distrax.ScalarAffine(shift, jnp.exp(log_scale)) diff --git a/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py b/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py index 784bea5..d0b8894 100644 --- a/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_coupling_generative_funnel.py @@ -8,16 +8,15 @@ class AffineMaskedCouplingGenerativeFunnel(Surjector): - """A Generative funnel layer using masked affine coupling.""" + """A Generative funnel layer using masked affine coupling. - def __init__(self, n_keep, encoder, conditioner): - """Construct a AffineMaskedCouplingGenerativeFunnel layer. + Args: + n_keep: number of dimensions to keep + encoder: callable + conditioner: callable + """ - Args: - n_keep: number of dimensions to keep - encoder: callable - conditioner: callable - """ + def __init__(self, n_keep, encoder, conditioner): self.n_keep = n_keep self.encoder = encoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py index 6b04305..0a74524 100644 --- a/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/affine_masked_coupling_inference_funnel.py @@ -12,12 +12,11 @@ class AffineMaskedCouplingInferenceFunnel(MaskedCouplingInferenceFunnel): """A masked coupling inference funnel that uses an affine transformation. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network Examples: >>> import distrax @@ -39,18 +38,16 @@ class AffineMaskedCouplingInferenceFunnel(MaskedCouplingInferenceFunnel): >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__(self, n_keep: int, decoder: Callable, conditioner: Callable): - """Constructs a AffineMaskedCouplingInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - """ - def bijector_fn(params: Array): shift, log_scale = jnp.split(params, 2, axis=-1) return distrax.ScalarAffine(shift, jnp.exp(log_scale)) diff --git a/surjectors/_src/surjectors/augment.py b/surjectors/_src/surjectors/augment.py index f1ce7ad..9526e0f 100644 --- a/surjectors/_src/surjectors/augment.py +++ b/surjectors/_src/surjectors/augment.py @@ -5,15 +5,14 @@ class Augment(Surjector): - """Augment generative funnel.""" + """Augment generative funnel. - def __init__(self, n_keep, encoder): - """Construct an augmentation layer. + Args: + n_keep: number of dimensions to keep + encoder: encoder callable + """ - Args: - n_keep: number of dimensions to keep - encoder: encoder callable - """ + def __init__(self, n_keep, encoder): self.n_keep = n_keep self.encoder = encoder diff --git a/surjectors/_src/surjectors/chain.py b/surjectors/_src/surjectors/chain.py index 0814ac7..98e9166 100644 --- a/surjectors/_src/surjectors/chain.py +++ b/surjectors/_src/surjectors/chain.py @@ -1,14 +1,17 @@ -from typing import List - from surjectors._src._transform import Transform from surjectors._src.surjectors.surjector import Surjector +# type: ignore[B009] class Chain(Surjector): """Chain of normalizing flows. Can be used to concatenate several normalizing flows together. + Args: + transforms: a list of transformations, such as bijections or + surjections + Examples: >>> from surjectors import Slice, Chain >>> a = Slice(10) @@ -16,13 +19,7 @@ class Chain(Surjector): >>> ab = Chain([a, b]) """ - def __init__(self, transforms: List[Transform]): - """Constructs a Chain. - - Args: - transforms: a list of transformations, such as bijections or - surjections - """ + def __init__(self, transforms: list[Transform]): self._transforms = transforms def _inverse_and_likelihood_contribution(self, y, x=None, **kwargs): diff --git a/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py index e23f46f..8044c36 100644 --- a/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/masked_autoregressive_inference_funnel.py @@ -17,13 +17,12 @@ class MaskedAutoregressiveInferenceFunnel(Surjector): comparison to AffineMaskedAutoregressiveInferenceFunnel and RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a MADE neural network + bijector_fn: an inner bijector function to be used Examples: >>> import distrax @@ -50,6 +49,14 @@ class MaskedAutoregressiveInferenceFunnel(Surjector): >>> conditioner=MADE(10, [8, 8], 2), >>> bijector_fn=bijector_fn >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__( @@ -59,15 +66,6 @@ def __init__( conditioner: MADE, bijector_fn: Callable, ): - """Constructs a MaskedAutoregressiveInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a MADE neural network - bijector_fn: an inner bijector function to be used - """ self.n_keep = n_keep self.decoder = decoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/masked_coupling_inference_funnel.py index 42c67e1..7e071e1 100644 --- a/surjectors/_src/surjectors/masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/masked_coupling_inference_funnel.py @@ -17,12 +17,12 @@ class MaskedCouplingInferenceFunnel(Surjector): comparison to ASffineMaskedCouplingInferenceFunnel and RationalQuadraticSplineMaskedCouplingInferenceFunnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + bijector_fn: an inner bijector function to be used Examples: >>> import distrax @@ -48,6 +48,13 @@ class MaskedCouplingInferenceFunnel(Surjector): >>> conditioner=make_mlp([4, 4, 10 * 2]), >>> bijector_fn=bijector_fn >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__( @@ -57,15 +64,6 @@ def __init__( conditioner: Callable, bijector_fn: Callable, ): - """Construct a MaskedCouplingInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - bijector_fn: an inner bijector function to be used - """ self.n_keep = n_keep self.decoder = decoder self.conditioner = conditioner diff --git a/surjectors/_src/surjectors/mlp.py b/surjectors/_src/surjectors/mlp.py index 67c4274..719f4ea 100644 --- a/surjectors/_src/surjectors/mlp.py +++ b/surjectors/_src/surjectors/mlp.py @@ -10,10 +10,9 @@ class MLPInferenceFunnel(Surjector, hk.Module): """A multilayer perceptron inference funnel. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. + Args: + n_keep: number of dimensions to keep + decoder: a conditional probability function Examples: >>> import distrax @@ -31,15 +30,14 @@ class MLPInferenceFunnel(Surjector, hk.Module): >>> >>> decoder = decoder_fn(5) >>> a = MLPInferenceFunnel(10, decoder) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. """ def __init__(self, n_keep: int, decoder: Callable): - """Constructs a MLPInferenceFunnel layer. - - Args: - n_keep: number of dimensions to keep - decoder: a conditional probability function - """ super().__init__() self._r = LULinear(n_keep, False) self._w_prime = hk.Linear(n_keep, True) diff --git a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py index 1b98792..79aa4f2 100644 --- a/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_autoregressive_inference_funnel.py @@ -11,21 +11,19 @@ ) -# pylint: disable=too-many-arguments, arguments-renamed +# ruff: noqa: PLR0913 class RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( MaskedAutoregressiveInferenceFunnel ): """A masked autoregressive inference funnel that uses RQ-NSFs. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Durkan, Conor, et al. "Neural Spline Flows". - Advances in Neural Information Processing Systems, 2019. - .. [3] Papamakarios, George, et al. "Masked Autoregressive Flow for - Density Estimation". Advances in Neural Information Processing - Systems, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + range_min: minimum range of the spline + range_max: maximum range of the spline Examples: >>> import distrax @@ -48,6 +46,16 @@ class RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + .. [3] Papamakarios, George, et al. "Masked Autoregressive Flow for + Density Estimation". Advances in Neural Information Processing + Systems, 2017. """ def __init__( @@ -58,17 +66,7 @@ def __init__( range_min: float, range_max: float, ): - """Constructs a RQ-NSF inference funnel. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - range_min: minimum range of the spline - range_max: maximum range of the spline - """ - warnings.warn("class has not been tested. use at own risk") + warnings.warn("class has not been tested properly. use at own risk") self.range_min = range_min self.range_max = range_max diff --git a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py index b4a53d3..e61a789 100644 --- a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py +++ b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel.py @@ -7,19 +7,19 @@ ) +# ruff: noqa: PLR0913 class RationalQuadraticSplineMaskedCouplingInferenceFunnel( MaskedCouplingInferenceFunnel ): """A masked coupling inference funnel that uses a rational quatratic spline. - References: - .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood - with dimensionality reduction". Workshop on Bayesian Deep Learning, - Advances in Neural Information Processing Systems, 2021. - .. [2] Durkan, Conor, et al. "Neural Spline Flows". - Advances in Neural Information Processing Systems, 2019. - .. [3] Dinh, Laurent, et al. "Density estimation using RealNVP". - International Conference on Learning Representations, 2017. + Args: + n_keep: number of dimensions to keep + decoder: a callable that returns a conditional probabiltiy + distribution when called + conditioner: a conditioning neural network + range_min: minimum range of the spline + range_max: maximum range of the spline Examples: >>> import distrax @@ -42,19 +42,18 @@ class RationalQuadraticSplineMaskedCouplingInferenceFunnel( >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> ) + + References: + .. [1] Klein, Samuel, et al. "Funnels: Exact maximum likelihood + with dimensionality reduction". Workshop on Bayesian Deep Learning, + Advances in Neural Information Processing Systems, 2021. + .. [2] Durkan, Conor, et al. "Neural Spline Flows". + Advances in Neural Information Processing Systems, 2019. + .. [3] Dinh, Laurent, et al. "Density estimation using RealNVP". + International Conference on Learning Representations, 2017. """ def __init__(self, n_keep, decoder, conditioner, range_min, range_max): - """Construct a RationalQuadraticSplineMaskedCouplingInferenceFunnel. - - Args: - n_keep: number of dimensions to keep - decoder: a callable that returns a conditional probabiltiy - distribution when called - conditioner: a conditioning neural network - range_min: minimum range of the spline - range_max: maximum range of the spline - """ self.range_min = range_min self.range_max = range_max diff --git a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py index bba233c..e4d1e10 100644 --- a/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py +++ b/surjectors/_src/surjectors/rq_masked_coupling_inference_funnel_test.py @@ -1,4 +1,4 @@ -# pylint: skip-file +# type: ignore import distrax import haiku as hk diff --git a/surjectors/_src/surjectors/slice.py b/surjectors/_src/surjectors/slice.py index 5cf9763..6d3b024 100644 --- a/surjectors/_src/surjectors/slice.py +++ b/surjectors/_src/surjectors/slice.py @@ -9,10 +9,9 @@ class Slice(Surjector): """A slice funnel. - References: - .. [1] Nielsen, Didrik, et al. "SurVAE Flows: Surjections to Bridge the - Gap between VAEs and Flows". Advances in Neural Information - Processing Systems, 2020. + Args: + n_keep: number if dimensions to keep + decoder: callable Examples: >>> import distrax @@ -29,15 +28,14 @@ class Slice(Surjector): >>> return _fn >>> >>> layer = Slice(10, decoder_fn(10)) + + References: + .. [1] Nielsen, Didrik, et al. "SurVAE Flows: Surjections to Bridge the + Gap between VAEs and Flows". Advances in Neural Information + Processing Systems, 2020. """ def __init__(self, n_keep: int, decoder: Callable): - """Constructs a slice layer. - - Args: - n_keep: number if dimensions to keep - decoder: callable - """ self.n_keep = n_keep self.decoder = decoder diff --git a/surjectors/_src/surjectors/surjector.py b/surjectors/_src/surjectors/surjector.py index 0a20a17..9c8709e 100644 --- a/surjectors/_src/surjectors/surjector.py +++ b/surjectors/_src/surjectors/surjector.py @@ -10,7 +10,7 @@ class Surjector(Transform, ABC): """A surjective transformation.""" def __call__(self, method, **kwargs): - """Call the Surjector. + """Call the surjector. Depending on "method", computes - inverse, diff --git a/surjectors/util.py b/surjectors/util.py index 78055c9..828ffcf 100644 --- a/surjectors/util.py +++ b/surjectors/util.py @@ -1,3 +1,5 @@ +"""Utility functions.""" + from collections import namedtuple import numpy as np