Skip to content

Commit

Permalink
Merge pull request #87 from normal-computing/laplace_ggn
Browse files Browse the repository at this point in the history
Add Laplace GGN methods
  • Loading branch information
SamDuffield authored May 8, 2024
2 parents 2b0874a + 34a5d95 commit cbf0b36
Show file tree
Hide file tree
Showing 13 changed files with 584 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/api/ekf/diag_fisher.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: EKF Diag Fisher
title: EKF Diagonal Fisher
---

# EKF Diag Fisher
# EKF Diagonal Fisher

::: posteriors.ekf.diag_fisher
11 changes: 8 additions & 3 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.or
### Laplace approximation
- [`laplace.dense_fisher`](laplace/dense_fisher.md) calculates the empirical Fisher
information matrix and uses it to approximate the posterior precision, i.e. a [Laplace
approximation](https://arxiv.org/abs/2106.14806), without modification to parameters.
approximation](https://arxiv.org/abs/2106.14806).
- [`laplace.dense_ggn`](laplace/dense_ggn.md) calculates the Generalised
Gauss-Newton matrix which is equivalent to the non-empirical Fisher in most
neural network settings - see [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
- [`laplace.diag_fisher`](laplace/diag_fisher.md) same as `laplace.dense_fisher` but
uses the diagonal empirical Fisher information matrix instead.
uses the diagonal of the empirical Fisher information matrix instead.
- [`laplace.diag_ggn`](laplace/diag_ggn.md) same as `laplace.dense_ggn` but
uses the diagonal of the Generalised Gauss-Newton matrix instead.

Comprehensive details on Laplace approximations can be found in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).
All Laplace transforms leave the parameters unmodified. Comprehensive details on Laplace approximations can be found in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).


### Stochastic gradient Markov chain Monte Carlo (SGMCMC)
Expand Down
7 changes: 7 additions & 0 deletions docs/api/laplace/dense_ggn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: Laplace Dense GGN
---

# Laplace Dense GGN

::: posteriors.laplace.dense_ggn
4 changes: 2 additions & 2 deletions docs/api/laplace/diag_fisher.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: Laplace Diag Fisher
title: Laplace Diagonal Fisher
---

# Laplace Diag Fisher
# Laplace Diagonal Fisher

::: posteriors.laplace.diag_fisher
7 changes: 7 additions & 0 deletions docs/api/laplace/diag_ggn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: Laplace Diagonal GGN
---

# Laplace Diagonal GGN

::: posteriors.laplace.diag_ggn
8 changes: 5 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,20 @@ nav:
- API:
- api/index.md
- EKF:
- Diag Fisher: api/ekf/diag_fisher.md
- Diagonal Fisher: api/ekf/diag_fisher.md
- Laplace:
- Dense Fisher: api/laplace/dense_fisher.md
- Diag Fisher: api/laplace/diag_fisher.md
- Dense GGN: api/laplace/dense_ggn.md
- Diagonal Fisher: api/laplace/diag_fisher.md
- Diagonal GGN: api/laplace/diag_ggn.md
- SGMCMC:
- api/sgmcmc/sgld.md
- api/sgmcmc/sghmc.md
- VI:
- Diag: api/vi/diag.md
- api/optim.md
- TorchOpt: api/torchopt.md
- api/tree_utils.md
- Tree Utils: api/tree_utils.md
- api/types.md
- api/utils.md

Expand Down
2 changes: 2 additions & 0 deletions posteriors/laplace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from posteriors.laplace import dense_fisher
from posteriors.laplace import diag_fisher
from posteriors.laplace import dense_ggn
from posteriors.laplace import diag_ggn
8 changes: 4 additions & 4 deletions posteriors/laplace/dense_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.types import TensorTree, Transform, LogProbFn, Tensor, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import tree_size
from posteriors.utils import (
per_samplify,
Expand All @@ -18,7 +18,7 @@
def build(
log_posterior: LogProbFn,
per_sample: bool = False,
init_prec: Tensor | float = 0.0,
init_prec: torch.Tensor | float = 0.0,
) -> Transform:
"""Builds a transform for dense empirical Fisher information
Laplace approximation.
Expand Down Expand Up @@ -67,13 +67,13 @@ class DenseLaplaceState(TransformState):
"""

params: TensorTree
prec: Tensor
prec: torch.Tensor
aux: Any = None


def init(
params: TensorTree,
init_prec: Tensor | float = 0.0,
init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
"""Initialise Normal distribution over parameters
with a dense precision matrix.
Expand Down
178 changes: 178 additions & 0 deletions posteriors/laplace/dense_ggn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from functools import partial
from typing import Any
import torch
from optree import tree_map
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.types import (
TensorTree,
Transform,
ForwardFn,
OuterLogProbFn,
TransformState,
)
from posteriors.utils import (
tree_size,
ggn,
is_scalar,
CatchAuxError,
)


def build(
forward: ForwardFn,
outer_log_likelihood: OuterLogProbFn,
init_prec: TensorTree | float = 0.0,
) -> Transform:
"""Builds a transform for a Generalized Gauss-Newton (GGN)
Laplace approximation.
Equivalent to the (non-empirical) Fisher information matrix when
the `outer_log_likelihood` is exponential family with natural parameter equal to
the output from `forward`.
`forward` should output auxiliary information (or `torch.tensor([])`),
`outer_log_likelihood` should not.
The GGN is defined as
$$
G(θ) = J_f(θ) H_l(z) J_f(θ)^T
$$
where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
is a loss (negative log-likelihood) that maps the output of $f$ to a scalar output.
More info on Fisher and GGN matrices can be found in
[Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and
their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806).
Args:
forward: Function that takes parameters and input batch and
returns a forward value (e.g. logits), not reduced over the batch,
as well as auxiliary information.
outer_log_likelihood: A function that takes the output of `forward` and batch
then returns the log likelihood of the model output,
with no auxiliary information.
init_prec: Initial precision matrix.
If it is a float, it is defined as an identity matrix
scaled by that float.
Returns:
GGN Laplace approximation transform instance.
"""
init_fn = partial(init, init_prec=init_prec)
update_fn = partial(
update, forward=forward, outer_log_likelihood=outer_log_likelihood
)
return Transform(init_fn, update_fn)


@dataclass
class DenseLaplaceState(TransformState):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix
Args:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
prec: torch.Tensor
aux: Any = None


def init(
params: TensorTree,
init_prec: torch.Tensor | float = 0.0,
) -> DenseLaplaceState:
"""Initialise Normal distribution over parameters
with a dense precision matrix.
Args:
params: Mean of the Normal distribution.
init_prec: Initial precision matrix.
If it is a float, it is defined as an identity matrix
scaled by that float.
Returns:
Initial DenseLaplaceState.
"""

if is_scalar(init_prec):
num_params = tree_size(params)
init_prec = init_prec * torch.eye(num_params, requires_grad=False)

return DenseLaplaceState(params, init_prec)


def update(
state: DenseLaplaceState,
batch: Any,
forward: ForwardFn,
outer_log_likelihood: OuterLogProbFn,
inplace: bool = False,
) -> DenseLaplaceState:
"""Adds GGN matrix over given batch.
Args:
state: Current state.
batch: Input data to model.
forward: Function that takes parameters and input batch and
returns a forward value (e.g. logits), not reduced over the batch,
as well as auxiliary information.
outer_log_likelihood: A function that takes the output of `forward` and batch
then returns the log likelihood of the model output,
with no auxiliary information.
inplace: If True, then the state is updated in place, otherwise a new state
is returned.
Returns:
Updated DenseLaplaceState.
"""

def outer_loss(z, batch):
return -outer_log_likelihood(z, batch)

with torch.no_grad(), CatchAuxError():
ggn_batch, aux = ggn(
partial(forward, batch=batch),
partial(outer_loss, batch=batch),
forward_has_aux=True,
loss_has_aux=False,
normalize=False,
)(state.params)

if inplace:
state.prec += ggn_batch
state.aux = aux
return state
else:
return DenseLaplaceState(state.params, state.prec + ggn_batch, aux)


def sample(
state: DenseLaplaceState,
sample_shape: torch.Size = torch.Size([]),
) -> TensorTree:
"""Sample from Normal distribution over parameters.
Args:
state: State encoding mean and precision matrix.
sample_shape: Shape of the desired samples.
Returns:
Sample(s) from the Normal distribution.
"""
samples = torch.distributions.MultivariateNormal(
loc=torch.zeros(state.prec.shape[0], device=state.prec.device),
precision_matrix=state.prec,
validate_args=False,
).sample(sample_shape)
samples = samples.flatten(end_dim=-2) # ensure samples is 2D
mean_flat, unravel_func = tree_ravel(state.params)
samples += mean_flat
samples = torch.vmap(unravel_func)(samples)
samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples)
return samples
Loading

0 comments on commit cbf0b36

Please sign in to comment.