From 4ee52cc0c8da456df3429deac3340d92502d468a Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Fri, 15 Sep 2023 10:51:00 +0200 Subject: [PATCH] simplify changes --- .../examples/howto_metropolis_within_gibbs.md | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/docs/examples/howto_metropolis_within_gibbs.md b/docs/examples/howto_metropolis_within_gibbs.md index 0819ceb9d..64eafff96 100644 --- a/docs/examples/howto_metropolis_within_gibbs.md +++ b/docs/examples/howto_metropolis_within_gibbs.md @@ -67,9 +67,9 @@ In this case the conditional distributions $p(\xx \mid \yy)$ and $p(\yy \mid \xx 1. Maintain separate MCMC kernels to update each component of $p(\xx, \yy)$ while holding the other fixed. 2. Apply the kernel updates correctly. -The issue with (2) is that each kernel update for a given MCMC `SamplingAlgorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$. +The issue with (2) is that each kernel update for a given MCMC `Algorithm` in BlackJAX refers to an algorithm-specific `AlgorithmState`. For example, `RWState` is a `typing.NamedTuple` class containing elements `position` and `log_probability`. In our MWG sampling problem at the beginning of step $t$, `RWState.log_probability` will consist of $\log p(\xx_{t-1}, \yy_{t-1})$. After updating $\xx$, it will consist of $\log p(\xx_{t}, \yy_{t-1})$. This happens automatically when we call `blackjax.rmh.build_kernel()`. However, after updating $\yy$ (via HMC), we must manually update `RWState.log_probability` to consist of $\log p(\xx_{t}, \yy_{t})$. -A general way of performing this manual update is to use the `blackjax.mcmc.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$. +A general way of performing this manual update is to use the `blackjax.algorithm.init()` function of the given component's MCMC algorithm to update the `AlgorithmState`. This function has arguments `position` and `logdensity_fn`. For example with the HMC component, after obtaining $\xx_t$ but before drawing $\yy_t$, the `position` would be $\yy_{t-1}$ and the `logdensity_fn` function would be $\log p(\xx_t, \cdot )$. Using this approach, we now are now ready to implement the Gibbs sampling kernel in the code below. @@ -81,18 +81,7 @@ mwg_init_x = blackjax.rmh.init mwg_init_y = blackjax.hmc.init # MCMC updaters -def build_normal_rmh_kernel(rng_key, state, logdensity_fn, sigma): - """ RMH kernel with normal proposal generator - """ - kernel = blackjax.rmh.build_kernel()( - rng_key, - state, - logdensity_fn, - transition_generator=blackjax.mcmc.random_walk.normal(sigma) - ) - return kernel - -mwg_step_fn_x = build_normal_rmh_kernel +mwg_step_fn_x = blackjax.rmh.build_kernel() mwg_step_fn_y = blackjax.hmc.build_kernel() # default integrator, etc. @@ -162,7 +151,7 @@ def mwg_kernel(rng_key, state, parameters): ```{code-cell} ipython3 parameters = { "x": { - "sigma": .2 * jnp.eye(2) + "transition_generator": blackjax.mcmc.random_walk.normal(.2 * jnp.eye(2)) }, "y": { "inverse_mass_matrix": jnp.array([1., 1.]), @@ -250,7 +239,7 @@ def mwg_kernel_general(rng_key, state, logdensity_fn, step_fn, init, parameters) each element of which is an MCMC stepping functions on the corresponding component. init Dictionary with the same keys as ``state``, - each element of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`. + each elemtn of chi is an MCMC initializer corresponding to the stepping functions in `step_fn`. parameters Dictionary with the same keys as ``state``, each of which is a dictionary of parameters to the MCMC algorithm for the corresponding component. @@ -342,6 +331,6 @@ positions_general = sampling_loop_general( ## Developer Notes -- The update method above (using `blackjax.mcmc.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code. +- The update method above (using `blackjax.algorithm.init()`) should work out-of-the-box for most (if not all) MCMC algorithms in BlackJAX. However, it is not optimally efficient. For example for the RMH update, after obtaining $\yy_{t-1}$ but before drawing $\xx_t$, the method above would calculate `RWState.log_density` to be $\log p(\xx_{t-1}, \yy_{t-1})$. But we've already calculated this value from the previous HMC update of $\yy_{t-1} \sim p(\yy \mid \xx_{t-1})$. So, we could save ourselves the cost of calculating the log-density twice, at the expense of a deeper understanding of the low-level components of the algorithms at hand and less generalizable code. - The general MWG kernel prototyped above should be adequate for problems with a small number of components. However, the for-loop over the components of `state` gets unrolled by the JAX JIT compiler (as discussed [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives)), which can cause long compilation times when the number of components is large. To mitigate this problem, the for-loop could be replaced by a `lax.scan()` primitive. For the sake of simplicity this approach is not fully developed here.