Skip to content

Commit

Permalink
Added Dense VI! (#115)
Browse files Browse the repository at this point in the history
* Implemented Dense Hessian Laplace!

* Docs fix

* Fix tests

* Updated build function and documentation in accordance with feedback

* Added Dense VI!

* Fix Docs

* Fix Docs

* Fixed docs

* Fix consistency

* Fix sample test

* Remove test notebook

* Removed positive constraint on 'cholesky' diagonal, renamed to L_factor, updated documentation and variable and function names accordingly.

* Docs fix

* Removed covariance tracking, removed shape specifiction from new util functions
  • Loading branch information
jcqcai authored Nov 5, 2024
1 parent 30cd0e0 commit 6af158b
Show file tree
Hide file tree
Showing 9 changed files with 535 additions and 3 deletions.
8 changes: 5 additions & 3 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](


### Variational inference (VI)
- [`vi.diag`](vi/diag.md) implements a diagonal Gaussian variational distribution.
- [`vi.dense`](vi/dense.md) implements a Gaussian variational distribution.
Expects a [`torchopt`](https://github.com/metaopt/torchopt) optimizer for handling the
minimization of the NELBO. Also find `vi.diag.nelbo` for simply calculating the NELBO
with respect to a `log_posterior` and diagonal Gaussian distribution.
minimization of the NELBO. Also find `vi.dense.nelbo` for simply calculating the NELBO
with respect to a `log_posterior` and Gaussian distribution.
- [`vi.diag`](vi/diag.md) same as `vi.dense` but uses the diagonal of the Gaussian
variational distribution.

A review of variational inference can be found in [Blei et al, 2017](https://arxiv.org/abs/1601.00670).

Expand Down
7 changes: 7 additions & 0 deletions docs/api/vi/dense.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: VI Dense
---

# VI Dense

::: posteriors.vi.dense
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ nav:
- api/sgmcmc/sghmc.md
- api/sgmcmc/sgnht.md
- VI:
- Dense: api/vi/dense.md
- Diag: api/vi/diag.md
- api/optim.md
- TorchOpt: api/torchopt.md
Expand Down
2 changes: 2 additions & 0 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from posteriors.utils import diag_normal_sample
from posteriors.utils import per_samplify
from posteriors.utils import is_scalar
from posteriors.utils import L_from_flat
from posteriors.utils import L_to_flat

from posteriors.tree_utils import tree_size
from posteriors.tree_utils import tree_extract
Expand Down
35 changes: 35 additions & 0 deletions posteriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,38 @@ def is_scalar(x: Any) -> bool:
True if x is a scalar.
"""
return isinstance(x, (int, float)) or (torch.is_tensor(x) and x.numel() == 1)


def L_from_flat(L_flat: torch.Tensor) -> torch.Tensor:
"""Returns lower triangular matrix from a flat representation of its nonzero elements.
Args:
L_flat: Flat representation of nonzero lower triangular matrix elements.
Returns:
Lower triangular matrix.
"""
k = torch.tensor(L_flat.shape[0], dtype=L_flat.dtype, device=L_flat.device)
n = (-1 + (1 + 8 * k).sqrt()) / 2
num_params = round(n.item())

tril_indices = torch.tril_indices(num_params, num_params)
L = torch.zeros((num_params, num_params), device=L_flat.device)
L[tril_indices[0], tril_indices[1]] = L_flat
return L


def L_to_flat(L: torch.Tensor) -> torch.Tensor:
"""Returns flat representation of the nonzero elements of a lower triangular matrix.
Args:
L: Lower triangular matrix.
Returns:
Flat representation of the nonzero lower triangular matrix elements.
"""

num_params = L.shape[0]
tril_indices = torch.tril_indices(num_params, num_params)
L_flat = L[tril_indices[0], tril_indices[1]].clone()
return L_flat
1 change: 1 addition & 0 deletions posteriors/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from posteriors.vi import dense
from posteriors.vi import diag
283 changes: 283 additions & 0 deletions posteriors/vi/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
from typing import Callable, Any, Tuple, NamedTuple
from functools import partial
import torch
from torch.func import grad_and_value, vmap
from optree import tree_map
from optree.integration.torch import tree_ravel
import torchopt

from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import tree_size, tree_insert_
from posteriors.utils import (
is_scalar,
CatchAuxError,
L_from_flat,
L_to_flat,
)


def build(
log_posterior: Callable[[TensorTree, Any], float],
optimizer: torchopt.base.GradientTransformation,
temperature: float = 1.0,
n_samples: int = 1,
stl: bool = True,
init_L: torch.Tensor | float = 1.0,
) -> Transform:
"""Builds a transform for variational inference with a Normal
distribution over parameters.
Find $\\mu$ and $\\Sigma$ that mimimize $\\text{KL}(N(θ| \\mu, \\Sigma) || p_T(θ))$
where $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$.
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data.
For more information on variational inference see [Blei et al, 2017](https://arxiv.org/abs/1601.00670).
Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior (which can be unnormalised).
optimizer: TorchOpt functional optimizer for updating the variational
parameters. Make sure to use lower case like torchopt.adam()
temperature: Temperature to rescale (divide) log_posterior.
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$.
Returns:
Dense VI transform instance.
"""
init_fn = partial(init, optimizer=optimizer, init_L=init_L)
update_fn = partial(
update,
log_posterior=log_posterior,
optimizer=optimizer,
temperature=temperature,
n_samples=n_samples,
stl=stl,
)
return Transform(init_fn, update_fn)


class VIDenseState(NamedTuple):
"""State encoding a diagonal Normal variational distribution over parameters.
Attributes:
params: Mean of the variational distribution.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
is the covariance matrix of the variational distribution.
opt_state: TorchOpt state storing optimizer data for updating the
variational parameters.
nelbo: Negative evidence lower bound (lower is better).
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
L_factor: torch.Tensor
opt_state: torchopt.typing.OptState
nelbo: torch.tensor = torch.tensor([])
aux: Any = None


def init(
params: TensorTree,
optimizer: torchopt.base.GradientTransformation,
init_L: torch.Tensor | float = 1.0,
) -> VIDenseState:
"""Initialise diagonal Normal variational distribution over parameters.
optimizer.init will be called on flattened variational parameters so hyperparameters
such as learning rate need to pre-specified through TorchOpt's functional API:
```
import torchopt
optimizer = torchopt.adam(lr=1e-2)
vi_state = init(init_mean, optimizer)
```
It's assumed maximize=False for the optimizer, so that we minimize the NELBO.
Args:
params: Initial mean of the variational distribution.
optimizer: TorchOpt functional optimizer for updating the variational
parameters. Make sure to use lower case like torchopt.adam()
init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$,
where $\\Sigma$ is the covariance matrix of the variational distribution.
Returns:
Initial DenseVIState.
"""

num_params = tree_size(params)
if is_scalar(init_L):
init_L = init_L * torch.eye(num_params, requires_grad=True)

init_L = L_to_flat(init_L)
opt_state = optimizer.init([params, init_L])
return VIDenseState(params, init_L, opt_state)


def update(
state: VIDenseState,
batch: Any,
log_posterior: LogProbFn,
optimizer: torchopt.base.GradientTransformation,
temperature: float = 1.0,
n_samples: int = 1,
stl: bool = True,
inplace: bool = False,
) -> VIDenseState:
"""Updates the variational parameters to minimize the NELBO.
Args:
state: Current state.
batch: Input data to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior (which can be unnormalised).
optimizer: TorchOpt functional optimizer for updating the variational
parameters. Make sure to use lower case like torchopt.adam()
temperature: Temperature to rescale (divide) log_posterior.
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
inplace: Whether to modify state in place.
Returns:
Updated DenseVIState.
"""

def nelbo_L_factor(m, L_flat):
return nelbo(m, L_flat, batch, log_posterior, temperature, n_samples, stl)

with torch.no_grad(), CatchAuxError():
nelbo_grads, (nelbo_val, aux) = grad_and_value(
nelbo_L_factor, argnums=(0, 1), has_aux=True
)(state.params, state.L_factor)

updates, opt_state = optimizer.update(
nelbo_grads,
state.opt_state,
params=[state.params, state.L_factor],
inplace=inplace,
)
mean, L_factor = torchopt.apply_updates(
(state.params, state.L_factor), updates, inplace=inplace
)

if inplace:
tree_insert_(state.nelbo, nelbo_val.detach())
return state._replace(aux=aux)

return VIDenseState(mean, L_factor, opt_state, nelbo_val.detach(), aux)


def nelbo(
mean: dict,
L_factor: torch.Tensor,
batch: Any,
log_posterior: LogProbFn,
temperature: float = 1.0,
n_samples: int = 1,
stl: bool = True,
) -> Tuple[float, Any]:
"""Returns the negative evidence lower bound (NELBO) for a Normal
variational distribution over the parameters of a model.
Monte Carlo estimate with `n_samples` from q.
$$
\\text{NELBO} = - 𝔼_{q(θ)}[\\log p(y|x, θ) + \\log p(θ) - \\log q(θ) * T])
$$
for temperature $T$.
`log_posterior` expects to take parameters and input batch and return a scalar
as well as a TensorTree of any auxiliary information:
```
log_posterior_eval, aux = log_posterior(params, batch)
```
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.
Args:
mean: Mean of the variational distribution.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
is the covariance matrix of the variational distribution.
batch: Input data to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior (which can be unnormalised).
temperature: Temperature to rescale (divide) log_posterior.
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
Returns:
The sampled approximate NELBO averaged over the batch.
"""

mean_flat, unravel_func = tree_ravel(mean)
L = L_from_flat(L_factor)
cov = L @ L.T
dist = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=cov,
validate_args=False,
)

sampled_params = dist.rsample((n_samples,))
sampled_params_tree = torch.vmap(lambda s: unravel_func(s))(sampled_params)

if stl:
mean_flat.detach()
L = L_from_flat(L_factor.detach())
cov = L @ L.T
# Redefine distribution to sample from after stl
dist = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=cov,
validate_args=False,
)

# Don't use vmap for single sample, since vmap doesn't work with lots of models
if n_samples == 1:
single_param = tree_map(lambda x: x[0], sampled_params_tree)
log_p, aux = log_posterior(single_param, batch)
log_q = dist.log_prob(single_param)

else:
log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params_tree, batch)
log_q = dist.log_prob(sampled_params)

return -(log_p - log_q * temperature).mean(), aux


def sample(
state: VIDenseState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
"""Single sample from Normal distribution over parameters.
Args:
state: State encoding mean and covariance matrix.
sample_shape: Shape of the desired samples.
Returns:
Sample(s) from Normal distribution.
"""

mean_flat, unravel_func = tree_ravel(state.params)
L = L_from_flat(state.L_factor)
cov = L @ L.T

samples = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=cov,
validate_args=False,
).rsample(sample_shape)

samples = torch.vmap(unravel_func)(samples)
return samples
Loading

0 comments on commit 6af158b

Please sign in to comment.