Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SGNHT #89

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ All Laplace transforms leave the parameters unmodified. Comprehensive details on
- [`sgmcmc.sghmc`](sgmcmc/sghmc.md) implements the stochastic gradient Hamiltonian
Monte Carlo (SGHMC) algorithm from [Chen et al, 2014](https://arxiv.org/abs/1402.4102)
(without momenta resampling).
- [`sgmcmc.sgnht`](sgmcmc/sgnht.md) implements the stochastic gradient Nosé-Hoover
thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf),
(SGHMC with adaptive friction coefficient).

For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696).

Expand Down
3 changes: 3 additions & 0 deletions docs/api/sgmcmc/sgnht.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SGNHT

::: posteriors.sgmcmc.sgnht
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ nav:
- SGMCMC:
- api/sgmcmc/sgld.md
- api/sgmcmc/sghmc.md
- api/sgmcmc/sgnht.md
- VI:
- Diag: api/vi/diag.md
- api/optim.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from posteriors.sgmcmc import sgld
from posteriors.sgmcmc import sghmc
from posteriors.sgmcmc import sgnht
183 changes: 183 additions & 0 deletions posteriors/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any
from functools import partial
import torch
from torch.func import grad_and_value
from optree import tree_map
from optree.integration.torch import tree_ravel
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import is_scalar, CatchAuxError


def build(
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
momenta: TensorTree | float | None = None,
xi: float = None,
SamDuffield marked this conversation as resolved.
Show resolved Hide resolved
) -> Transform:
"""Builds SGNHT transform.

Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):

\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
\\end{align}

for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$.

Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$.

The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.

Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
temperature: Temperature of the joint parameter + momenta distribution.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
xi: Initial value for scalar thermostat ξ. Defaults to `alpha`.

Returns:
SGNHT transform instance.
"""
init_fn = partial(init, momenta=momenta, xi=xi or alpha)
update_fn = partial(
update,
log_posterior=log_posterior,
lr=lr,
alpha=alpha,
beta=beta,
temperature=temperature,
)
return Transform(init_fn, update_fn)


@dataclass
class SGNHTState(TransformState):
"""State encoding params and momenta for SGNHT.

Args:
params: Parameters.
momenta: Momenta for each parameter.
log_posterior: Log posterior evaluation.
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
momenta: TensorTree
xi: float
log_posterior: torch.tensor = None
aux: Any = None


def init(
params: TensorTree, momenta: TensorTree | float | None = None, xi: float = 0.01
) -> SGNHTState:
"""Initialise momenta for SGNHT.

Args:
params: Parameters for which to initialise.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
xi: Initial value for scalar thermostat ξ.

Returns:
Initial SGNHTState containing momenta.
"""
if momenta is None:
momenta = tree_map(
lambda x: torch.randn_like(x, requires_grad=x.requires_grad),
params,
)
elif is_scalar(momenta):
momenta = tree_map(
lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad),
params,
)

return SGNHTState(params, momenta, xi)


def update(
state: SGNHTState,
batch: Any,
log_posterior: LogProbFn,
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
temperature: float = 1.0,
inplace: bool = False,
) -> SGNHTState:
"""Updates parameters, momenta and xi for SGNHT.

Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):

\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
\\end{align}

for learning rate $\\epsilon$ and temperature $T$

Args:
state: SGNHTState containing params, momenta and xi.
batch: Data batch to be send to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
temperature: Temperature of the joint parameter + momenta distribution.
inplace: Whether to modify state in place.

Returns:
Updated SGNHTState
(which are pointers to the inputted state tensors if inplace=True).
"""
with torch.no_grad(), CatchAuxError():
grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
state.params, batch
)

def transform_params(p, m):
return p + lr * m

def transform_momenta(m, g):
return (
m
+ lr * g
- lr * state.xi * m
+ (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5
* torch.randn_like(m)
)

m_flat, _ = tree_ravel(state.momenta)
xi_new = state.xi + lr * (torch.mean(m_flat**2) - temperature)

params = flexi_tree_map(
transform_params, state.params, state.momenta, inplace=inplace
)
momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace)

if inplace:
state.xi = xi_new
state.log_posterior = log_post.detach()
state.aux = aux
return state
return SGNHTState(params, momenta, xi_new, log_post.detach(), aux)
85 changes: 85 additions & 0 deletions tests/sgmcmc/test_sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.sgmcmc import sgnht

from tests.scenarios import batch_normal_log_prob


def test_sgnht():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)

target_mean_flat = tree_ravel(target_mean)[0]
target_cov = torch.diag(tree_ravel(target_sds)[0] ** 2)

batch = torch.arange(10).reshape(-1, 1)

batch_normal_log_prob_spec = partial(
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds
)

n_steps = 10000
lr = 1e-2
alpha = 1.0
beta = 0.0

params = tree_map(lambda x: torch.zeros_like(x), target_mean)
init_params_copy = tree_map(lambda x: x.clone(), params)

sampler = sgnht.build(batch_normal_log_prob_spec, lr=lr, alpha=alpha, beta=beta)

# Test inplace = False
sgnht_state = sampler.init(params)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
sgnht_state = sampler.update(sgnht_state, batch, inplace=False)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params
)

log_posts.append(sgnht_state.log_posterior.item())

burnin = 5000
KaelanDt marked this conversation as resolved.
Show resolved Hide resolved
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x == y), params, init_params_copy
) # Check that the parameters are not updated

# Test inplace = True
sgnht_state = sampler.init(params, momenta=0.0)
log_posts = []
all_params = tree_map(lambda x: x.unsqueeze(0), params)

for _ in range(n_steps):
sgnht_state = sampler.update(sgnht_state, batch, inplace=True)

all_params = tree_map(
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params
)

log_posts.append(sgnht_state.log_posterior.item())

burnin = 5000
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params)
sampled_mean = all_params_flat[burnin:].mean(0)
sampled_cov = torch.cov(all_params_flat[burnin:].T)

assert log_posts[-1] > log_posts[0]
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1)
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1)
assert tree_map(
lambda x, y: torch.all(x != y), params, init_params_copy
) # Check that the parameters are updated