Skip to content

Commit

Permalink
simplify changes
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Sep 15, 2023
1 parent 458aea6 commit 4ee52cc
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ In this case the conditional distributions $p(\xx \mid \yy)$ and $p(\yy \mid \xx
1. Maintain separate MCMC kernels to update each component of $p(\xx, \yy)$ while holding the other fixed.
2. Apply the kernel updates correctly.

The issue with (2) is that each kernel update for a given MCMC `SamplingAlgorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.
The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$.

A general way of performing this manual update is to use the `blackjax.mcmc.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$.
A general way of performing this manual update is to use the `blackjax.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$.

Using this approach, we now are now ready to implement the Gibbs sampling kernel in the code below.

Expand All @@ -81,18 +81,7 @@ mwg_init_x = blackjax.rmh.init
mwg_init_y = blackjax.hmc.init
# MCMC updaters
def build_normal_rmh_kernel(rng_key, state, logdensity_fn, sigma):
""" RMH kernel with normal proposal generator
"""
kernel = blackjax.rmh.build_kernel()(
rng_key,
state,
logdensity_fn,
transition_generator=blackjax.mcmc.random_walk.normal(sigma)
)
return kernel
mwg_step_fn_x = build_normal_rmh_kernel
mwg_step_fn_x = blackjax.rmh.build_kernel()
mwg_step_fn_y = blackjax.hmc.build_kernel() # default integrator, etc.
Expand Down Expand Up @@ -162,7 +151,7 @@ def mwg_kernel(rng_key, state, parameters):
```{code-cell} ipython3
parameters = {
"x": {
"sigma": .2 * jnp.eye(2)
"transition_generator": blackjax.mcmc.random_walk.normal(.2 * jnp.eye(2))
},
"y": {
"inverse_mass_matrix": jnp.array([1., 1.]),
Expand Down Expand Up @@ -250,7 +239,7 @@ def mwg_kernel_general(rng_key, state, logdensity_fn, step_fn, init, parameters)
each element of which is an MCMC stepping functions on the corresponding component.
init
Dictionary with the same keys as ``state``,
each element of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`.
each elemtn of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`.
parameters
Dictionary with the same keys as ``state``, each of which is a dictionary of parameters to
the MCMC algorithm for the corresponding component.
Expand Down Expand Up @@ -342,6 +331,6 @@ positions_general = sampling_loop_general(

## Developer Notes

- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.
- The update method above (using `blackjax.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code.

- The general MWG kernel prototyped above should be adequate for problems with a small number of components. However, the for-loop over the components of `state` gets unrolled by the JAX JIT compiler (as discussed [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives)), which can cause long compilation times when the number of components is large. To mitigate this problem, the for-loop could be replaced by a `lax.scan()` primitive. For the sake of simplicity this approach is not fully developed here.

0 comments on commit 4ee52cc

Please sign in to comment.