-
Notifications
You must be signed in to change notification settings - Fork 16
Add SGNHT #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add SGNHT #89
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SGNHT | ||
|
||
::: posteriors.sgmcmc.sgnht |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from posteriors.sgmcmc import sgld | ||
from posteriors.sgmcmc import sghmc | ||
from posteriors.sgmcmc import sgnht |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from typing import Any | ||
from functools import partial | ||
import torch | ||
from torch.func import grad_and_value | ||
from optree import tree_map | ||
from optree.integration.torch import tree_ravel | ||
from dataclasses import dataclass | ||
|
||
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState | ||
from posteriors.tree_utils import flexi_tree_map | ||
from posteriors.utils import is_scalar, CatchAuxError | ||
|
||
|
||
def build( | ||
log_posterior: LogProbFn, | ||
lr: float, | ||
alpha: float = 0.01, | ||
beta: float = 0.0, | ||
temperature: float = 1.0, | ||
momenta: TensorTree | float | None = None, | ||
xi: float = None, | ||
) -> Transform: | ||
"""Builds SGNHT transform. | ||
|
||
Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf): | ||
|
||
\\begin{align} | ||
θ_{t+1} &= θ_t + ε m_t \\\\ | ||
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t | ||
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\ | ||
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T) | ||
\\end{align} | ||
|
||
for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$. | ||
|
||
Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$. | ||
|
||
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md) | ||
to ensure robust scaling for a large amount of data and variable batch size. | ||
|
||
Args: | ||
log_posterior: Function that takes parameters and input batch and | ||
returns the log posterior value (which can be unnormalised) | ||
as well as auxiliary information, e.g. from the model call. | ||
lr: Learning rate. | ||
alpha: Friction coefficient. | ||
beta: Gradient noise coefficient (estimated variance). | ||
temperature: Temperature of the joint parameter + momenta distribution. | ||
momenta: Initial momenta. Can be tree like params or scalar. | ||
Defaults to random iid samples from N(0, 1). | ||
xi: Initial value for scalar thermostat ξ. Defaults to `alpha`. | ||
|
||
Returns: | ||
SGNHT transform instance. | ||
""" | ||
init_fn = partial(init, momenta=momenta, xi=xi or alpha) | ||
update_fn = partial( | ||
update, | ||
log_posterior=log_posterior, | ||
lr=lr, | ||
alpha=alpha, | ||
beta=beta, | ||
temperature=temperature, | ||
) | ||
return Transform(init_fn, update_fn) | ||
|
||
|
||
@dataclass | ||
class SGNHTState(TransformState): | ||
"""State encoding params and momenta for SGNHT. | ||
|
||
Args: | ||
params: Parameters. | ||
momenta: Momenta for each parameter. | ||
log_posterior: Log posterior evaluation. | ||
aux: Auxiliary information from the log_posterior call. | ||
""" | ||
|
||
params: TensorTree | ||
momenta: TensorTree | ||
xi: float | ||
log_posterior: torch.tensor = None | ||
aux: Any = None | ||
|
||
|
||
def init( | ||
params: TensorTree, momenta: TensorTree | float | None = None, xi: float = 0.01 | ||
) -> SGNHTState: | ||
"""Initialise momenta for SGNHT. | ||
|
||
Args: | ||
params: Parameters for which to initialise. | ||
momenta: Initial momenta. Can be tree like params or scalar. | ||
Defaults to random iid samples from N(0, 1). | ||
xi: Initial value for scalar thermostat ξ. | ||
|
||
Returns: | ||
Initial SGNHTState containing momenta. | ||
""" | ||
if momenta is None: | ||
momenta = tree_map( | ||
lambda x: torch.randn_like(x, requires_grad=x.requires_grad), | ||
params, | ||
) | ||
elif is_scalar(momenta): | ||
momenta = tree_map( | ||
lambda x: torch.full_like(x, momenta, requires_grad=x.requires_grad), | ||
params, | ||
) | ||
|
||
return SGNHTState(params, momenta, xi) | ||
|
||
|
||
def update( | ||
state: SGNHTState, | ||
batch: Any, | ||
log_posterior: LogProbFn, | ||
lr: float, | ||
alpha: float = 0.01, | ||
beta: float = 0.0, | ||
temperature: float = 1.0, | ||
inplace: bool = False, | ||
) -> SGNHTState: | ||
"""Updates parameters, momenta and xi for SGNHT. | ||
|
||
Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf): | ||
|
||
\\begin{align} | ||
θ_{t+1} &= θ_t + ε m_t \\ | ||
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t | ||
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\ | ||
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T) | ||
\\end{align} | ||
|
||
for learning rate $\\epsilon$ and temperature $T$ | ||
|
||
Args: | ||
state: SGNHTState containing params, momenta and xi. | ||
batch: Data batch to be send to log_posterior. | ||
log_posterior: Function that takes parameters and input batch and | ||
returns the log posterior value (which can be unnormalised) | ||
as well as auxiliary information, e.g. from the model call. | ||
lr: Learning rate. | ||
alpha: Friction coefficient. | ||
beta: Gradient noise coefficient (estimated variance). | ||
temperature: Temperature of the joint parameter + momenta distribution. | ||
inplace: Whether to modify state in place. | ||
|
||
Returns: | ||
Updated SGNHTState | ||
(which are pointers to the inputted state tensors if inplace=True). | ||
""" | ||
with torch.no_grad(), CatchAuxError(): | ||
grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)( | ||
state.params, batch | ||
) | ||
|
||
def transform_params(p, m): | ||
return p + lr * m | ||
|
||
def transform_momenta(m, g): | ||
return ( | ||
m | ||
+ lr * g | ||
- lr * state.xi * m | ||
+ (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5 | ||
* torch.randn_like(m) | ||
) | ||
|
||
m_flat, _ = tree_ravel(state.momenta) | ||
xi_new = state.xi + lr * (torch.mean(m_flat**2) - temperature) | ||
|
||
params = flexi_tree_map( | ||
transform_params, state.params, state.momenta, inplace=inplace | ||
) | ||
momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace) | ||
|
||
if inplace: | ||
state.xi = xi_new | ||
state.log_posterior = log_post.detach() | ||
state.aux = aux | ||
return state | ||
return SGNHTState(params, momenta, xi_new, log_post.detach(), aux) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from functools import partial | ||
import torch | ||
from optree import tree_map | ||
from optree.integration.torch import tree_ravel | ||
|
||
from posteriors.sgmcmc import sgnht | ||
|
||
from tests.scenarios import batch_normal_log_prob | ||
|
||
|
||
def test_sgnht(): | ||
torch.manual_seed(42) | ||
target_mean = {"a": torch.randn(2, 1) + 10, "b": torch.randn(1, 1) + 10} | ||
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) | ||
|
||
target_mean_flat = tree_ravel(target_mean)[0] | ||
target_cov = torch.diag(tree_ravel(target_sds)[0] ** 2) | ||
|
||
batch = torch.arange(10).reshape(-1, 1) | ||
|
||
batch_normal_log_prob_spec = partial( | ||
batch_normal_log_prob, mean=target_mean, sd_diag=target_sds | ||
) | ||
|
||
n_steps = 10000 | ||
lr = 1e-2 | ||
alpha = 1.0 | ||
beta = 0.0 | ||
|
||
params = tree_map(lambda x: torch.zeros_like(x), target_mean) | ||
init_params_copy = tree_map(lambda x: x.clone(), params) | ||
|
||
sampler = sgnht.build(batch_normal_log_prob_spec, lr=lr, alpha=alpha, beta=beta) | ||
|
||
# Test inplace = False | ||
sgnht_state = sampler.init(params) | ||
log_posts = [] | ||
all_params = tree_map(lambda x: x.unsqueeze(0), params) | ||
|
||
for _ in range(n_steps): | ||
sgnht_state = sampler.update(sgnht_state, batch, inplace=False) | ||
|
||
all_params = tree_map( | ||
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params | ||
) | ||
|
||
log_posts.append(sgnht_state.log_posterior.item()) | ||
|
||
burnin = 5000 | ||
KaelanDt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params) | ||
sampled_mean = all_params_flat[burnin:].mean(0) | ||
sampled_cov = torch.cov(all_params_flat[burnin:].T) | ||
|
||
assert log_posts[-1] > log_posts[0] | ||
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1) | ||
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1) | ||
assert tree_map( | ||
lambda x, y: torch.all(x == y), params, init_params_copy | ||
) # Check that the parameters are not updated | ||
|
||
# Test inplace = True | ||
sgnht_state = sampler.init(params, momenta=0.0) | ||
log_posts = [] | ||
all_params = tree_map(lambda x: x.unsqueeze(0), params) | ||
|
||
for _ in range(n_steps): | ||
sgnht_state = sampler.update(sgnht_state, batch, inplace=True) | ||
|
||
all_params = tree_map( | ||
lambda x, y: torch.cat((x, y.unsqueeze(0))), all_params, sgnht_state.params | ||
) | ||
|
||
log_posts.append(sgnht_state.log_posterior.item()) | ||
|
||
burnin = 5000 | ||
all_params_flat = torch.vmap(lambda x: tree_ravel(x)[0])(all_params) | ||
sampled_mean = all_params_flat[burnin:].mean(0) | ||
sampled_cov = torch.cov(all_params_flat[burnin:].T) | ||
|
||
assert log_posts[-1] > log_posts[0] | ||
assert torch.allclose(sampled_mean, target_mean_flat, atol=1e-0, rtol=1e-1) | ||
assert torch.allclose(sampled_cov, target_cov, atol=1e-0, rtol=1e-1) | ||
assert tree_map( | ||
lambda x, y: torch.all(x != y), params, init_params_copy | ||
) # Check that the parameters are updated |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.