From 52ae85b992895ec4b5a27d2b17962fc925a3ab71 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 7 May 2024 15:24:17 +0100 Subject: [PATCH 1/4] Draft diag_ggn --- posteriors/laplace/diag_ggn.py | 166 +++++++++++++++++++++++++++++++++ posteriors/types.py | 1 + 2 files changed, 167 insertions(+) create mode 100644 posteriors/laplace/diag_ggn.py diff --git a/posteriors/laplace/diag_ggn.py b/posteriors/laplace/diag_ggn.py new file mode 100644 index 00000000..4ca3d12e --- /dev/null +++ b/posteriors/laplace/diag_ggn.py @@ -0,0 +1,166 @@ +from functools import partial +from typing import Any +import torch +from optree import tree_map +from dataclasses import dataclass + +from posteriors.types import ( + TensorTree, + Transform, + ForwardFn, + OuterLogProbFn, + TransformState, +) +from posteriors.tree_utils import flexi_tree_map +from posteriors.utils import ( + diag_normal_sample, + diag_ggn, + is_scalar, + CatchAuxError, +) + + +def build( + forward: ForwardFn, + outer_log_likelihood: OuterLogProbFn, + init_prec_diag: TensorTree | float = 0.0, +) -> Transform: + """Builds a transform for diagonal Generalized Gauss-Newton (GGN) + Laplace approximation. + + Equivalent to the diagonal of the (non-emprical) Fisher information matrix when + the `outer_log_likelihood` is exponential family with natural parameter equal to + the output from `forward`. + + `forward` should output auxiliary information (or `torch.tensor([])`), + `outer_log_likelihood` should not. + + The GGN is defined as + $$ + G(θ) = J_f(θ) H_l(z) J_f(θ)^T + $$ + where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$ + is a loss function with scalar output. + + More info on Fisher and GGN matrices can be found in + [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and + their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806). + + Args: + forward: Function that takes parameters and input batch and + returns a forward value (e.g. logits), not reduced over the batch, + as well as auxiliary information. + outer_log_likelihood: A function that takes the output of `forward` and batch + then returns the log likelihood of the model output, + with no auxiliary information. + init_prec_diag: Initial diagonal precision matrix. + Can be tree like params or scalar. + + Returns: + Diagonal GGN Laplace approximation transform instance. + """ + init_fn = partial(init, init_prec_diag=init_prec_diag) + update_fn = partial( + update, forward=forward, outer_log_likelihood=outer_log_likelihood + ) + return Transform(init_fn, update_fn) + + +@dataclass +class DiagLaplaceState(TransformState): + """State encoding a diagonal Normal distribution over parameters. + + Args: + params: Mean of the Normal distribution. + prec_diag: Diagonal of the precision matrix of the Normal distribution. + aux: Auxiliary information from the log_posterior call. + """ + + params: TensorTree + prec_diag: TensorTree + aux: Any = None + + +def init( + params: TensorTree, + init_prec_diag: TensorTree | float = 0.0, +) -> DiagLaplaceState: + """Initialise diagonal Normal distribution over parameters. + + Args: + params: Mean of the Normal distribution. + init_prec_diag: Initial diagonal precision matrix. + Can be tree like params or scalar. + + Returns: + Initial DiagLaplaceState. + """ + if is_scalar(init_prec_diag): + init_prec_diag = tree_map( + lambda x: torch.full_like(x, init_prec_diag, requires_grad=x.requires_grad), + params, + ) + + return DiagLaplaceState(params, init_prec_diag) + + +def update( + state: DiagLaplaceState, + batch: Any, + forward: ForwardFn, + outer_log_likelihood: OuterLogProbFn, + inplace: bool = False, +) -> DiagLaplaceState: + """Adds diagonal GGN matrix of covariance summed over given batch. + + Args: + state: Current state. + batch: Input data to model. + forward: Function that takes parameters and input batch and + returns a forward value (e.g. logits), not reduced over the batch, + as well as auxiliary information. + outer_log_likelihood: A function that takes the output of `forward` and batch + then returns the log likelihood of the model output, + with no auxiliary information. + inplace: If True, then the state is updated in place, otherwise a new state + is returned. + + Returns: + Updated DiagLaplaceState. + """ + with torch.no_grad(), CatchAuxError(): + diag_ggn_batch, aux = diag_ggn( + forward, + outer_log_likelihood, + forward_has_aux=True, + loss_has_aux=False, + normalize=False, + )(state.params, batch) + + def update_func(x, y): + return x + y + + prec_diag = flexi_tree_map( + update_func, state.prec_diag, diag_ggn_batch, inplace=inplace + ) + + if inplace: + state.aux = aux + return state + return DiagLaplaceState(state.params, prec_diag, aux) + + +def sample( + state: DiagLaplaceState, sample_shape: torch.Size = torch.Size([]) +) -> TensorTree: + """Sample from diagonal Normal distribution over parameters. + + Args: + state: State encoding mean and diagonal precision. + sample_shape: Shape of the desired samples. + + Returns: + Sample(s) from Normal distribution. + """ + sd_diag = tree_map(lambda x: x.sqrt().reciprocal(), state.prec_diag) + return diag_normal_sample(state.params, sd_diag, sample_shape=sample_shape) diff --git a/posteriors/types.py b/posteriors/types.py index 80b224e0..5674ccfe 100644 --- a/posteriors/types.py +++ b/posteriors/types.py @@ -9,6 +9,7 @@ LogProbFn = Callable[[TensorTree, TensorTree], Tuple[float, TensorTree]] ForwardFn = Callable[[TensorTree, TensorTree], Tuple[Tensor, TensorTree]] +OuterLogProbFn = Callable[[TensorTree, TensorTree], float] namespace = registry.__GLOBAL_NAMESPACE From ae8d5e9ee73374412c74a65d6c3aa54f73b63964 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 8 May 2024 12:28:59 +0100 Subject: [PATCH 2/4] Test diag_ggn --- posteriors/laplace/diag_ggn.py | 10 ++-- tests/laplace/test_diag_ggn.py | 97 ++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 tests/laplace/test_diag_ggn.py diff --git a/posteriors/laplace/diag_ggn.py b/posteriors/laplace/diag_ggn.py index 4ca3d12e..c48be9a3 100644 --- a/posteriors/laplace/diag_ggn.py +++ b/posteriors/laplace/diag_ggn.py @@ -28,7 +28,7 @@ def build( """Builds a transform for diagonal Generalized Gauss-Newton (GGN) Laplace approximation. - Equivalent to the diagonal of the (non-emprical) Fisher information matrix when + Equivalent to the diagonal of the (non-empirical) Fisher information matrix when the `outer_log_likelihood` is exponential family with natural parameter equal to the output from `forward`. @@ -130,15 +130,15 @@ def update( """ with torch.no_grad(), CatchAuxError(): diag_ggn_batch, aux = diag_ggn( - forward, - outer_log_likelihood, + partial(forward, batch=batch), + partial(outer_log_likelihood, batch=batch), forward_has_aux=True, loss_has_aux=False, normalize=False, - )(state.params, batch) + )(state.params) def update_func(x, y): - return x + y + return x - y prec_diag = flexi_tree_map( update_func, state.prec_diag, diag_ggn_batch, inplace=inplace diff --git a/tests/laplace/test_diag_ggn.py b/tests/laplace/test_diag_ggn.py new file mode 100644 index 00000000..5961e1c8 --- /dev/null +++ b/tests/laplace/test_diag_ggn.py @@ -0,0 +1,97 @@ +from functools import partial +import torch +from torch.distributions import Normal +from torch.utils.data import DataLoader, TensorDataset +from torch.func import functional_call +from optree import tree_map +from optree.integration.torch import tree_ravel + +from posteriors.laplace import diag_ggn + +from tests.scenarios import TestModel + + +def normal_log_likelihood(y_pred, batch): + y = batch[1] + return ( + Normal(y_pred, 1, validate_args=False).log_prob(y).sum() + ) # validate args introduces control flows not yet supported in torch.func.vmap + + +def forward_m(params, batch, model): + y_pred = functional_call(model, params, batch[0]) + return y_pred, torch.tensor([]) + + +def test_diag_ggn_vmap(): + torch.manual_seed(42) + model = TestModel() + + xs = torch.randn(100, 10) + ys = model(xs) + + dataloader = DataLoader( + TensorDataset(xs, ys), + batch_size=20, + ) + + forward = partial(forward_m, model=model) + + params = dict(model.named_parameters()) + + # Test inplace = False + transform = diag_ggn.build(forward, normal_log_likelihood) + laplace_state = transform.init(params) + laplace_state_prec_diag_init = tree_map(lambda x: x, laplace_state.prec_diag) + for batch in dataloader: + laplace_state = transform.update(laplace_state, batch, inplace=False) + + flat_params, unravel_fn = tree_ravel(params) + + expected = tree_map(lambda x: torch.zeros_like(x), params) + for x, y in zip(xs, ys): + with torch.no_grad(): + z = forward(params, (x, y))[0] + J = torch.func.jacrev(lambda fp: forward(unravel_fn(fp), (x, y)))( + flat_params + )[0] + H = torch.func.hessian(lambda zt: normal_log_likelihood(zt, (x, y)))(z) + G = J.T @ H @ J + expected = tree_map(lambda x, y: x - y, expected, unravel_fn(torch.diag(G))) + + for key in expected: + assert torch.allclose(expected[key], laplace_state.prec_diag[key], atol=1e-5) + assert not torch.allclose( + laplace_state.prec_diag[key], laplace_state_prec_diag_init[key] + ) + + # Also check full batch + laplace_state_fb = transform.init(params) + laplace_state_fb = transform.update(laplace_state_fb, (xs, ys)) + + for key in expected: + assert torch.allclose(expected[key], laplace_state_fb.prec_diag[key], atol=1e-5) + + # Test inplace = True + laplace_state = transform.init(params) + laplace_state_prec_diag_init = tree_map(lambda x: x, laplace_state.prec_diag) + for batch in dataloader: + laplace_state = transform.update(laplace_state, batch, inplace=True) + + for key in expected: + assert torch.allclose(expected[key], laplace_state.prec_diag[key], atol=1e-5) + assert torch.allclose( + laplace_state.prec_diag[key], laplace_state_prec_diag_init[key] + ) + + # Test sample + mean_copy = tree_map(lambda x: x.clone(), laplace_state.params) + samples = diag_ggn.sample(laplace_state, (1000,)) + samples_mean = tree_map(lambda x: x.mean(dim=0), samples) + samples_sd = tree_map(lambda x: x.std(dim=0), samples) + for key in samples_mean: + assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1) + assert torch.allclose( + samples_sd[key], laplace_state.prec_diag[key] ** -0.5, atol=1e-1 + ) + assert torch.allclose(mean_copy[key], laplace_state.params[key]) From cb47119e7c1a6f79e59f215ee71330feb3a671ce Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 8 May 2024 13:06:12 +0100 Subject: [PATCH 3/4] Add dense GGN --- docs/api/index.md | 11 +- docs/api/laplace/dense_ggn.md | 7 ++ docs/api/laplace/diag_ggn.md | 7 ++ mkdocs.yml | 2 + posteriors/laplace/__init__.py | 2 + posteriors/laplace/dense_fisher.py | 8 +- posteriors/laplace/dense_ggn.py | 174 +++++++++++++++++++++++++++++ posteriors/laplace/diag_ggn.py | 4 +- tests/laplace/test_dense_ggn.py | 101 +++++++++++++++++ 9 files changed, 307 insertions(+), 9 deletions(-) create mode 100644 docs/api/laplace/dense_ggn.md create mode 100644 docs/api/laplace/diag_ggn.md create mode 100644 posteriors/laplace/dense_ggn.py create mode 100644 tests/laplace/test_dense_ggn.py diff --git a/docs/api/index.md b/docs/api/index.md index 8e8ddcbc..e45f9b69 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -9,11 +9,16 @@ Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.or ### Laplace approximation - [`laplace.dense_fisher`](laplace/dense_fisher.md) calculates the empirical Fisher information matrix and uses it to approximate the posterior precision, i.e. a [Laplace -approximation](https://arxiv.org/abs/2106.14806), without modification to parameters. +approximation](https://arxiv.org/abs/2106.14806). +- [`laplace.dense_ggn`](laplace/dense_ggn.md) calculates the Generalised +Gauss-Newton matrix which is equivalent to the non-empirical Fisher in most +neural network settings. - [`laplace.diag_fisher`](laplace/diag_fisher.md) same as `laplace.dense_fisher` but -uses the diagonal empirical Fisher information matrix instead. +uses the diagonal of the empirical Fisher information matrix instead. +- [`laplace.diag_ggn`](laplace/diag_ggn.md) same as `laplace.dense_ggn` but +uses the diagonal of the Generalised Gauss-Newton matrix instead. -Comprehensive details on Laplace approximations can be found in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806). +All Laplace transforms leave the parameters unmodified. Comprehensive details on Laplace approximations can be found in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806). ### Stochastic gradient Markov chain Monte Carlo (SGMCMC) diff --git a/docs/api/laplace/dense_ggn.md b/docs/api/laplace/dense_ggn.md new file mode 100644 index 00000000..fe631f6d --- /dev/null +++ b/docs/api/laplace/dense_ggn.md @@ -0,0 +1,7 @@ +--- +title: Laplace Dense GGN +--- + +# Laplace Dense GGN + +::: posteriors.laplace.dense_ggn \ No newline at end of file diff --git a/docs/api/laplace/diag_ggn.md b/docs/api/laplace/diag_ggn.md new file mode 100644 index 00000000..1eefd1db --- /dev/null +++ b/docs/api/laplace/diag_ggn.md @@ -0,0 +1,7 @@ +--- +title: Laplace Diag GGN +--- + +# Laplace Diag GGN + +::: posteriors.laplace.diag_ggn \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index e99db0eb..f0df04f9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -71,7 +71,9 @@ nav: - Diag Fisher: api/ekf/diag_fisher.md - Laplace: - Dense Fisher: api/laplace/dense_fisher.md + - Dense GGN: api/laplace/dense_ggn.md - Diag Fisher: api/laplace/diag_fisher.md + - Diag GGN: api/laplace/diag_ggn.md - SGMCMC: - api/sgmcmc/sgld.md - api/sgmcmc/sghmc.md diff --git a/posteriors/laplace/__init__.py b/posteriors/laplace/__init__.py index 3a8a9beb..1b77e64e 100644 --- a/posteriors/laplace/__init__.py +++ b/posteriors/laplace/__init__.py @@ -1,2 +1,4 @@ from posteriors.laplace import dense_fisher from posteriors.laplace import diag_fisher +from posteriors.laplace import dense_ggn +from posteriors.laplace import diag_ggn diff --git a/posteriors/laplace/dense_fisher.py b/posteriors/laplace/dense_fisher.py index 7a598a4b..b15f39b2 100644 --- a/posteriors/laplace/dense_fisher.py +++ b/posteriors/laplace/dense_fisher.py @@ -5,7 +5,7 @@ from optree import tree_map from optree.integration.torch import tree_ravel -from posteriors.types import TensorTree, Transform, LogProbFn, Tensor, TransformState +from posteriors.types import TensorTree, Transform, LogProbFn, TransformState from posteriors.tree_utils import tree_size from posteriors.utils import ( per_samplify, @@ -18,7 +18,7 @@ def build( log_posterior: LogProbFn, per_sample: bool = False, - init_prec: Tensor | float = 0.0, + init_prec: torch.Tensor | float = 0.0, ) -> Transform: """Builds a transform for dense empirical Fisher information Laplace approximation. @@ -67,13 +67,13 @@ class DenseLaplaceState(TransformState): """ params: TensorTree - prec: Tensor + prec: torch.Tensor aux: Any = None def init( params: TensorTree, - init_prec: Tensor | float = 0.0, + init_prec: torch.Tensor | float = 0.0, ) -> DenseLaplaceState: """Initialise Normal distribution over parameters with a dense precision matrix. diff --git a/posteriors/laplace/dense_ggn.py b/posteriors/laplace/dense_ggn.py new file mode 100644 index 00000000..77a5023f --- /dev/null +++ b/posteriors/laplace/dense_ggn.py @@ -0,0 +1,174 @@ +from functools import partial +from typing import Any +import torch +from optree import tree_map +from dataclasses import dataclass +from optree.integration.torch import tree_ravel + +from posteriors.types import ( + TensorTree, + Transform, + ForwardFn, + OuterLogProbFn, + TransformState, +) +from posteriors.utils import ( + tree_size, + ggn, + is_scalar, + CatchAuxError, +) + + +def build( + forward: ForwardFn, + outer_log_likelihood: OuterLogProbFn, + init_prec: TensorTree | float = 0.0, +) -> Transform: + """Builds a transform for a Generalized Gauss-Newton (GGN) + Laplace approximation. + + Equivalent to the (non-empirical) Fisher information matrix when + the `outer_log_likelihood` is exponential family with natural parameter equal to + the output from `forward`. + + `forward` should output auxiliary information (or `torch.tensor([])`), + `outer_log_likelihood` should not. + + The GGN is defined as + $$ + G(θ) = J_f(θ) H_l(z) J_f(θ)^T + $$ + where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$ + is a negative outer log-likelihood with scalar output. + + More info on Fisher and GGN matrices can be found in + [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and + their use within a Laplace approximation in [Daxberger et al, 2021](https://arxiv.org/abs/2106.14806). + + Args: + forward: Function that takes parameters and input batch and + returns a forward value (e.g. logits), not reduced over the batch, + as well as auxiliary information. + outer_log_likelihood: A function that takes the output of `forward` and batch + then returns the log likelihood of the model output, + with no auxiliary information. + init_prec: Initial precision matrix. + If it is a float, it is defined as an identity matrix + scaled by that float. + + Returns: + GGN Laplace approximation transform instance. + """ + init_fn = partial(init, init_prec=init_prec) + update_fn = partial( + update, forward=forward, outer_log_likelihood=outer_log_likelihood + ) + return Transform(init_fn, update_fn) + + +@dataclass +class DenseLaplaceState(TransformState): + """State encoding a Normal distribution over parameters, + with a dense precision matrix + + Args: + params: Mean of the Normal distribution. + prec: Precision matrix of the Normal distribution. + aux: Auxiliary information from the log_posterior call. + """ + + params: TensorTree + prec: torch.Tensor + aux: Any = None + + +def init( + params: TensorTree, + init_prec: torch.Tensor | float = 0.0, +) -> DenseLaplaceState: + """Initialise Normal distribution over parameters + with a dense precision matrix. + + Args: + params: Mean of the Normal distribution. + init_prec: Initial precision matrix. + If it is a float, it is defined as an identity matrix + scaled by that float. + + Returns: + Initial DenseLaplaceState. + """ + + if is_scalar(init_prec): + num_params = tree_size(params) + init_prec = init_prec * torch.eye(num_params, requires_grad=False) + + return DenseLaplaceState(params, init_prec) + + +def update( + state: DenseLaplaceState, + batch: Any, + forward: ForwardFn, + outer_log_likelihood: OuterLogProbFn, + inplace: bool = False, +) -> DenseLaplaceState: + """Adds GGN matrix over given batch. + + Args: + state: Current state. + batch: Input data to model. + forward: Function that takes parameters and input batch and + returns a forward value (e.g. logits), not reduced over the batch, + as well as auxiliary information. + outer_log_likelihood: A function that takes the output of `forward` and batch + then returns the log likelihood of the model output, + with no auxiliary information. + inplace: If True, then the state is updated in place, otherwise a new state + is returned. + + Returns: + Updated DenseLaplaceState. + """ + with torch.no_grad(), CatchAuxError(): + ggn_batch, aux = ggn( + partial(forward, batch=batch), + partial(outer_log_likelihood, batch=batch), + forward_has_aux=True, + loss_has_aux=False, + normalize=False, + )(state.params) + + if inplace: + state.prec -= ggn_batch + state.aux = aux + return state + else: + return DenseLaplaceState(state.params, state.prec - ggn_batch, aux) + + +def sample( + state: DenseLaplaceState, + sample_shape: torch.Size = torch.Size([]), +) -> TensorTree: + """Sample from Normal distribution over parameters. + + Args: + state: State encoding mean and precision matrix. + sample_shape: Shape of the desired samples. + + Returns: + Sample(s) from the Normal distribution. + """ + samples = torch.distributions.MultivariateNormal( + loc=torch.zeros(state.prec.shape[0], device=state.prec.device), + precision_matrix=state.prec, + validate_args=False, + ).sample(sample_shape) + samples = samples.flatten(end_dim=-2) # ensure samples is 2D + mean_flat, unravel_func = tree_ravel(state.params) + samples += mean_flat + samples = torch.vmap(unravel_func)(samples) + samples = tree_map(lambda x: x.reshape(sample_shape + x.shape[-1:]), samples) + return samples diff --git a/posteriors/laplace/diag_ggn.py b/posteriors/laplace/diag_ggn.py index c48be9a3..7f75407b 100644 --- a/posteriors/laplace/diag_ggn.py +++ b/posteriors/laplace/diag_ggn.py @@ -25,7 +25,7 @@ def build( outer_log_likelihood: OuterLogProbFn, init_prec_diag: TensorTree | float = 0.0, ) -> Transform: - """Builds a transform for diagonal Generalized Gauss-Newton (GGN) + """Builds a transform for a diagonal Generalized Gauss-Newton (GGN) Laplace approximation. Equivalent to the diagonal of the (non-empirical) Fisher information matrix when @@ -40,7 +40,7 @@ def build( G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$ - is a loss function with scalar output. + is a negative outer log-likelihood with scalar output. More info on Fisher and GGN matrices can be found in [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and diff --git a/tests/laplace/test_dense_ggn.py b/tests/laplace/test_dense_ggn.py new file mode 100644 index 00000000..d7b3909f --- /dev/null +++ b/tests/laplace/test_dense_ggn.py @@ -0,0 +1,101 @@ +from functools import partial +import torch +from torch.distributions import Normal +from torch.utils.data import DataLoader, TensorDataset +from torch.func import functional_call +from optree import tree_map +from optree.integration.torch import tree_ravel + +from posteriors.laplace import dense_ggn + +from tests.scenarios import TestModel + + +def normal_log_likelihood(y_pred, batch): + y = batch[1] + return ( + Normal(y_pred, 1, validate_args=False).log_prob(y).sum() + ) # validate args introduces control flows not yet supported in torch.func.vmap + + +def forward_m(params, batch, model): + y_pred = functional_call(model, params, batch[0]) + return y_pred, torch.tensor([]) + + +def test_ggn_vmap(): + torch.manual_seed(42) + model = TestModel() + + xs = torch.randn(100, 10) + ys = model(xs) + + dataloader = DataLoader( + TensorDataset(xs, ys), + batch_size=20, + ) + + forward = partial(forward_m, model=model) + + params = dict(model.named_parameters()) + + # Test inplace = False + transform = dense_ggn.build(forward, normal_log_likelihood) + laplace_state = transform.init(params) + laplace_state_prec_init = laplace_state.prec + for batch in dataloader: + laplace_state = transform.update(laplace_state, batch, inplace=False) + + flat_params, unravel_fn = tree_ravel(params) + + expected = torch.zeros((flat_params.shape[0], flat_params.shape[0])) + for x, y in zip(xs, ys): + with torch.no_grad(): + z = forward(params, (x, y))[0] + J = torch.func.jacrev(lambda fp: forward(unravel_fn(fp), (x, y)))( + flat_params + )[0] + H = torch.func.hessian(lambda zt: normal_log_likelihood(zt, (x, y)))(z) + G = J.T @ H @ J + expected -= G + + assert torch.allclose(expected, laplace_state.prec, atol=1e-5) + assert not torch.allclose(laplace_state.prec, laplace_state_prec_init) + + # Also check full batch + laplace_state_fb = transform.init(params) + laplace_state_fb = transform.update(laplace_state_fb, (xs, ys)) + + assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5) + + # Test inplace = True + laplace_state = transform.init(params) + laplace_state_prec_init = laplace_state.prec + for batch in dataloader: + laplace_state = transform.update(laplace_state, batch, inplace=True) + + assert torch.allclose(expected, laplace_state.prec, atol=1e-5) + assert torch.allclose(laplace_state.prec, laplace_state_prec_init) + + # Test sampling + num_samples = 10000 + laplace_state.prec = laplace_state.prec + 0.1 * torch.eye( + flat_params.shape[0] + ) # regularize to ensure PSD and reduce variance + + mean_copy = tree_map(lambda x: x.clone(), laplace_state.params) + sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt() + + samples = dense_ggn.sample(laplace_state, (num_samples,)) + + samples_mean = tree_map(lambda x: x.mean(dim=0), samples) + samples_sd = tree_map(lambda x: x.std(dim=0), samples) + samples_sd_flat = tree_ravel(samples_sd)[0] + + for key in samples_mean: + assert samples[key].shape[0] == num_samples + assert samples[key].shape[1:] == samples_mean[key].shape + assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1) + assert torch.allclose(mean_copy[key], laplace_state.params[key]) + + assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1) From 34a5d9571e661f52136987b7f16ecb9f84a9c526 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 8 May 2024 14:31:14 +0100 Subject: [PATCH 4/4] Clarify sign of GGN and diag to diagonal in docs --- docs/api/ekf/diag_fisher.md | 4 ++-- docs/api/index.md | 2 +- docs/api/laplace/diag_fisher.md | 4 ++-- docs/api/laplace/diag_ggn.md | 4 ++-- mkdocs.yml | 8 ++++---- posteriors/laplace/dense_ggn.py | 12 ++++++++---- posteriors/laplace/diag_ggn.py | 10 +++++++--- 7 files changed, 26 insertions(+), 18 deletions(-) diff --git a/docs/api/ekf/diag_fisher.md b/docs/api/ekf/diag_fisher.md index 31050cb7..84dcad41 100644 --- a/docs/api/ekf/diag_fisher.md +++ b/docs/api/ekf/diag_fisher.md @@ -1,7 +1,7 @@ --- -title: EKF Diag Fisher +title: EKF Diagonal Fisher --- -# EKF Diag Fisher +# EKF Diagonal Fisher ::: posteriors.ekf.diag_fisher \ No newline at end of file diff --git a/docs/api/index.md b/docs/api/index.md index e45f9b69..56785e1e 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -12,7 +12,7 @@ information matrix and uses it to approximate the posterior precision, i.e. a [L approximation](https://arxiv.org/abs/2106.14806). - [`laplace.dense_ggn`](laplace/dense_ggn.md) calculates the Generalised Gauss-Newton matrix which is equivalent to the non-empirical Fisher in most -neural network settings. +neural network settings - see [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf). - [`laplace.diag_fisher`](laplace/diag_fisher.md) same as `laplace.dense_fisher` but uses the diagonal of the empirical Fisher information matrix instead. - [`laplace.diag_ggn`](laplace/diag_ggn.md) same as `laplace.dense_ggn` but diff --git a/docs/api/laplace/diag_fisher.md b/docs/api/laplace/diag_fisher.md index 3c3c135b..617ce0da 100644 --- a/docs/api/laplace/diag_fisher.md +++ b/docs/api/laplace/diag_fisher.md @@ -1,7 +1,7 @@ --- -title: Laplace Diag Fisher +title: Laplace Diagonal Fisher --- -# Laplace Diag Fisher +# Laplace Diagonal Fisher ::: posteriors.laplace.diag_fisher \ No newline at end of file diff --git a/docs/api/laplace/diag_ggn.md b/docs/api/laplace/diag_ggn.md index 1eefd1db..152a6bbc 100644 --- a/docs/api/laplace/diag_ggn.md +++ b/docs/api/laplace/diag_ggn.md @@ -1,7 +1,7 @@ --- -title: Laplace Diag GGN +title: Laplace Diagonal GGN --- -# Laplace Diag GGN +# Laplace Diagonal GGN ::: posteriors.laplace.diag_ggn \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index f0df04f9..3839563c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -68,12 +68,12 @@ nav: - API: - api/index.md - EKF: - - Diag Fisher: api/ekf/diag_fisher.md + - Diagonal Fisher: api/ekf/diag_fisher.md - Laplace: - Dense Fisher: api/laplace/dense_fisher.md - Dense GGN: api/laplace/dense_ggn.md - - Diag Fisher: api/laplace/diag_fisher.md - - Diag GGN: api/laplace/diag_ggn.md + - Diagonal Fisher: api/laplace/diag_fisher.md + - Diagonal GGN: api/laplace/diag_ggn.md - SGMCMC: - api/sgmcmc/sgld.md - api/sgmcmc/sghmc.md @@ -81,7 +81,7 @@ nav: - Diag: api/vi/diag.md - api/optim.md - TorchOpt: api/torchopt.md - - api/tree_utils.md + - Tree Utils: api/tree_utils.md - api/types.md - api/utils.md diff --git a/posteriors/laplace/dense_ggn.py b/posteriors/laplace/dense_ggn.py index 77a5023f..16002d7b 100644 --- a/posteriors/laplace/dense_ggn.py +++ b/posteriors/laplace/dense_ggn.py @@ -40,7 +40,7 @@ def build( G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$ - is a negative outer log-likelihood with scalar output. + is a loss (negative log-likelihood) that maps the output of $f$ to a scalar output. More info on Fisher and GGN matrices can be found in [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and @@ -131,21 +131,25 @@ def update( Returns: Updated DenseLaplaceState. """ + + def outer_loss(z, batch): + return -outer_log_likelihood(z, batch) + with torch.no_grad(), CatchAuxError(): ggn_batch, aux = ggn( partial(forward, batch=batch), - partial(outer_log_likelihood, batch=batch), + partial(outer_loss, batch=batch), forward_has_aux=True, loss_has_aux=False, normalize=False, )(state.params) if inplace: - state.prec -= ggn_batch + state.prec += ggn_batch state.aux = aux return state else: - return DenseLaplaceState(state.params, state.prec - ggn_batch, aux) + return DenseLaplaceState(state.params, state.prec + ggn_batch, aux) def sample( diff --git a/posteriors/laplace/diag_ggn.py b/posteriors/laplace/diag_ggn.py index 7f75407b..26f5a116 100644 --- a/posteriors/laplace/diag_ggn.py +++ b/posteriors/laplace/diag_ggn.py @@ -40,7 +40,7 @@ def build( G(θ) = J_f(θ) H_l(z) J_f(θ)^T $$ where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$ - is a negative outer log-likelihood with scalar output. + is a loss (negative log-likelihood) that maps the output of $f$ to a scalar output. More info on Fisher and GGN matrices can be found in [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and @@ -128,17 +128,21 @@ def update( Returns: Updated DiagLaplaceState. """ + + def outer_loss(z, batch): + return -outer_log_likelihood(z, batch) + with torch.no_grad(), CatchAuxError(): diag_ggn_batch, aux = diag_ggn( partial(forward, batch=batch), - partial(outer_log_likelihood, batch=batch), + partial(outer_loss, batch=batch), forward_has_aux=True, loss_has_aux=False, normalize=False, )(state.params) def update_func(x, y): - return x - y + return x + y prec_diag = flexi_tree_map( update_func, state.prec_diag, diag_ggn_batch, inplace=inplace