Skip to content

Commit

Permalink
Include first draft of guidelines of design for developer docs w/ ske…
Browse files Browse the repository at this point in the history
…letons for sampling and approximate inference algorithms
  • Loading branch information
albcab authored and rlouf committed Mar 10, 2023
1 parent 1908453 commit 66ea6ad
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 0 deletions.
127 changes: 127 additions & 0 deletions docs/developer/approximate_inf_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
from optax import GradientTransformation

#import basic compoments that are already implemented
#or that you have implemented with a general structure
from blackjax.base import VIAlgorithm
from blackjax.types import PRNGKey, PyTree

__all__ = ["ApproxInfState", "ApproxInfInfo", "init", "sample", "step", "approx_inf_algorithm"]


class ApproxInfState(NamedTuple):
"""State of the approximate inference algorithm.
Give an overview of the variables needed at each step and for sampling.
"""
...

class ApproxInfInfo(NamedTuple):
"""Additional information on the algorithm transition.
Given an overview of the collected values at each step of the approximation.
"""
...


def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs):
#build an inital state
state = ApproxInfState(...)
return state


def step(
rng_key: PRNGKey,
state: ApproxInfInfo,
logdensity_fn: Callable,
optimizer: GradientTransformation,
*args,
**kwargs,
) -> Tuple[ApproxInfState, ApproxInfInfo]:
"""Approximate the target density using the some approximation.
Parameters
----------
List and describe its parameters.
"""
#extract the previous parameters from the state
params = ...
#generate pseudorandom keys
key_other, key_update = jax.random.split(rng_key, 2)
#update the parameters and build a new state
new_state = ApproxInfState(...)
info = ApproxInfInfo(...)

return new_state, info


def sample(rng_key: PRNGKey, state: ApproxInfState, num_samples: int = 1):
"""Sample from the approximation."""
#the sample should be a PyTree of the same structure as the `position` in the init function
samples = ...
return samples


class approx_inf_algorithm:
"""Implements the (basic) user interface for the approximate inference method.
Describe in detail the inner mechanism of the method and its use.
Example
-------
Illustrate the use of the algorithm.
Parameters
----------
List and describe its parameters.
Returns
-------
A ``VIAlgorithm``.
"""
init = staticmethod(init)
step = staticmethod(step)
sample = staticmethod(sample)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
optimizer: GradientTransformation,
*args,
**kwargs,
) -> VIAlgorithm:

def init_fn(position: PyTree):
return cls.init(position, optimizer, ...)

def step_fn(rng_key: PRNGKey, state):
return cls.step(
rng_key,
state,
logdensity_fn,
optimizer,
...,
)

def sample_fn(rng_key: PRNGKey, state, num_samples):
return cls.sample(rng_key, state, num_samples)

return VIAlgorithm(init_fn, step_fn, sample_fn)


#other functions that help make `init`,` `step` and/or `sample` easier to read and understand
41 changes: 41 additions & 0 deletions docs/developer/guidelines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Developer Guidelines

## Style
In its broadest sense, an algorithm that belongs in the blackjax library should approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with probabilities defined by a density (continuous variable) or measure (discrete variable) function. The idea is to sample more from areas with higher probability but also from areas with low probability, just at a lower rate. You can also approximate the target density directly, using an approximation that is easier to handle, then do inference, i.e. solve integrals, with the approximation directly and use importance sampling to correct its bias.

In the following section, we’ll explain blackjax’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles:

Leverage JAX's unique strengths: functional programming and composable function-transformation approach.
Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms.
Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax).
Write code that is easy to read and understand.
Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use.

## Core implementation
There are three types of sampling algorithms blackjax currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, blackjax supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples.

Basic components are functions, which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all that are already implemented *before* writing your own. A recurrent example is the Metropolis-Hastings step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In blackjax, this common accept/reject step done with two functions: first the Hastings ratio is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`.

Because JAX operates on pure functions, inference algorithms always return a NamedTuple containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of blackjax, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a NamedTuple with important information of each iteration.

The user-facing interface of a **sampling algorithm** should work like this:
```python
import blackjax
sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs)
state = sampling_algorithm.init(initial_position)
new_state, info = sampling_algorithm.step(rng_key, state)
```
Achieve this by building from the basic skeleton of a sampling algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py]. Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary.

The user-facing interface of an **approximate inference algorithm** should work like this:
```python
import blackjax
approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs)
state = approx_inf_algorithm.init(initial_position)
new_state, info = approx_inf_algorithm.step(rng_key, state)
#user is able to build the approximate distribution using the state, or generate samples:
position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples)
```
Achieve this by building from the basic skeleton of an approximate inference algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py]. Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary.

Well documented code is essential for a useful library. Start by decomposing your algorithm into basic components, finding those that are already implemented, then implement your own and build the high-level API from basic components.
149 changes: 149 additions & 0 deletions docs/developer/sampling_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax

#import basic compoments that are already implemented
#or that you have implemented with a general structure
#for example, if you do a Metropolis-Hastings accept/reject step:
import blackjax.mcmc.proposal as proposal
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import PRNGKey, PyTree

__all__ = ["SamplingAlgoState", "SamplingAlgoInfo", "init", "build_kernel", "sampling_algorithm"]


class SamplingAlgoState(NamedTuple):
"""State of the sampling algorithm.
Give an overview of the variables needed at each iteration of the model.
"""
...

class SamplingAlgoInfo(NamedTuple):
"""Additional information on the algorithm transition.
Given an overview of the collected values at each iteration of the model.
"""
...


def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs):
#build an inital state
state = SamplingAlgoState(...)
return state


def build_kernel(*args, **kwargs):
"""Build a HMC kernel.
Parameters
----------
List and describe its parameters.
Returns
-------
Describe the kernel that is returned.
"""

def kernel(
rng_key: PRNGKey,
state: SamplingAlgoState,
logdensity_fn: Callable,
*args,
**kwargs,
) -> Tuple[SamplingAlgoState, SamplingAlgoInfo]:
"""Generate a new sample with the sampling kernel."""

#build everything you'll need
proposal_generator = sampling_algorithm_proposal(...)

#generate pseudorandom keys
key_other, key_proposal = jax.random.split(rng_key, 2)

#generate the proposal with all its parts
proposal, info = proposal_generator(key_proposal, ...)
proposal = SamplingAlgoState(...)

return proposal, info

return kernel


class sampling_algorithm:
"""Implements the (basic) user interface for the sampling kernel.
Describe in detail the inner mechanism of the algorithm and its use.
Example
-------
Illustrate the use of the algorithm.
Parameters
----------
List and describe its parameters.
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""
init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
*args,
**kwargs,
) -> MCMCSamplingAlgorithm:
kernel = cls.build_kernel(...)

def init_fn(position: PyTree):
return cls.init(position, logdensity_fn, ...)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
...,
)

return MCMCSamplingAlgorithm(init_fn, step_fn)


#and other functions that help make `init` and/or `build_kernel` easier to read and understand
def sampling_algorithm_proposal(*args, **kwags) -> Callable:
"""Title
Description
Parameters
----------
List and describe its parameters.
Returns
-------
Describe what is returned.
"""
...

def generate(*args, **kwargs):
"""Generate a new chain state."""
sampled_state, info = ...

return sampled_state, info

return generate
9 changes: 9 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,12 @@ hidden:
---
Bibliography<bib.rst>
```

```{toctree}
---
maxdepth: 1
caption: DEVELOPER DOCUMENTATION
hidden:
---
Guidelines<developer/principles.md>
```

0 comments on commit 66ea6ad

Please sign in to comment.