From 1cc9f8ef66e070d0bf9d26e3ad05b2960eb41f3c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 14 May 2024 11:43:37 +0100 Subject: [PATCH] Add SGNHT --- docs/api/index.md | 3 + docs/api/sgmcmc/sgnht.md | 3 + mkdocs.yml | 1 + posteriors/sgmcmc/__init__.py | 1 + posteriors/sgmcmc/sgnht.py | 183 ++++++++++++++++++++++++++++++++++ tests/sgmcmc/test_sgnht.py | 85 ++++++++++++++++ 6 files changed, 276 insertions(+) create mode 100644 docs/api/sgmcmc/sgnht.md create mode 100644 posteriors/sgmcmc/sgnht.py create mode 100644 tests/sgmcmc/test_sgnht.py diff --git a/docs/api/index.md b/docs/api/index.md index 56785e1e..94b8ece2 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -27,6 +27,9 @@ All Laplace transforms leave the parameters unmodified. Comprehensive details on - [`sgmcmc.sghmc`](sgmcmc/sghmc.md) implements the stochastic gradient Hamiltonian Monte Carlo (SGHMC) algorithm from [Chen et al, 2014](https://arxiv.org/abs/1402.4102) (without momenta resampling). +- [`sgmcmc.sgnht`](sgmcmc/sgnht.md) implements the stochastic gradient Nosé-Hoover +thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf), +(SGHMC with adaptive friction coefficient). For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696). diff --git a/docs/api/sgmcmc/sgnht.md b/docs/api/sgmcmc/sgnht.md new file mode 100644 index 00000000..95131f5f --- /dev/null +++ b/docs/api/sgmcmc/sgnht.md @@ -0,0 +1,3 @@ +# SGNHT + +::: posteriors.sgmcmc.sgnht \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 3839563c..e28518d5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,6 +77,7 @@ nav: - SGMCMC: - api/sgmcmc/sgld.md - api/sgmcmc/sghmc.md + - api/sgmcmc/sgnht.md - VI: - Diag: api/vi/diag.md - api/optim.md diff --git a/posteriors/sgmcmc/__init__.py b/posteriors/sgmcmc/__init__.py index 89f1eca3..3bfb4917 100644 --- a/posteriors/sgmcmc/__init__.py +++ b/posteriors/sgmcmc/__init__.py @@ -1,2 +1,3 @@ from posteriors.sgmcmc import sgld from posteriors.sgmcmc import sghmc +from posteriors.sgmcmc import sgnht diff --git a/posteriors/sgmcmc/sgnht.py b/posteriors/sgmcmc/sgnht.py new file mode 100644 index 00000000..74cb1c3e --- /dev/null +++ b/posteriors/sgmcmc/sgnht.py @@ -0,0 +1,183 @@ +from typing import Any +from functools import partial +import torch +from torch.func import grad_and_value +from optree import tree_map +from optree.integration.torch import tree_ravel +from dataclasses import dataclass + +from posteriors.types import TensorTree, Transform, LogProbFn, TransformState +from posteriors.tree_utils import flexi_tree_map +from posteriors.utils import is_scalar, CatchAuxError + + +def build( + log_posterior: LogProbFn, + lr: float, + alpha: float = 0.01, + beta: float = 0.0, + temperature: float = 1.0, + momenta: TensorTree | float | None = None, + xi: float = None, +) -> Transform: + """Builds SGNHT transform. + + Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf): + + \\begin{align} + θ_{t+1} &= θ_t + ε m_t \\\\ + m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t + + N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\ + ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T) + \\end{align} + + for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$. + + Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$. + + The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md) + to ensure robust scaling for a large amount of data and variable batch size. + + Args: + log_posterior: Function that takes parameters and input batch and + returns the log posterior value (which can be unnormalised) + as well as auxiliary information, e.g. from the model call. + lr: Learning rate. + alpha: Friction coefficient. + beta: Gradient noise coefficient (estimated variance). + temperature: Temperature of the joint parameter + momenta distribution. + momenta: Initial momenta. Can be tree like params or scalar. + Defaults to random iid samples from N(0, 1). + xi: Initial value for scalar thermostat ξ. Defaults to `alpha`. + + Returns: + SGNHT transform instance. + """ + init_fn = partial(init, momenta=momenta, xi=xi or alpha) + update_fn = partial( + update, + log_posterior=log_posterior, + lr=lr, + alpha=alpha, + beta=beta, + temperature=temperature, + ) + return Transform(init_fn, update_fn) + + +@dataclass +class SGNHTState(TransformState): + """State encoding params and momenta for SGNHT. + + Args: + params: Parameters. + momenta: Momenta for each parameter. + log_posterior: Log posterior evaluation. + aux: Auxiliary information from the log_posterior call. + """ + + params: TensorTree + momenta: TensorTree + xi: float + log_posterior: torch.tensor = None + aux: Any = None + + +def init( + params: TensorTree, momenta: TensorTree | float | None = None, xi: float = 0.01 +) -> SGNHTState: + """Initialise momenta for SGNHT. + + Args: + params: Parameters for which to initialise. + momenta: Initial momenta. Can be tree like params or scalar. + Defaults to random iid samples from N(0, 1). + xi: Initial value for scalar thermostat ξ. + + Returns: + Initial SGNHTState containing momenta. + """ + if momenta is None: + momenta = tree_map( + lambda x: torch.randn_like(x, requires_grad=x.requires_grad), + params, + ) + elif is_scalar(momenta): + momenta = tree_map( + lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad), + params, + ) + + return SGNHTState(params, momenta, xi) + + +def update( + state: SGNHTState, + batch: Any, + log_posterior: LogProbFn, + lr: float, + alpha: float = 0.01, + beta: float = 0.0, + temperature: float = 1.0, + inplace: bool = False, +) -> SGNHTState: + """Updates parameters, momenta and xi for SGNHT. + + Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf): + + \\begin{align} + θ_{t+1} &= θ_t + ε m_t \\ + m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t + + N(0, ε T (2 α - ε β T) \\mathbb{I})\\ + ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T) + \\end{align} + + for learning rate $\\epsilon$ and temperature $T$ + + Args: + state: SGNHTState containing params, momenta and xi. + batch: Data batch to be send to log_posterior. + log_posterior: Function that takes parameters and input batch and + returns the log posterior value (which can be unnormalised) + as well as auxiliary information, e.g. from the model call. + lr: Learning rate. + alpha: Friction coefficient. + beta: Gradient noise coefficient (estimated variance). + temperature: Temperature of the joint parameter + momenta distribution. + inplace: Whether to modify state in place. + + Returns: + Updated SGNHTState + (which are pointers to the inputted state tensors if inplace=True). + """ + with torch.no_grad(), CatchAuxError(): + grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)( + state.params, batch + ) + + def transform_params(p, m): + return p + lr * m + + def transform_momenta(m, g): + return ( + m + + lr * g + - lr * state.xi * m + + (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5 + * torch.randn_like(m) + ) + + m_flat, _ = tree_ravel(state.momenta) + xi_new = state.xi + lr * (torch.mean(m_flat**2) - temperature) + + params = flexi_tree_map( + transform_params, state.params, state.momenta, inplace=inplace + ) + momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace) + + if inplace: + state.xi = xi_new + state.log_posterior = log_post.detach() + state.aux = aux + return state + return SGNHTState(params, momenta, xi_new, log_post.detach(), aux) diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py new file mode 100644 index 00000000..f0d38aba --- /dev/null +++ b/tests/sgmcmc/test_sgnht.py @@ -0,0 +1,85 @@ +from functools import partial +import torch +from optree import tree_map +from optree.integration.torch import tree_ravel + +from posteriors.sgmcmc import sgnht + +from tests.scenarios import batch_normal_log_prob + + +def test_sgnht(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10} + target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) + + target_mean_flat = tree_ravel(target_mean)[0] + target_cov = torch.diag(tree_ravel(target_sds)[0] ** 2) + + batch = torch.arange(10).reshape(-1, 1) + + batch_normal_log_prob_spec = partial( + batch_normal_log_prob, mean=target_mean, sd_diag=target_sds + ) + + n_steps = 10000 + lr = 1e-2 + alpha = 1.0 + beta = 0.0 + + params = tree_map(lambda x: torch.zeros_like(x), target_mean) + init_params_copy = tree_map(lambda x: x.clone(), params) + + sampler = sgnht.build(batch_normal_log_prob_spec, lr=lr, alpha=alpha, beta=beta) + + # Test inplace = False + sgnht_state = sampler.init(params) + log_posts = [] + all_params = tree_map(lambda x: x.unsqueeze(0), params) + + for _ in range(n_steps): + sgnht_state = sampler.update(sgnht_state, batch, inplace=False) + + all_params = tree_map( + lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params + ) + + log_posts.append(sgnht_state.log_posterior.item()) + + burnin = 5000 + all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params) + sampled_mean = all_params_flat[burnin:].mean(0) + sampled_cov = torch.cov(all_params_flat[burnin:].T) + + assert log_posts[-1] > log_posts[0] + assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1) + assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1) + assert tree_map( + lambda x, y: torch.all(x == y), params, init_params_copy + ) # Check that the parameters are not updated + + # Test inplace = True + sgnht_state = sampler.init(params, momenta=0.0) + log_posts = [] + all_params = tree_map(lambda x: x.unsqueeze(0), params) + + for _ in range(n_steps): + sgnht_state = sampler.update(sgnht_state, batch, inplace=True) + + all_params = tree_map( + lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params + ) + + log_posts.append(sgnht_state.log_posterior.item()) + + burnin = 5000 + all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params) + sampled_mean = all_params_flat[burnin:].mean(0) + sampled_cov = torch.cov(all_params_flat[burnin:].T) + + assert log_posts[-1] > log_posts[0] + assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1) + assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1) + assert tree_map( + lambda x, y: torch.all(x != y), params, init_params_copy + ) # Check that the parameters are updated