diff --git a/docs/examples/howto_metropolis_within_gibbs.md b/docs/examples/howto_metropolis_within_gibbs.md index 723eb464a..0153294f1 100644 --- a/docs/examples/howto_metropolis_within_gibbs.md +++ b/docs/examples/howto_metropolis_within_gibbs.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.4 + jupytext_version: 1.14.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -67,7 +67,7 @@ In this case the conditional distributions $p(\xx \mid \yy)$ and $p(\yy \mid \xx 1. Maintain separate MCMC kernels to update each component of $p(\xx, \yy)$ while holding the other fixed. 2. Apply the kernel updates correctly. -The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RMHState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RMHState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.mcmc.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RMHState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$. +The issue with (2) is that each kernel update for a given MCMC `SamplingAlgorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$. A general way of performing this manual update is to use the `blackjax.mcmc.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$. @@ -77,12 +77,23 @@ Using this approach, we now are now ready to implement the Gibbs sampling kernel ```{code-cell} ipython3 # MCMC initializers for each set of paramters -mwg_init_x = blackjax.mcmc.rmh.init -mwg_init_y = blackjax.mcmc.hmc.init +mwg_init_x = blackjax.rmh.init +mwg_init_y = blackjax.hmc.init # MCMC updaters -mwg_step_fn_x = blackjax.mcmc.rmh.build_kernel() -mwg_step_fn_y = blackjax.mcmc.hmc.build_kernel() # default integrator, etc. +def build_normal_rmh_kernel(rng_key, state, logdensity_fn, sigma): + """ RMH kernel with normal proposal generator + """ + kernel = blackjax.rmh.build_kernel()( + rng_key, + state, + logdensity_fn, + transition_generator=blackjax.mcmc.random_walk.normal(sigma) + ) + return kernel + +mwg_step_fn_x = build_normal_rmh_kernel +mwg_step_fn_y = blackjax.hmc.build_kernel() # default integrator, etc. def mwg_kernel(rng_key, state, parameters): @@ -239,7 +250,7 @@ def mwg_kernel_general(rng_key, state, logdensity_fn, step_fn, init, parameters) each element of which is an MCMC stepping functions on the corresponding component. init Dictionary with the same keys as ``state``, - each elemtn of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`. + each element of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`. parameters Dictionary with the same keys as ``state``, each of which is a dictionary of parameters to the MCMC algorithm for the corresponding component. @@ -331,6 +342,6 @@ positions_general = sampling_loop_general( ## Developer Notes -- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RMHState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code. +- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code. - The general MWG kernel prototyped above should be adequate for problems with a small number of components. However, the for-loop over the components of `state` gets unrolled by the JAX JIT compiler (as discussed [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives)), which can cause long compilation times when the number of components is large. To mitigate this problem, the for-loop could be replaced by a `lax.scan()` primitive. For the sake of simplicity this approach is not fully developed here.