Skip to content

Commit

Permalink
Merge pull request #75 from normal-computing/ggn
Browse files Browse the repository at this point in the history
Add `ggn` and `ggnvp`
  • Loading branch information
SamDuffield authored Apr 24, 2024
2 parents 8e66a50 + 42c0830 commit ecfeff7
Show file tree
Hide file tree
Showing 3 changed files with 532 additions and 16 deletions.
2 changes: 2 additions & 0 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from posteriors.utils import hvp
from posteriors.utils import fvp
from posteriors.utils import empirical_fisher
from posteriors.utils import ggnvp
from posteriors.utils import ggn
from posteriors.utils import cg
from posteriors.utils import diag_normal_log_prob
from posteriors.utils import diag_normal_sample
Expand Down
302 changes: 290 additions & 12 deletions posteriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial
import contextlib
import torch
from torch.func import grad, jvp, vjp, functional_call, jacrev
from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel
Expand Down Expand Up @@ -107,7 +107,7 @@ def hvp(
where H(primals) is the Hessian of f evaluated at primals.
Taken from [jacobians_hessians.html](https://pytorch.org/functorch/nightly/notebooks/jacobians_hessians.html).
Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html)
Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html).
Args:
f: A function with scalar output.
Expand All @@ -128,7 +128,7 @@ def fvp(
primals: tuple,
tangents: tuple,
has_aux: bool = False,
normalize: bool = True,
normalize: bool = False,
) -> Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]:
"""Empirical Fisher vector product.
Expand All @@ -138,10 +138,11 @@ def fvp(
The empirical Fisher is defined as:
$$
F(θ) = \\sum_i ∇_θ f_θ(x_i, y_i) ∇_θ f_θ(x_i, y_i)^T
F(θ) = J_f(θ) J_f(θ)^T
$$
where typically $f_θ(x_i, y_i)$ is the log likelihood $\\log p(y_i | x_i,θ)$ of a
model with parameters $θ$ given inputs $x_i$ and labels $y_i$.
where typically $f_θ$ is the per-sample log likelihood (with elements
$\\log p(y_i | x_i, θ)$ for a model with `primals` $θ$ given inputs $x_i$ and
labels $y_i$).
If `normalize=True`, then $F(θ)$ is divided by the number of outputs from f
(i.e. batchsize).
Expand All @@ -151,6 +152,31 @@ def fvp(
More info on empirical Fisher matrices can be found in
[Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
Examples:
```python
from functools import partial
from optree import tree_map
import torch
from posteriors import fvp
# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}
def log_likelihood_per_sample(params, batch):
output = torch.func.functional_call(model, params, batch["inputs"])
return -torch.nn.functional.cross_entropy(
output, batch["labels"], reduction="none"
)
params = dict(model.parameters())
v = tree_map(lambda x: torch.randn_like(x), params)
fvp_result = fvp(
partial(log_likelihood_per_sample, batch=batch),
(params,),
(v,)
)
```
Args:
f: A function with tensor output.
Typically this is the [per-sample log likelihood of a model](https://pytorch.org/tutorials/intermediate/per_sample_grads.html).
Expand Down Expand Up @@ -180,31 +206,53 @@ def empirical_fisher(
f: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
normalize: bool = True,
normalize: bool = False,
) -> Callable:
"""
Constructs function to compute the empirical Fisher information matrix of a function
f with respect to its parameters, defined as (unnormalized):
$$
F(θ) = \\sum_i ∇_θ f_θ(x_i, y_i) ∇_θ f_θ(x_i, y_i)^T
F(θ) = J_f(θ) J_f(θ)^T
$$
where typically $f_θ(x_i, y_i)$ is the log likelihood $\\log p(y_i | x_i,θ)$ of a
model with parameters $θ$ given inputs $x_i$ and labels $y_i$.
where typically $f_θ$ is the per-sample log likelihood (with elements
$\\log p(y_i | x_i, θ)$ for a model with `primals` $θ$ given inputs $x_i$ and
labels $y_i$).
If `normalize=True`, then $F(θ)$ is divided by the number of outputs from f
(i.e. batchsize).
The empirical Fisher will be provided as a square tensor with respect to the
ravelled parameters.
`flat_params, params_unravel = optree.tree_ravel(params)`.
Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).
More info on empirical Fisher matrices can be found in
[Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
Examples:
```python
import torch
from posteriors import empirical_fisher, per_samplify
# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}
def log_likelihood(params, batch):
output = torch.func.functional_call(model, params, batch['inputs'])
return -torch.nn.functional.cross_entropy(output, batch['labels'])
likelihood_per_sample = per_samplify(log_likelihood)
params = dict(model.parameters())
ef_result = empirical_fisher(log_likelihood_per_sample)(params, batch)
```
Args:
f: A Python function that takes one or more arguments, one of which must be a
Tensor, and returns one or more Tensors.
Typically this is the [per-sample log likelihood of a model](https://pytorch.org/tutorials/intermediate/per_sample_grads.html).
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to. Defaults to 0.
positional argument(s) to differentiate with respect to.
has_aux: Whether f returns auxiliary information.
normalize: Whether to normalize, divide by the dimension of the output from f.
Expand All @@ -213,8 +261,16 @@ def empirical_fisher(
If has_aux is True, then the function instead returns a tuple of (F, aux).
"""

def f_to_flat(*args, **kwargs):
f_out = f(*args, **kwargs)
f_out_val = f_out[0] if has_aux else f_out
f_out_val = tree_ravel(f_out_val)[0]
return (f_out_val, f_out[1]) if has_aux else f_out_val

def fisher(*args, **kwargs):
jac_output = jacrev(f, argnums=argnums, has_aux=has_aux)(*args, **kwargs)
jac_output = jacrev(f_to_flat, argnums=argnums, has_aux=has_aux)(
*args, **kwargs
)
jac = jac_output[0] if has_aux else jac_output

# Convert Jacobian to tensor, flat in parameter dimension
Expand All @@ -230,6 +286,228 @@ def fisher(*args, **kwargs):
return fisher


def ggnvp(
forward: Callable,
loss: Callable,
primals: tuple,
tangents: tuple,
forward_has_aux: bool = False,
loss_has_aux: bool = False,
normalize: bool = False,
) -> (
Tuple[float, TensorTree]
| Tuple[float, TensorTree, Any]
| Tuple[float, TensorTree, Any, Any]
):
"""Generalised Gauss-Newton vector product.
Equivalent to the (non-empirical) Fisher vector product when `loss` is the negative
log likelihood of an exponential family distribution as a function of its natural
parameter.
Defined as
$$
G(θ) = J_f(θ) H_l(z) J_f(θ)^T
$$
where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
is a loss function with scalar output.
Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
at `primals` $θ$, with dimensions `(dz, dθ)`.
And $H_l(z)$ is the Hessian of the loss function $l$ evaluated at `z = f(θ)`, with
dimensions `(dz, dz)`.
Follows API from [`torch.func.jvp`](https://pytorch.org/docs/stable/generated/torch.func.jvp.html).
More info on Fisher and GGN matrices can be found in
[Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
Examples:
```python
from functools import partial
from optree import tree_map
import torch
from posteriors import ggnvp
# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}
def forward(params, inputs):
return torch.func.functional_call(model, params, inputs)
def loss(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels)
params = dict(model.parameters())
v = tree_map(lambda x: torch.randn_like(x), params)
ggnvp_result = ggnvp(
partial(forward, inputs=batch['inputs']),
partial(loss, labels=batch['labels']),
(params,),
(v,),
)
```
Args:
forward: A function with tensor output.
loss: A function that maps the output of forward to a scalar output.
primals: Tuple of e.g. tensor or dict with tensor values to evaluate f at.
tangents: Tuple matching structure of primals.
forward_has_aux: Whether forward returns auxiliary information.
loss_has_aux: Whether loss returns auxiliary information.
normalize: Whether to normalize, divide by the first dimension of the output
from f.
Returns:
Returns a (output, ggnvp_out) tuple, where output is a tuple of
`(forward(primals), grad(loss)(forward(primals)))`.
If forward_has_aux or loss_has_aux is True, then instead returns a
(output, ggnvp_out, aux) or
(output, ggnvp_out, forward_aux, loss_aux) tuple accordingly.
"""

jvp_output = jvp(forward, primals, tangents, has_aux=forward_has_aux)
z = jvp_output[0]
Jv = jvp_output[1]
HJv_output = hvp(loss, (z,), (Jv,), has_aux=loss_has_aux)
HJv = HJv_output[1]

if normalize:
output_dim = tree_flatten(jvp_output[0])[0][0].shape[0]
HJv = tree_map(lambda x: x / output_dim, HJv)

forward_vjp = vjp(forward, *primals, has_aux=forward_has_aux)[1]
JTHJv = forward_vjp(HJv)[0]

return (jvp_output[0], HJv_output[0]), JTHJv, *jvp_output[2:], *HJv_output[2:]


def ggn(
forward: Callable,
loss: Callable,
argnums: int | Sequence[int] = 0,
forward_has_aux: bool = False,
loss_has_aux: bool = False,
normalize: bool = False,
) -> Callable:
"""
Constructs function to compute the Generalised Gauss-Newton matrix.
Equivalent to the (non-empirical) Fisher vector product when `loss` is the negative
log likelihood of an exponential family distribution as a function of its natural
parameter.
Defined as
$$
G(θ) = J_f(θ) H_l(z) J_f(θ)^T
$$
where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
is a loss function with scalar output.
Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
at `primals` $θ$. And $H_l(z)$ is the Hessian of the loss function $l$ evaluated
at `z = f(θ)`.
Requires output from `forward` to be a tensor and therefore `loss` takes a tensor as
input. Although both support `aux` output.
If `normalize=True`, then $G(θ)$ is divided by the size of the leading dimension of
outputs from `forward` (i.e. batchsize).
The GGN will be provided as a square tensor with respect to the
ravelled parameters.
`flat_params, params_unravel = optree.tree_ravel(params)`.
Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).
More info on Fisher and GGN matrices can be found in
[Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).
Examples:
```python
from functools import partial
import torch
from posteriors import ggn
# Load model that outputs logits
# Load batch = {'inputs': ..., 'labels': ...}
def forward(params, inputs):
return torch.func.functional_call(model, params, inputs)
def loss(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels)
params = dict(model.parameters())
ggn_result = ggn(
partial(forward, inputs=batch['inputs']),
partial(loss, labels=batch['labels']),
)(params)
```
Args:
forward: A function with tensor output.
loss: A function that maps the output of forward to a scalar output.
Takes a single input and returns a scalar (and possibly aux).
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate `forward` with respect to.
forward_has_aux: Whether forward returns auxiliary information.
loss_has_aux: Whether loss returns auxiliary information.
normalize: Whether to normalize, divide by the first dimension of the output
from f.
Returns:
A function with the same arguments as f that returns the tensor GGN.
If has_aux is True, then the function instead returns a tuple of (F, aux).
"""
assert argnums == 0, "Only argnums=0 is supported for now."

def internal_ggn(params):
flat_params, params_unravel = tree_ravel(params)

def flat_params_to_forward(fps):
return forward(params_unravel(fps))

jac_output = jacrev(
flat_params_to_forward, argnums=argnums, has_aux=forward_has_aux
)(flat_params)
jac = jac_output[0] if forward_has_aux else jac_output # (..., dθ)
jac = torch.stack(tree_leaves(jac))[
0
] # convert to tensor (assumes jac has tensor output)
rescale = 1 / jac.shape[0] if normalize else 1 # maybe normalize by batchsize
jac = jac.flatten(end_dim=-2) # (d, dθ)

z = forward(params)
z = z[0] if forward_has_aux else z

hess_output = jacfwd(jacrev(loss, has_aux=loss_has_aux), has_aux=loss_has_aux)(
z
)
hess = hess_output[0] if loss_has_aux else hess_output
hess = torch.stack(tree_leaves(hess))[
0
] # convert to tensor (assumes loss has tensor input)
z_ndim = hess.ndim // 2
hess = hess.flatten(start_dim=z_ndim).flatten(
end_dim=-z_ndim
) # flatten to square tensor

# Collect aux outputs
aux = []
if forward_has_aux:
aux.append(jac_output[1])
if loss_has_aux:
aux.append(loss(z)[1])

if aux:
return jac.T @ (hess @ jac) * rescale, *aux
else:
return jac.T @ (hess @ jac) * rescale

return internal_ggn


def _vdot_real_part(x: Tensor, y: Tensor) -> float:
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
Expand Down
Loading

0 comments on commit ecfeff7

Please sign in to comment.