Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dense Fisher EKF #96

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/ekf/dense_fisher.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: EKF Dense Fisher
---

# EKF Dense Fisher

::: posteriors.ekf.dense_fisher
6 changes: 4 additions & 2 deletions docs/api/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# API

### Extended Kalman filter (EKF)
- [`ekf.diag_fisher`](ekf/diag_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the diagonal empirical Fisher
- [`ekf.dense_fisher`](ekf/dense_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the empirical Fisher
information matrix as a positive-definite alternative to the Hessian.
Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209).
- [`ekf.diag_fisher`](ekf/diag_fisher.md) same as `ekf.dense_fisher` but
uses the diagonal of the empirical Fisher information matrix instead.

### Laplace approximation
- [`laplace.dense_fisher`](laplace/dense_fisher.md) calculates the empirical Fisher
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ nav:
- API:
- api/index.md
- EKF:
- Dense Fisher: api/ekf/dense_fisher.md
- Diagonal Fisher: api/ekf/diag_fisher.md
- Laplace:
- Dense Fisher: api/laplace/dense_fisher.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/ekf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from posteriors.ekf import diag_fisher
from posteriors.ekf import dense_fisher
202 changes: 202 additions & 0 deletions posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any
from functools import partial
import torch
from torch.func import grad_and_value
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.tree_utils import tree_size

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.utils import (
per_samplify,
empirical_fisher,
is_scalar,
CatchAuxError,
)


def build(
log_likelihood: LogProbFn,
lr: float,
transition_cov: torch.Tensor | float = 0.0,
per_sample: bool = False,
init_cov: torch.Tensor | float = 1.0,
) -> Transform:
"""Builds a transform to implement an extended Kalman Filter update.

EKF applies an online update to a Gaussian posterior over the parameters.

The approximate Bayesian update is based on the linearization
$$
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
$$
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
(or equivalently the likelihood inverse temperature),
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
empirical Fisher information matrix at $μ$ for data $y$.

For more information on extended Kalman filtering as well as an equivalence
to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209).

Args:
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood value as well as auxiliary information,
e.g. from the model call.
lr: Inverse temperature of the update, which behaves like a learning rate.
transition_cov: Covariance of the transition noise, to additively
inflate the covariance before the update.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
init_cov: Initial covariance of the Normal distribution. Can be torch.Tensor or scalar.

Returns:
EKF transform instance.
"""
init_fn = partial(init, init_cov=init_cov)
update_fn = partial(
update,
log_likelihood=log_likelihood,
lr=lr,
transition_cov=transition_cov,
per_sample=per_sample,
)
return Transform(init_fn, update_fn)


@dataclass
class EKFDenseState(TransformState):
"""State encoding a Normal distribution over parameters.

Args:
params: Mean of the Normal distribution.
cov: Covariance matrix of the
Normal distribution.
log_likelihood: Log likelihood of the data given the parameters.
aux: Auxiliary information from the log_likelihood call.
"""

params: TensorTree
cov: torch.Tensor
log_likelihood: float = 0
aux: Any = None


def init(
params: TensorTree,
init_cov: torch.Tensor | float = 1.0,
) -> EKFDenseState:
"""Initialise Multivariate Normal distribution over parameters.

Args:
params: Initial mean of the Normal distribution.
init_cov: Initial covariance matrix of the Multivariate Normal distribution.
If it is a float, it is defined as an identity matrix scaled by that float.

Returns:
Initial EKFDenseState.
"""
if is_scalar(init_cov):
num_params = tree_size(params)
init_cov = init_cov * torch.eye(num_params, requires_grad=False)

return EKFDenseState(params, init_cov)


def update(
state: EKFDenseState,
batch: Any,
log_likelihood: LogProbFn,
lr: float,
transition_cov: torch.Tensor | float = 0.0,
per_sample: bool = False,
inplace: bool = False,
) -> EKFDenseState:
"""Applies an extended Kalman Filter update to the Multivariate Normal distribution.
The approximate Bayesian update is based on the linearization
$$
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
$$
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
(or equivalently the likelihood inverse temperature),
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
empirical Fisher information matrix at $μ$ for data $y$.

Args:
state: Current state.
batch: Input data to log_likelihood.
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood value as well as auxiliary information,
e.g. from the model call.
lr: Inverse temperature of the update, which behaves like a learning rate.
transition_cov: Covariance of the transition noise, to additively
inflate the covariance before the update.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
inplace: Whether to update the state parameters in-place.

Returns:
Updated EKFDenseState.
"""
if not per_sample:
log_likelihood = per_samplify(log_likelihood)

with torch.no_grad(), CatchAuxError():

def log_likelihood_reduced(params, batch):
per_samp_log_lik, internal_aux = log_likelihood(params, batch)
return per_samp_log_lik.mean(), internal_aux

grad, (log_liks, aux) = grad_and_value(log_likelihood_reduced, has_aux=True)(
state.params, batch
)
fisher, _ = empirical_fisher(
lambda p: log_likelihood(p, batch), has_aux=True, normalize=True
)(state.params)

predict_cov = state.cov + transition_cov
predict_cov_inv = torch.cholesky_inverse(torch.linalg.cholesky(predict_cov))
update_cov_inv = predict_cov_inv - lr * fisher
update_cov = torch.cholesky_inverse(torch.linalg.cholesky(update_cov_inv))

mu_raveled, mu_unravel_f = tree_ravel(state.params)
update_mean = mu_raveled + lr * update_cov @ tree_ravel(grad)[0]
update_mean = mu_unravel_f(update_mean)

if inplace:
state.params = update_mean
state.cov = update_cov
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)


def sample(
state: EKFDenseState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
"""Single sample from Multivariate Normal distribution over parameters.

Args:
state: State encoding mean and covariance.
sample_shape: Shape of the desired samples.

Returns:
Sample(s) from Multivariate Normal distribution.
"""
mean_flat, unravel_func = tree_ravel(state.params)

samples = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=state.cov,
validate_args=False,
).sample(sample_shape)

samples = torch.vmap(unravel_func)(samples)
return samples
63 changes: 63 additions & 0 deletions tests/ekf/test_dense_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from optree import tree_map
from torch.distributions import MultivariateNormal
from optree.integration.torch import tree_ravel
from posteriors.tree_utils import tree_size
from posteriors import ekf


def test_ekf_dense():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
A = torch.randn(num_params, num_params)
target_cov = torch.mm(A.t(), A)

dist = MultivariateNormal(tree_ravel(target_mean)[0], covariance_matrix=target_cov)

def log_prob(p, b):
return dist.log_prob(tree_ravel(p)[0]).sum(), torch.Tensor([])

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
batch = torch.arange(3).reshape(-1, 1)
n_steps = 1000
transform = ekf.dense_fisher.build(log_prob, lr=1e-1)

# Test inplace = False
state = transform.init(init_mean)
log_liks = []
for _ in range(n_steps):
state = transform.update(state, batch, inplace=False)
log_liks.append(state.log_likelihood.item())

assert log_liks[0] < log_liks[-1]

for key in state.params:
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1)
assert not torch.allclose(state.params[key], init_mean[key])

# Test inplace = True
state = transform.init(init_mean)
log_liks = []
for _ in range(n_steps):
state = transform.update(state, batch, inplace=True)
log_liks.append(state.log_likelihood.item())

for key in state.params:
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1)
assert not torch.allclose(state.params[key], init_mean[key])

# Test sample
num_samples = 1000
samples = ekf.dense_fisher.sample(state, (num_samples,))

flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples)
samples_cov = torch.cov(flat_samples.T)

mean_copy = tree_map(lambda x: x.clone(), state.params)
samples_mean = tree_map(lambda x: x.mean(dim=0), samples)

assert torch.allclose(samples_cov, state.cov, atol=1e-1)
for key in samples_mean:
assert torch.allclose(samples_mean[key], state.params[key], atol=1e-1)
assert not torch.allclose(samples_mean[key], mean_copy[key])