Skip to content

Commit

Permalink
Refactor mgrad_gaussian (#628)
Browse files Browse the repository at this point in the history
* Refactor mgrad_gaussian

* fix formatting

* Add a svd_from_cov helper function
  • Loading branch information
junpenglao authored Dec 12, 2023
1 parent fac1d5e commit 029b981
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 47 deletions.
155 changes: 120 additions & 35 deletions blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import Array, PRNGKey

__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"]
__all__ = ["MarginalState", "MarginalInfo", "init", "build_kernel", "mgrad_gaussian"]


# [TODO](https://github.com/blackjax-devs/blackjax/issues/237)
Expand Down Expand Up @@ -50,6 +50,40 @@ class MarginalState(NamedTuple):
U_grad_x: Array


class CovarianceSVD(NamedTuple):
"""Singular Value Decomposition of the covariance matrix.
U
Unitary array of the covariance matrix.
Gamma
Singular values of the covariance matrix.
U_t
Transpose of the unitary array of the covariance matrix.
"""

U: Array
Gamma: Array
U_t: Array


def svd_from_covariance(covariance: Array) -> CovarianceSVD:
"""Compute the singular value decomposition of the covariance matrix.
Parameters
----------
covariance
The covariance matrix.
Returns
-------
A ``CovarianceSVD`` object.
"""
U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True)
return CovarianceSVD(U, Gamma, U_t)


class MarginalInfo(NamedTuple):
"""Additional information on the RMH chain.
Expand All @@ -72,28 +106,66 @@ class MarginalInfo(NamedTuple):
proposal: MarginalState


def init_and_kernel(logdensity_fn, covariance, mean=None):
"""Build the marginal version of the auxiliary gradient-based sampler
def generate_mean_shifted_logprob(logdensity_fn, mean, covariance):
"""Generate a log-density function that is shifted by a constant
Parameters
----------
logdensity_fn
The original log-density function
mean
The mean of the prior Gaussian density
covariance
The covariance of the prior Gaussian density.
Returns
-------
A log-density function that is shifted by a constant
"""
shift = linalg.solve(covariance, mean, assume_a="pos")

def shifted_logdensity_fn(x):
return logdensity_fn(x) + jnp.dot(x, shift)

return shifted_logdensity_fn


def init(position, logdensity_fn, U_t):
"""Initialize the marginal version of the auxiliary gradient-based sampler.
Parameters
----------
position
The initial position of the chain.
logdensity_fn
The logarithm of the likelihood function for the latent Gaussian model.
U_t
The unitary array of the covariance matrix.
"""
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
return MarginalState(
position, logdensity, logdensity_grad, U_t @ position, U_t @ logdensity_grad
)


def build_kernel(cov_svd: CovarianceSVD):
"""Build the marginal version of the auxiliary gradient-based sampler.
Parameters
----------
cov_svd
The singular value decomposition of the covariance matrix.
Returns
-------
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
An init function.
"""
U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True)
U, Gamma, U_t = cov_svd

if mean is not None:
shift = linalg.solve(covariance, mean, assume_a="pos")
val_and_grad = jax.value_and_grad(
lambda x: logdensity_fn(x) + jnp.dot(x, shift)
)
else:
val_and_grad = jax.value_and_grad(logdensity_fn)

def step(key: PRNGKey, state: MarginalState, delta):
def kernel(key: PRNGKey, state: MarginalState, logdensity_fn, delta):
y_key, u_key = jax.random.split(key, 2)

position, logdensity, logdensity_grad, U_x, U_grad_x = state
Expand All @@ -111,7 +183,7 @@ def step(key: PRNGKey, state: MarginalState, delta):
y = U @ temp

# Bookkeeping
log_p_y, grad_y = val_and_grad(y)
log_p_y, grad_y = jax.value_and_grad(logdensity_fn)(y)
U_y = U_t @ y
U_grad_y = U_t @ grad_y

Expand All @@ -131,39 +203,34 @@ def step(key: PRNGKey, state: MarginalState, delta):
info = MarginalInfo(p_accept, do_accept, proposed_state)
return accepted_state, info

def init(position):
logdensity, logdensity_grad = val_and_grad(position)
return MarginalState(
position, logdensity, logdensity_grad, U_t @ position, U_t @ logdensity_grad
)

return init, step
return kernel


class mgrad_gaussian:
"""Implements the marginal sampler for latent Gaussian model of :cite:p:`titsias2018auxiliary`.
It uses a first order approximation to the log_likelihood of a model with Gaussian prior.
Interestingly, the only parameter that needs calibrating is the "step size" delta, which can be done very efficiently.
Interestingly, the only parameter that needs calibrating is the "step size" delta,
which can be done very efficiently.
Calibrating it to have an acceptance rate of roughly 50% is a good starting point.
Examples
--------
A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C) can be initialized and
used for a given "step size" delta with the following code:
A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C)
can be initialized and used for a given "step size" delta with the following code:
.. code::
mgrad_gaussian = blackjax.mgrad_gaussian(f, C, use_inverse=False, mean=m)
mgrad_gaussian = blackjax.mgrad_gaussian(f, C, mean=m, step_size=delta)
state = mgrad_gaussian.init(zeros) # Starting at the mean of the prior
new_state, info = mgrad_gaussian.step(rng_key, state, delta)
new_state, info = mgrad_gaussian.step(rng_key, state)
We can JIT-compile the step function for better performance
.. code::
step = jax.jit(mgrad_gaussian.step)
new_state, info = step(rng_key, state, delta)
new_state, info = step(rng_key, state)
Parameters
----------
Expand All @@ -180,22 +247,40 @@ class mgrad_gaussian:
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
covariance: Array,
covariance: Optional[Array] = None,
mean: Optional[Array] = None,
cov_svd: Optional[CovarianceSVD] = None,
step_size: float = 1.0,
) -> SamplingAlgorithm:
init, kernel = init_and_kernel(logdensity_fn, covariance, mean)
if cov_svd is None:
if covariance is None:
raise ValueError("Either covariance or cov_svd must be provided.")
cov_svd = svd_from_covariance(covariance)

U, Gamma, U_t = cov_svd

if mean is not None:
logdensity_fn = generate_mean_shifted_logprob(
logdensity_fn, mean, covariance
)

kernel = cls.build_kernel(cov_svd)

def init_fn(position: Array):
return init(position)
return init(position, logdensity_fn, U_t)

def step_fn(rng_key: PRNGKey, state, delta: float):
def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
delta,
logdensity_fn,
step_size,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)
17 changes: 12 additions & 5 deletions tests/mcmc/test_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import numpy as np
from absl.testing import absltest, parameterized

from blackjax.mcmc.marginal_latent_gaussian import init_and_kernel
from blackjax.mcmc.marginal_latent_gaussian import (
build_kernel,
generate_mean_shifted_logprob,
init,
svd_from_covariance,
)


class GaussianTest(chex.TestCase):
Expand All @@ -26,14 +31,16 @@ def test_gaussian(self, seed, mean):

obs = jax.random.normal(key4, (D,))
log_pdf = lambda x: stats.multivariate_normal.logpdf(x, obs, R)
if prior_mean is not None:
log_pdf = generate_mean_shifted_logprob(log_pdf, prior_mean, C)

DELTA = 50.0

init, step = init_and_kernel(log_pdf, C, mean=prior_mean)
step = jax.jit(step)
cov_svd = svd_from_covariance(C)
_step = build_kernel(cov_svd)
step = jax.jit(lambda key, state, delta: _step(key, state, log_pdf, delta))

init_x = np.zeros((D,))
init_state = init(init_x)
init_state = init(init_x, log_pdf, cov_svd.U_t)

keys = jax.random.split(key5, n_samples)

Expand Down
10 changes: 3 additions & 7 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,9 @@ def test_latent_gaussian(self):
from blackjax import mgrad_gaussian

inference_algorithm = mgrad_gaussian(
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C
)
inference_algorithm = inference_algorithm._replace(
step=functools.partial(
inference_algorithm.step,
delta=self.delta,
)
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2),
covariance=self.C,
step_size=self.delta,
)

initial_state = inference_algorithm.init(jnp.zeros((1,)))
Expand Down

0 comments on commit 029b981

Please sign in to comment.