Skip to content

Commit

Permalink
Merge pull request #89 from normal-computing/sgnht
Browse files Browse the repository at this point in the history
Add SGNHT
  • Loading branch information
SamDuffield authored May 14, 2024
2 parents d557486 + 1cc9f8e commit 8202a93
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ All Laplace transforms leave the parameters unmodified. Comprehensive details on
- [`sgmcmc.sghmc`](sgmcmc/sghmc.md) implements the stochastic gradient Hamiltonian
Monte Carlo (SGHMC) algorithm from [Chen et al, 2014](https://arxiv.org/abs/1402.4102)
(without momenta resampling).
- [`sgmcmc.sgnht`](sgmcmc/sgnht.md) implements the stochastic gradient Nosé-Hoover
thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf),
(SGHMC with adaptive friction coefficient).

For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696).

Expand Down
3 changes: 3 additions & 0 deletions docs/api/sgmcmc/sgnht.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SGNHT

::: posteriors.sgmcmc.sgnht
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ nav:
- SGMCMC:
- api/sgmcmc/sgld.md
- api/sgmcmc/sghmc.md
- api/sgmcmc/sgnht.md
- VI:
- Diag: api/vi/diag.md
- api/optim.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from posteriors.sgmcmc import sgld
from posteriors.sgmcmc import sghmc
from posteriors.sgmcmc import sgnht
183 changes: 183 additions & 0 deletions posteriors/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any
from functools import partial
import torch
from torch.func import grad_and_value
from optree import tree_map
from optree.integration.torch import tree_ravel
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import is_scalar, CatchAuxError


def build(
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
momenta: TensorTree | float | None = None,
xi: float = None,
) -> Transform:
"""Builds SGNHT transform.
Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
\\end{align}
for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$.
Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$.
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.
Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
temperature: Temperature of the joint parameter + momenta distribution.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
xi: Initial value for scalar thermostat ξ. Defaults to `alpha`.
Returns:
SGNHT transform instance.
"""
init_fn = partial(init, momenta=momenta, xi=xi or alpha)
update_fn = partial(
update,
log_posterior=log_posterior,
lr=lr,
alpha=alpha,
beta=beta,
temperature=temperature,
)
return Transform(init_fn, update_fn)


@dataclass
class SGNHTState(TransformState):
"""State encoding params and momenta for SGNHT.
Args:
params: Parameters.
momenta: Momenta for each parameter.
log_posterior: Log posterior evaluation.
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
momenta: TensorTree
xi: float
log_posterior: torch.tensor = None
aux: Any = None


def init(
params: TensorTree, momenta: TensorTree | float | None = None, xi: float = 0.01
) -> SGNHTState:
"""Initialise momenta for SGNHT.
Args:
params: Parameters for which to initialise.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
xi: Initial value for scalar thermostat ξ.
Returns:
Initial SGNHTState containing momenta.
"""
if momenta is None:
momenta = tree_map(
lambda x: torch.randn_like(x, requires_grad=x.requires_grad),
params,
)
elif is_scalar(momenta):
momenta = tree_map(
lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad),
params,
)

return SGNHTState(params, momenta, xi)


def update(
state: SGNHTState,
batch: Any,
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
inplace: bool = False,
) -> SGNHTState:
"""Updates parameters, momenta and xi for SGNHT.
Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
\\end{align}
for learning rate $\\epsilon$ and temperature $T$
Args:
state: SGNHTState containing params, momenta and xi.
batch: Data batch to be send to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
temperature: Temperature of the joint parameter + momenta distribution.
inplace: Whether to modify state in place.
Returns:
Updated SGNHTState
(which are pointers to the inputted state tensors if inplace=True).
"""
with torch.no_grad(), CatchAuxError():
grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
state.params, batch
)

def transform_params(p, m):
return p + lr * m

def transform_momenta(m, g):
return (
m
+ lr * g
- lr * state.xi * m
+ (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5
* torch.randn_like(m)
)

m_flat, _ = tree_ravel(state.momenta)
xi_new = state.xi + lr * (torch.mean(m_flat**2) - temperature)

params = flexi_tree_map(
transform_params, state.params, state.momenta, inplace=inplace
)
momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace)

if inplace:
state.xi = xi_new
state.log_posterior = log_post.detach()
state.aux = aux
return state
return SGNHTState(params, momenta, xi_new, log_post.detach(), aux)
85 changes: 85 additions & 0 deletions tests/sgmcmc/test_sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.sgmcmc import sgnht

from tests.scenarios import batch_normal_log_prob


def test_sgnht():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

target_mean_flat = tree_ravel(target_mean)[0]
target_cov = torch.diag(tree_ravel(target_sds)[0] ** 2)

batch = torch.arange(10).reshape(-1, 1)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

n_steps = 10000
lr = 1e-2
alpha = 1.0
beta = 0.0

params = tree_map(lambda x: torch.zeros_like(x), target_mean)
init_params_copy = tree_map(lambda x: x.clone(), params)

sampler = sgnht.build(batch_normal_log_prob_spec, lr=lr, alpha=alpha, beta=beta)

# Test inplace = False
sgnht_state = sampler.init(params)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
sgnht_state = sampler.update(sgnht_state, batch, inplace=False)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params
)

log_posts.append(sgnht_state.log_posterior.item())

burnin = 5000
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x == y), params, init_params_copy
) # Check that the parameters are not updated

# Test inplace = True
sgnht_state = sampler.init(params, momenta=0.0)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
sgnht_state = sampler.update(sgnht_state, batch, inplace=True)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params
)

log_posts.append(sgnht_state.log_posterior.item())

burnin = 5000
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x != y), params, init_params_copy
) # Check that the parameters are updated

0 comments on commit 8202a93

Please sign in to comment.