Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted MCLMC #675

Open
wants to merge 107 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
b60e4ca
TESTS
reubenharry May 13, 2024
0c5aa2d
TESTS
reubenharry May 13, 2024
5eeb3e1
UPDATE DOCSTRING
reubenharry May 13, 2024
4a09156
ADD STREAMING VERSION
reubenharry May 13, 2024
dfb5ee0
ADD PRECONDITIONING TO MCLMC
reubenharry May 13, 2024
2ab3365
ADD PRECONDITIONING TO TUNING FOR MCLMC
reubenharry May 13, 2024
4cc3971
UPDATE GITIGNORE
reubenharry May 13, 2024
f987da3
UPDATE GITIGNORE
reubenharry May 13, 2024
dbab9a3
UPDATE TESTS
reubenharry May 13, 2024
a7ffdb8
UPDATE TESTS
reubenharry May 13, 2024
098f5ad
UPDATE TESTS
reubenharry May 13, 2024
5bd2a3f
ADD DOCSTRING
reubenharry May 13, 2024
4fc1453
ADD TEST
reubenharry May 13, 2024
3678428
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 13, 2024
203f1fd
STREAMING AVERAGE
reubenharry May 15, 2024
fc347d6
ADD TEST
reubenharry May 15, 2024
49410f9
REFACTOR RUN_INFERENCE_ALGORITHM
reubenharry May 15, 2024
ffdca93
UPDATE DOCSTRING
reubenharry May 15, 2024
b7b7084
Precommit
reubenharry May 15, 2024
9d2601d
RESOLVE MERGE CONFLICTS
reubenharry May 15, 2024
97cfc9e
CLEAN TESTS
reubenharry May 15, 2024
45429b8
CLEAN TESTS
reubenharry May 15, 2024
beb6cbe
FIX BAD MERGE
reubenharry May 15, 2024
3e66be7
FIX BAD MERGE
reubenharry May 15, 2024
09b9dbd
ADJUSTED MCLMC
reubenharry May 15, 2024
71e3721
REMOVE BENCHMARKS:
reubenharry May 15, 2024
45bc677
ADD ADJUSTED MCLMC
reubenharry May 15, 2024
dd9fb1c
Merge branch 'preconditioned_mclmc' of https://github.com/reubenharry…
reubenharry May 15, 2024
a27dba9
GITIGNORE
reubenharry May 15, 2024
b0ac897
GITIGNORE
reubenharry May 15, 2024
7a6e42b
PRECOMMIT CLEAN UP
reubenharry May 15, 2024
2d3c3fc
FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS
reubenharry May 15, 2024
dad0060
TEMPORARILY ADD BENCHMARKS
reubenharry May 15, 2024
3ebc413
MERGE
reubenharry May 15, 2024
48fa9b1
ADD ADJUSTED MCLMC TUNING
reubenharry May 15, 2024
4330d18
CLEAN
reubenharry May 15, 2024
1c17ecf
UNIFY ADJUSTED MCLMC AND MCHMC
reubenharry May 16, 2024
6bd5ab1
ADD INITIAL_POSITION
reubenharry May 17, 2024
5615261
FIX TEST
reubenharry May 17, 2024
d66a561
Merge branch 'main' into inference_algorithm
reubenharry May 17, 2024
290addc
Merge branch 'main' into inference_algorithm
reubenharry May 18, 2024
35d4880
Merge branch 'inference_algorithm' into new_integrator
reubenharry May 18, 2024
356cd3b
CLEAN UP
reubenharry May 18, 2024
67c0002
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 18, 2024
63a8042
REMOVE BENCHMARKS
reubenharry May 18, 2024
51fee69
ADD TEST
reubenharry May 18, 2024
29994d7
REMOVE BENCHMARKS
reubenharry May 18, 2024
e4be0ae
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 18, 2024
33bda9d
REMOVE BENCHMARKS
reubenharry May 18, 2024
888fb09
MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR
reubenharry May 18, 2024
4123e4f
MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR
reubenharry May 18, 2024
64948e5
BUG FIX
reubenharry May 18, 2024
c3d44f3
CHANGE PRECISION
reubenharry May 18, 2024
94d43bd
CHANGE PRECISION
reubenharry May 18, 2024
17b7454
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 18, 2024
984dbc3
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 18, 2024
636ef43
ADD OMELYAN TEST
reubenharry May 18, 2024
78f35b6
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 18, 2024
7b16464
ADD ADJUSTED MCLMC TEST
reubenharry May 19, 2024
0a11a0f
ADD ADJUSTED MCLMC TEST
reubenharry May 19, 2024
178b452
RENAME O
reubenharry May 19, 2024
9c1c816
Merge branch 'inference_algorithm' of github.com:reubenharry/blackjax…
reubenharry May 19, 2024
db90cdc
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 19, 2024
a26d4a0
UPDATE STREAMING AVG
reubenharry May 19, 2024
0ff1d24
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 19, 2024
0ab0694
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 19, 2024
9dd740f
UPDATE STREAMING AVG
reubenharry May 19, 2024
4d03b89
FIX MERGE
reubenharry May 19, 2024
4e2b7c0
MERGE
reubenharry May 20, 2024
6bacb6c
UPDATE PR
reubenharry May 24, 2024
654dacc
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 24, 2024
d389758
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 24, 2024
cacb792
RENAME STD_MAT
reubenharry May 24, 2024
3656bb9
RENAME STD_MAT
reubenharry May 24, 2024
9c2fea7
RENAME STD_MAT
reubenharry May 24, 2024
c249a12
Merge branch 'preconditioned_mclmc' into new_integrator
reubenharry May 24, 2024
c3439bf
Merge branch 'preconditioned_mclmc' into adjusted_mclmc
reubenharry May 24, 2024
a34957e
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 24, 2024
06dd04d
MERGE MAIN
reubenharry May 25, 2024
bb5dd9f
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 25, 2024
88b06cb
MERGE MAIN
reubenharry May 26, 2024
abe707c
REMOVE COEFFICIENT EXPORTS
reubenharry May 26, 2024
4ffc91d
Merge branch 'new_integrator' into adjusted_mclmc
reubenharry May 27, 2024
0629fa8
REMOVE COEFFICIENT EXPORTS
reubenharry May 27, 2024
102e9e2
MYPY ISSUE
reubenharry May 27, 2024
f4e8064
RESOLVE MYPY ISSUE
reubenharry May 27, 2024
40b5b98
RESOLVE MYPY ISSUE
reubenharry May 27, 2024
0e0016b
RETURN EXPECTATION HISTORY
reubenharry May 30, 2024
ca7935b
FIX KWARG BUG
reubenharry Jun 2, 2024
36adb40
FIX KWARG BUG
reubenharry Jun 2, 2024
88409bf
Merge branch 'bugfix' into adjusted_mclmc
reubenharry Jun 2, 2024
c9c514d
FIX KWARG BUG IN ADJUSTED MCLMC
reubenharry Jun 2, 2024
93f0f46
MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT
reubenharry Jun 2, 2024
527e484
Merge branch 'window' into adjusted_mclmc
reubenharry Jun 2, 2024
bb2b9e7
L_proposal_factor
reubenharry Jun 3, 2024
3df3fd8
Merge branch 'main' into adjusted_mclmc
reubenharry Jun 3, 2024
da83d56
SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE
reubenharry Jun 4, 2024
31a7be6
SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE
reubenharry Jun 4, 2024
ef40045
Merge branch 'main' into adjusted_mclmc
reubenharry Jun 5, 2024
90be1be
RENAME STREAMING_AVERAGE_UPDATE ARGS IN ADJUSTED MCLMC ADAPTATION
reubenharry Jun 5, 2024
c19df21
diagnostics
reubenharry Jun 5, 2024
5ca1b22
fix bugs
reubenharry Jun 10, 2024
ccd9a28
FIX MINOR TUNING BUGS
reubenharry Jun 10, 2024
1ef4c95
UPDATE TUNING
reubenharry Jun 11, 2024
7e21f43
UPDATE TUNING
reubenharry Jun 11, 2024
17f83e2
UPDATE TUNING
reubenharry Jun 11, 2024
477b11a
Merge branch 'main' into adjusted_mclmc
junpenglao Jun 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from blackjax._version import __version__

from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size
from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
Expand All @@ -11,6 +12,7 @@
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
Expand Down Expand Up @@ -109,6 +111,7 @@ def generate_top_level_api_from(module):
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)
Expand Down Expand Up @@ -158,6 +161,7 @@ def generate_top_level_api_from(module):
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation
"ess", # diagnostics
"rhat",
]
314 changes: 314 additions & 0 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, handle_nans
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
)
from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update

Lratio_lowerbound = 0.0
Lratio_upperbound = 2.0


def adjusted_mclmc_find_L_and_step_size(
mclmc_kernel,
num_steps,
state,
rng_key,
target,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
diagonal_preconditioning=True,
params=None,
):
"""
Finds the optimal value of the parameters for the MH-MCHMC algorithm.

Parameters
----------
mclmc_kernel
The kernel function used for the MCMC algorithm.
num_steps
The number of MCMC steps that will subsequently be run, after tuning.
state
The initial state of the MCMC algorithm.
rng_key
The random number generator key.
target
The target acceptance rate for the step size adaptation.
frac_tune1
The fraction of tuning for the first step of the adaptation.
frac_tune2
The fraction of tuning for the second step of the adaptation.
frac_tune3
The fraction of tuning for the third step of the adaptation.
desired_energy_va
The desired energy variance for the MCMC algorithm.
trust_in_estimate
The trust in the estimate of optimal stepsize.
num_effective_samples
The number of effective samples for the MCMC algorithm.

Returns
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
"""

dim = pytree_size(state.position)
if params is None:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,))
)

part1_key, part2_key = jax.random.split(rng_key, 2)

(
state,
params,
params_history,
final_da_val,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=frac_tune2,
target=target,
diagonal_preconditioning=diagonal_preconditioning,
)(
state, params, num_steps, part1_key
)

if frac_tune3 != 0:
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params = adjusted_mclmc_make_adaptation_L(
mclmc_kernel, frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key1)

(
state,
params,
params_history,
final_da_val,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=0,
target=target,
fix_L_first_da=True,
diagonal_preconditioning=diagonal_preconditioning,
)(
state, params, num_steps, part2_key2
)

return state, params, params_history, final_da_val


def adjusted_mclmc_make_L_step_size_adaptation(
kernel,
dim,
frac_tune1,
frac_tune2,
target,
diagonal_preconditioning,
fix_L_first_da=False,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""

def dual_avg_step(fix_L, update_da):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

def step(iteration_state, weight_and_key):
mask, rng_key = weight_and_key
kernel_key, num_steps_key = jax.random.split(rng_key, 2)
(
previous_state,
params,
(adaptive_state, step_size_max),
previous_weight_and_average,
) = iteration_state

avg_num_integration_steps = params.L / params.step_size

state, info = kernel(
rng_key=kernel_key,
state=previous_state,
avg_num_integration_steps=avg_num_integration_steps,
step_size=params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
)

# step updating
success, state, step_size_max, energy_change = handle_nans(
previous_state,
state,
params.step_size,
step_size_max,
info.energy,
)

log_step_size, log_step_size_avg, step, avg_error, mu = update_da(
adaptive_state, info.acceptance_rate
)

adaptive_state = DualAveragingAdaptationState(
mask * log_step_size + (1 - mask) * adaptive_state.log_step_size,
mask * log_step_size_avg
+ (1 - mask) * adaptive_state.log_step_size_avg,
mask * step + (1 - mask) * adaptive_state.step,
mask * avg_error + (1 - mask) * adaptive_state.avg_error,
mask * mu + (1 - mask) * adaptive_state.mu,
)

step_size = jax.lax.clamp(
1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1
)
adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size))
# step_size = 1e-3

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
previous_weight_and_average = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=previous_weight_and_average,
weight=(1 - mask) * success * step_size,
zero_prevention=mask,
)

if fix_L:
params = params._replace(
step_size=mask * step_size + (1 - mask) * params.step_size,
)

else:
params = params._replace(
step_size=mask * step_size + (1 - mask) * params.step_size,
L=mask * (params.L * (step_size / params.step_size))
+ (1 - mask) * params.L,
)

return (
state,
params,
(adaptive_state, step_size_max),
previous_weight_and_average,
), (
info,
params,
)

return step

def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da):
return jax.lax.scan(
dual_avg_step(fix_L, update_da),
init=(
state,
params,
(initial_da(params.step_size), jnp.inf), # step size max
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(mask, keys),
)

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
num_steps * frac_tune2
)

rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2)
L_step_size_adaptation_keys_pass1 = jax.random.split(
rng_key_pass1, num_steps1 + num_steps2
)
L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1)

# determine which steps to ignore in the streaming average
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)

(
(state, params, (dual_avg_state, step_size_max), (_, average)),
(info, params_history),
) = step_size_adaptation(
mask,
state,
params,
L_step_size_adaptation_keys_pass1,
fix_L=fix_L_first_da,
initial_da=initial_da,
update_da=update_da,
)

# determine L
if num_steps2 != 0.0:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)

change = jax.lax.clamp(
Lratio_lowerbound,
jnp.sqrt(jnp.sum(variances)) / params.L,
Lratio_upperbound,
)
params = params._replace(
L=params.L * change, step_size=params.step_size * change
)
if diagonal_preconditioning:
params = params._replace(
sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim)
)

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
(
(state, params, (dual_avg_state, step_size_max), (_, average)),
(info, params_history),
) = step_size_adaptation(
jnp.ones(num_steps1),
state,
params,
L_step_size_adaptation_keys_pass2,
fix_L=True,
update_da=update_da,
initial_da=initial_da,
)

return state, params, params_history.step_size, final_da(dual_avg_state)

return L_step_size_adaptation


def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)

def step(state, key):
next_state, _ = kernel(
rng_key=key,
state=state,
step_size=params.step_size,
avg_num_integration_steps=params.L / params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
)
return next_state, next_state.position

state, samples = jax.lax.scan(
f=step,
init=state,
xs=adaptation_L_keys,
)

flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
ess = effective_sample_size(flat_samples[None, ...])

return state, params._replace(L=Lfactor * params.L * jnp.mean(num_steps / ess))

return adaptation_L
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def mclmc_find_L_and_step_size(
Example
-------
.. code::
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=std_mat,
sqrt_diag_cov=sqrt_diag_cov,
)

(
Expand Down
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import (
adjusted_mclmc,
barker,
elliptical_slice,
ghmc,
Expand All @@ -24,4 +25,5 @@
"marginal_latent_gaussian",
"random_walk",
"mclmc",
"adjusted_mclmc",
]
Loading
Loading