diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 1368a8441..e41ff1d16 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -28,7 +28,8 @@ We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`. """ -from typing import Callable, NamedTuple, Optional, Protocol, Union + +from typing import Any, Callable, NamedTuple, Optional, Protocol, Union import jax.numpy as jnp import jax.scipy as jscipy @@ -38,14 +39,18 @@ from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] +__all__ = [ + "default_metric", + "gaussian_euclidean", + "gaussian_riemannian", + "gaussian_euclidean_low_rank", +] class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: - ... + ) -> float: ... class CheckTurning(Protocol): @@ -56,14 +61,14 @@ def __call__( momentum_sum: ArrayLikeTree, position_left: Optional[ArrayLikeTree] = None, position_right: Optional[ArrayLikeTree] = None, - ) -> bool: - ... + ) -> bool: ... class Metric(NamedTuple): sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] kinetic_energy: KineticEnergy check_turning: CheckTurning + data: Any = None MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] @@ -208,6 +213,120 @@ def is_turning( return Metric(momentum_generator, kinetic_energy, is_turning) +def gaussian_euclidean_low_rank( + diagonal_scale_std: Array, + eigenvectors: Array, + eigenvalues: Array, +) -> Metric: + r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum + :cite:p:`betancourt2013general`. + + The gaussian euclidean metric is a euclidean metric further characterized + by setting the conditional probability density :math:`\pi(momentum|position)` + to follow a standard gaussian distribution. A Newtonian hamiltonian + dynamics is assumed. + + This uses the mass matrix $(D^{-1}(V(\Sigma - I)V^T + I)D^{-1})^{-1}$. + + Parameters + ---------- + diagonal_scale_std + The diagonal $D^{-1}$. This should for instance correspond to the standard deviation + of the posterior. + eigenvectors + An arbitrary number of eigenvectors + eigenvalues + The corresponding eigenvalues + + Returns + ------- + momentum_generator + A function that generates a value for the momentum at random. + kinetic_energy + A function that returns the kinetic energy given the momentum. + is_turning + A function that determines whether a trajectory is turning back on + itself given the values of the momentum along the trajectory. + + """ + (ndim,) = jnp.shape(diagonal_scale_std) + (ndim_, n_eigs) = jnp.shape(eigenvectors) + if ndim != ndim_: + raise ValueError("Shape mismatch in metric.") + + (n_eigs_,) = jnp.shape(eigenvalues) + if n_eigs != n_eigs_: + raise ValueError("Shape mismatch in metric.") + + # Compute (V(\Sigma - I)V^T + I)x + def inner_matrix_mult(vals, vecs, x): + projected = x @ vecs + scaled = (vals - 1) * projected + projected_back = vecs @ scaled + return projected_back + x + + def inv_mass_matrix_mult(x): + scaled = x * diagonal_scale_std + product = inner_matrix_mult(eigenvalues, eigenvectors, scaled) + return product * diagonal_scale_std + + def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: + unit_draws = generate_gaussian_noise(rng_key, position) + sqrt_vals = jnp.sqrt(jnp.reciprocal(eigenvalues)) + sqrt_inv_diag = jnp.sqrt(jnp.reciprocal(diagonal_scale_std)) + return inner_matrix_mult(sqrt_vals, eigenvectors, unit_draws) * sqrt_inv_diag + + def kinetic_energy( + momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None + ) -> float: + del position + momentum, _ = ravel_pytree(momentum) + velocity = inv_mass_matrix_mult(momentum) + kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) + return kinetic_energy_val + + def is_turning( + momentum_left: ArrayLikeTree, + momentum_right: ArrayLikeTree, + momentum_sum: ArrayLikeTree, + position_left: Optional[ArrayLikeTree] = None, + position_right: Optional[ArrayLikeTree] = None, + ) -> bool: + """Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`. + + Parameters + ---------- + momentum_left + Momentum of the leftmost point of the trajectory. + momentum_right + Momentum of the rightmost point of the trajectory. + momentum_sum + Sum of the momenta along the trajectory. + + """ + del position_left, position_right + + m_left, _ = ravel_pytree(momentum_left) + m_right, _ = ravel_pytree(momentum_right) + m_sum, _ = ravel_pytree(momentum_sum) + + velocity_left = inv_mass_matrix_mult(m_left) + velocity_right = inv_mass_matrix_mult(m_right) + + # rho = m_sum + rho = m_sum - (m_right + m_left) / 2 + turning_at_left = jnp.dot(velocity_left, rho) <= 0 + turning_at_right = jnp.dot(velocity_right, rho) <= 0 + return turning_at_left | turning_at_right + + return Metric( + momentum_generator, + kinetic_energy, + is_turning, + data=(diagonal_scale_std, eigenvalues, eigenvectors), + ) + + def gaussian_riemannian( mass_matrix_fn: Callable, ) -> Metric: