Skip to content

Commit

Permalink
Add sigma param to SGHMC and SGNHT
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield committed May 15, 2024
1 parent 8202a93 commit 0664173
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
21 changes: 14 additions & 7 deletions posteriors/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def build(
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
sigma: float = 1.0,
temperature: float = 1.0,
momenta: TensorTree | float | None = None,
) -> Transform:
Expand All @@ -23,14 +24,14 @@ def build(
Algorithm from [Chen et al, 2014](https://arxiv.org/abs/1402.4102):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε α m_t
θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} α m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\
\\end{align}
for learning rate $\\epsilon$ and temperature $T$
Targets $p_T(θ, m) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm) / T)$
Targets $p_T(θ, m) \\propto \\exp( (\\log p(θ) - \\frac{1}{2σ^2} m^Tm) / T)$
with temperature $T$.
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
Expand All @@ -43,6 +44,7 @@ def build(
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
Expand All @@ -57,6 +59,7 @@ def build(
lr=lr,
alpha=alpha,
beta=beta,
sigma=sigma,
temperature=temperature,
)
return Transform(init_fn, update_fn)
Expand Down Expand Up @@ -111,6 +114,7 @@ def update(
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
sigma: float = 1.0,
temperature: float = 1.0,
inplace: bool = False,
) -> SGHMCState:
Expand All @@ -119,8 +123,8 @@ def update(
Update rule from [Chen et al, 2014](https://arxiv.org/abs/1402.4102):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε α m_t
θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} α m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\
\\end{align}
Expand All @@ -135,6 +139,7 @@ def update(
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
inplace: Whether to modify state in place.
Expand All @@ -147,14 +152,16 @@ def update(
state.params, batch
)

prec = sigma**-2

def transform_params(p, m):
return p + lr * m
return p + lr * prec * m

def transform_momenta(m, g):
return (
m
+ lr * g
- lr * alpha * m
- lr * prec * alpha * m
+ (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5
* torch.randn_like(m)
)
Expand Down
29 changes: 18 additions & 11 deletions posteriors/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def build(
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
sigma: float = 1.0,
temperature: float = 1.0,
momenta: TensorTree | float | None = None,
xi: float = None,
Expand All @@ -25,15 +26,15 @@ def build(
Algorithm from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T)
\\end{align}
for learning rate $\\epsilon$, temperature $T$ and parameter dimension $d$.
Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac12 m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$.
Targets $p_T(θ, m, ξ) \\propto \\exp( (\\log p(θ) - \\frac{1}{2σ^2} m^Tm + \\frac{d}{2}(ξ - α)^2) / T)$.
The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.
Expand All @@ -45,6 +46,7 @@ def build(
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
momenta: Initial momenta. Can be tree like params or scalar.
Defaults to random iid samples from N(0, 1).
Expand All @@ -60,6 +62,7 @@ def build(
lr=lr,
alpha=alpha,
beta=beta,
sigma=sigma,
temperature=temperature,
)
return Transform(init_fn, update_fn)
Expand Down Expand Up @@ -118,6 +121,7 @@ def update(
lr: float,
alpha: float = 0.01,
beta: float = 0.0,
sigma: float = 1.0,
temperature: float = 1.0,
inplace: bool = False,
) -> SGNHTState:
Expand All @@ -126,10 +130,10 @@ def update(
Update rule from [Ding et al, 2014](https://proceedings.neurips.cc/paper/2014/file/21fe5b8ba755eeaece7a450849876228-Paper.pdf):
\\begin{align}
θ_{t+1} &= θ_t + ε m_t \\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\
ξ_{t+1} &= ξ_t + ε (m_t^T m_t / d - T)
θ_{t+1} &= θ_t + ε σ^{-2} m_t \\\\
m_{t+1} &= m_t + ε \\nabla \\log p(θ_t, \\text{batch}) - ε σ^{-2} ξ_t m_t
+ N(0, ε T (2 α - ε β T) \\mathbb{I})\\\\
ξ_{t+1} &= ξ_t + ε (σ^{-2} d^{-1} m_t^T m_t - T)
\\end{align}
for learning rate $\\epsilon$ and temperature $T$
Expand All @@ -143,6 +147,7 @@ def update(
lr: Learning rate.
alpha: Friction coefficient.
beta: Gradient noise coefficient (estimated variance).
sigma: Standard deviation of momenta target distribution.
temperature: Temperature of the joint parameter + momenta distribution.
inplace: Whether to modify state in place.
Expand All @@ -155,20 +160,22 @@ def update(
state.params, batch
)

prec = sigma**-2

def transform_params(p, m):
return p + lr * m
return p + lr * prec * m

def transform_momenta(m, g):
return (
m
+ lr * g
- lr * state.xi * m
- lr * prec * state.xi * m
+ (temperature * lr * (2 * alpha - temperature * lr * beta)) ** 0.5
* torch.randn_like(m)
)

m_flat, _ = tree_ravel(state.momenta)
xi_new = state.xi + lr * (torch.mean(m_flat**2) - temperature)
xi_new = state.xi + lr * (prec * torch.mean(m_flat**2) - temperature)

params = flexi_tree_map(
transform_params, state.params, state.momenta, inplace=inplace
Expand Down

0 comments on commit 0664173

Please sign in to comment.