Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low-rank-modified metric #684

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 125 additions & 6 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.

"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

from typing import Any, Callable, NamedTuple, Optional, Protocol, Union

import jax.numpy as jnp
import jax.scipy as jscipy
Expand All @@ -38,14 +39,18 @@
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]
__all__ = [
"default_metric",
"gaussian_euclidean",
"gaussian_riemannian",
"gaussian_euclidean_low_rank",
]


class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
...
) -> float: ...


class CheckTurning(Protocol):
Expand All @@ -56,14 +61,14 @@ def __call__(
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
...
) -> bool: ...


class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
data: Any = None


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -208,6 +213,120 @@ def is_turning(
return Metric(momentum_generator, kinetic_energy, is_turning)


def gaussian_euclidean_low_rank(
diagonal_scale_std: Array,
eigenvectors: Array,
eigenvalues: Array,
) -> Metric:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum
:cite:p:`betancourt2013general`.

The gaussian euclidean metric is a euclidean metric further characterized
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The gaussian euclidean metric is a euclidean metric further characterized
The Gaussian Euclidean metric is a Euclidean metric further characterized

by setting the conditional probability density :math:`\pi(momentum|position)`
to follow a standard gaussian distribution. A Newtonian hamiltonian
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
to follow a standard gaussian distribution. A Newtonian hamiltonian
to follow a standard Gaussian distribution. A Newtonian Hamiltonian

dynamics is assumed.

This uses the mass matrix $(D^{-1}(V(\Sigma - I)V^T + I)D^{-1})^{-1}$.

Parameters
----------
diagonal_scale_std
The diagonal $D^{-1}$. This should for instance correspond to the standard deviation
of the posterior.
eigenvectors
An arbitrary number of eigenvectors
eigenvalues
The corresponding eigenvalues

Returns
-------
momentum_generator
A function that generates a value for the momentum at random.
kinetic_energy
A function that returns the kinetic energy given the momentum.
is_turning
A function that determines whether a trajectory is turning back on
itself given the values of the momentum along the trajectory.

"""
(ndim,) = jnp.shape(diagonal_scale_std)
(ndim_, n_eigs) = jnp.shape(eigenvectors)
if ndim != ndim_:
raise ValueError("Shape mismatch in metric.")

(n_eigs_,) = jnp.shape(eigenvalues)
if n_eigs != n_eigs_:
raise ValueError("Shape mismatch in metric.")

# Compute (V(\Sigma - I)V^T + I)x
def inner_matrix_mult(vals, vecs, x):
projected = x @ vecs
scaled = (vals - 1) * projected
projected_back = vecs @ scaled
return projected_back + x

def inv_mass_matrix_mult(x):
scaled = x * diagonal_scale_std
product = inner_matrix_mult(eigenvalues, eigenvectors, scaled)
return product * diagonal_scale_std

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
unit_draws = generate_gaussian_noise(rng_key, position)
sqrt_vals = jnp.sqrt(jnp.reciprocal(eigenvalues))
sqrt_inv_diag = jnp.sqrt(jnp.reciprocal(diagonal_scale_std))
return inner_matrix_mult(sqrt_vals, eigenvectors, unit_draws) * sqrt_inv_diag

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
del position
momentum, _ = ravel_pytree(momentum)
velocity = inv_mass_matrix_mult(momentum)
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum)
return kinetic_energy_val

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`.

Parameters
----------
momentum_left
Momentum of the leftmost point of the trajectory.
momentum_right
Momentum of the rightmost point of the trajectory.
momentum_sum
Sum of the momenta along the trajectory.

"""
del position_left, position_right

m_left, _ = ravel_pytree(momentum_left)
m_right, _ = ravel_pytree(momentum_right)
m_sum, _ = ravel_pytree(momentum_sum)

velocity_left = inv_mass_matrix_mult(m_left)
velocity_right = inv_mass_matrix_mult(m_right)

# rho = m_sum
rho = m_sum - (m_right + m_left) / 2
turning_at_left = jnp.dot(velocity_left, rho) <= 0
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(
momentum_generator,
kinetic_energy,
is_turning,
data=(diagonal_scale_std, eigenvalues, eigenvectors),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you want to store these as properties so it is easier for tuning later on? In any case I suggest removing it here, and add it in the subsequent PR when tuning is introduced (so we can discuss whether it is necessary)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was mostly for debugging and I forgot to take it out. Sorry :-)

)


def gaussian_riemannian(
mass_matrix_fn: Callable,
) -> Metric:
Expand Down
Loading