Skip to content

Commit

Permalink
Update Metropolis-within-Gibbs example
Browse files Browse the repository at this point in the history
Update the Metropolis-within-Gibbs example markdown notebook to be
compatible with API changes.

Minor changes to text to keep consistency with code.
  • Loading branch information
Tommy Hentschel committed Jul 13, 2023
1 parent cf94b27 commit 98e4c1a
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.4
jupytext_version: 1.14.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -67,7 +67,7 @@ In this case the conditional distributions $p(\xx \mid \yy)$ and $p(\yy \mid \xx
1. Maintain separate MCMC kernels to update each component of $p(\xx, \yy)$ while holding the other fixed.
2. Apply the kernel updates correctly.

The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RMHState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RMHState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.mcmc.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RMHState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.
The issue with (2) is that each kernel update for a given MCMC `SamplingAlgorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.

A general way of performing this manual update is to use the `blackjax.mcmc.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$.

Expand All @@ -77,12 +77,23 @@ Using this approach, we now are now ready to implement the Gibbs sampling kernel

```{code-cell} ipython3
# MCMC initializers for each set of paramters
mwg_init_x = blackjax.mcmc.rmh.init
mwg_init_y = blackjax.mcmc.hmc.init
mwg_init_x = blackjax.rmh.init
mwg_init_y = blackjax.hmc.init
# MCMC updaters
mwg_step_fn_x = blackjax.mcmc.rmh.build_kernel()
mwg_step_fn_y = blackjax.mcmc.hmc.build_kernel() # default integrator, etc.
def build_normal_rmh_kernel(rng_key, state, logdensity_fn, sigma):
""" RMH kernel with normal proposal generator
"""
kernel = blackjax.rmh.build_kernel()(
rng_key,
state,
logdensity_fn,
transition_generator=blackjax.mcmc.random_walk.normal(sigma)
)
return kernel
mwg_step_fn_x = build_normal_rmh_kernel
mwg_step_fn_y = blackjax.hmc.build_kernel() # default integrator, etc.
def mwg_kernel(rng_key, state, parameters):
Expand Down Expand Up @@ -239,7 +250,7 @@ def mwg_kernel_general(rng_key, state, logdensity_fn, step_fn, init, parameters)
each element of which is an MCMC stepping functions on the corresponding component.
init
Dictionary with the same keys as ``state``,
each elemtn of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`.
each element of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`.
parameters
Dictionary with the same keys as ``state``, each of which is a dictionary of parameters to
the MCMC algorithm for the corresponding component.
Expand Down Expand Up @@ -331,6 +342,6 @@ positions_general = sampling_loop_general(

## Developer Notes

- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RMHState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.
- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.

- The general MWG kernel prototyped above should be adequate for problems with a small number of components. However, the for-loop over the components of `state` gets unrolled by the JAX JIT compiler (as discussed [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives)), which can cause long compilation times when the number of components is large. To mitigate this problem, the for-loop could be replaced by a `lax.scan()` primitive. For the sake of simplicity this approach is not fully developed here.

0 comments on commit 98e4c1a

Please sign in to comment.