From b60e4ca21b8f7591b59f7e5f02e03a67ffea07a2 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 15:18:55 -0400 Subject: [PATCH 01/71] TESTS --- blackjax/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index df527ed01..5fa7ae627 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -159,8 +159,7 @@ def run_inference_algorithm( rng_key The random state used by JAX's random numbers generator. initial_state_or_position - The initial state OR the initial position of the inference algorithm. If an initial position - is passed in, the function will automatically convert it into an initial state. + The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps From 0c5aa2d5928e566440f72e2fc8787c02aeaa9768 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 15:28:09 -0400 Subject: [PATCH 02/71] TESTS --- blackjax/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/util.py b/blackjax/util.py index 5fa7ae627..59917e68a 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -167,7 +167,7 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transform of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. From 5eeb3e11e7492aeebb80480b3091286df9c5994e Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 15:30:07 -0400 Subject: [PATCH 03/71] UPDATE DOCSTRING --- blackjax/util.py | 3 +-- explore.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 explore.py diff --git a/blackjax/util.py b/blackjax/util.py index df527ed01..e2654481c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -159,8 +159,7 @@ def run_inference_algorithm( rng_key The random state used by JAX's random numbers generator. initial_state_or_position - The initial state OR the initial position of the inference algorithm. If an initial position - is passed in, the function will automatically convert it into an initial state. + The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps diff --git a/explore.py b/explore.py new file mode 100644 index 000000000..514029420 --- /dev/null +++ b/explore.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +from benchmarks.mcmc.sampling_algorithms import samplers +import blackjax +from blackjax.mcmc.mhmclmc import mhmclmc, rescale +from blackjax.mcmc.hmc import hmc +from blackjax.mcmc.dynamic_hmc import dynamic_hmc +from blackjax.mcmc.integrators import isokinetic_mclachlan +from blackjax.util import run_inference_algorithm + + + + + +init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) + +def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + +initial_position = jnp.ones(10,) + + +def run_mclmc(logdensity_fn, num_steps, initial_position): + key = jax.random.PRNGKey(0) + init_key, tune_key, run_key = jax.random.split(key, 3) + + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + ) + + print(blackjax_mclmc_sampler_params) + +# out = run_hmc(initial_position) +out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0)) +print(out.mean(axis=0) ) + + From 4a0915673663302ceffbd33400478840576dee4b Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 16:02:30 -0400 Subject: [PATCH 04/71] ADD STREAMING VERSION --- blackjax/util.py | 55 +++++++++++++++++++++++++++++++++------------- tests/test_util.py | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 15 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index e2654481c..55b8b3e47 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -8,7 +8,7 @@ from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm +from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -142,12 +142,13 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, + initial_state: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, -) -> tuple[State, State, Info]: + streaming=False, +) -> tuple: """Wrapper to run an inference algorithm. Note that this utility function does not work for Stochastic Gradient MCMC samplers @@ -158,8 +159,8 @@ def run_inference_algorithm( ---------- rng_key The random state used by JAX's random numbers generator. - initial_state_or_position - The initial state of the inference algorithm. + initial_state + The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -170,6 +171,8 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. + streaming + if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- @@ -178,14 +181,8 @@ def run_inference_algorithm( 2. The trace of states of the inference algorithm (contains the MCMC samples). 3. The trace of the info of the inference algorithm for diagnostics. """ - init_key, sample_key = split(rng_key, 2) - try: - initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except (TypeError, ValueError, AttributeError): - # We assume initial_state is already in the right format. - initial_state = initial_state_or_position - keys = split(sample_key, num_steps) + keys = split(rng_key, num_steps) @jit def _one_step(state, xs): @@ -193,11 +190,39 @@ def _one_step(state, xs): state, info = inference_algorithm.step(rng_key, state) return state, (transform(state), info) + def _online_one_step(average_and_state, xs): + _, rng_key = xs + average, state = average_and_state + state, _ = inference_algorithm.step(rng_key, state) + average = streaming_average(transform, state, average) + return (average, state), None + if progress_bar: one_step = progress_bar_scan(num_steps)(_one_step) + online_one_step = progress_bar_scan(num_steps)(_online_one_step) else: one_step = _one_step + online_one_step = _online_one_step - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) - return final_state, state_history, info_history + if streaming: + xs = (jnp.arange(num_steps), keys) + (average, final_state), _ = lax.scan( + online_one_step, ((0, transform(initial_state)), initial_state), xs + ) + return average, transform(final_state) + + else: + xs = (jnp.arange(num_steps), keys) + final_state, (state_history, info_history) = lax.scan( + one_step, initial_state, xs + ) + return final_state, state_history, info_history + + +def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): + """streaming average of f(x)""" + total, average = streaming_avg + average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + total += weight + streaming_avg = (total, average) + return streaming_avg diff --git a/tests/test_util.py b/tests/test_util.py index a6e023074..ed5cb12f0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -33,6 +33,46 @@ def check_compatible(self, initial_state_or_position, progress_bar): transform=lambda x: x.position, ) + def test_streamning(self): + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + initial_position = jnp.ones( + 10, + ) + + init_key, run_key = jax.random.split(self.key, 2) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + average, states = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=True, + transform=lambda x: x.position, + streaming=True, + ) + + print(average) + + _, states, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + transform=lambda x: x.position, + streaming=False, + ) + + assert jnp.array_equal(states.mean(axis=0), average) + @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) From dfb5ee0146d5f4f50ba7f71b164320002eced394 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 16:25:48 -0400 Subject: [PATCH 05/71] ADD PRECONDITIONING TO MCLMC --- blackjax/mcmc/integrators.py | 129 ++++++++++++++++++++++++----------- blackjax/mcmc/mclmc.py | 44 +++--------- explore.py | 62 +++++++++-------- 3 files changed, 132 insertions(+), 103 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index f4009b16e..ea68a1c7c 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree +from jax.random import normal from blackjax.mcmc.metrics import KineticEnergy from blackjax.types import ArrayTree @@ -293,43 +294,48 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step( - momentum: ArrayTree, - logdensity_grad: ArrayTree, - step_size: float, - coef: float, - previous_kinetic_energy_change=None, - is_last_call=False, -): - """Momentum update based on Esh dynamics. - - The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` - There are no exponentials e^delta, which prevents overflows when the gradient norm - is large. - """ - del is_last_call - - flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_momentum, _ = ravel_pytree(momentum) - dims = flatten_momentum.shape[0] - normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) - momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) - delta = step_size * coef * gradient_norm / (dims - 1) - zeta = jnp.exp(-delta) - new_momentum_raw = ( - normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) - + 2 * zeta * flatten_momentum - ) - new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - next_momentum = unravel_fn(new_momentum_normalized) - kinetic_energy_change = ( - delta - - jnp.log(2) - + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) - ) * (dims - 1) - if previous_kinetic_energy_change is not None: - kinetic_energy_change += previous_kinetic_energy_change - return next_momentum, next_momentum, kinetic_energy_change +def esh_dynamics_momentum_update_one_step(std_mat): + def update( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + previous_kinetic_energy_change=None, + is_last_call=False, + ): + """Momentum update based on Esh dynamics. + + The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` + There are no exponentials e^delta, which prevents overflows when the gradient norm + is large. + """ + del is_last_call + + logdensity_grad = logdensity_grad * std_mat + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_energy_change = ( + delta + - jnp.log(2) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) * (dims - 1) + if previous_kinetic_energy_change is not None: + kinetic_energy_change += previous_kinetic_energy_change + gr = std_mat * next_momentum + return next_momentum, gr, kinetic_energy_change + + return update def format_isokinetic_state_output( @@ -348,15 +354,15 @@ def format_isokinetic_state_output( ) -def generate_isokinetic_integrator(cofficients): +def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, *args, **kwargs + logdensity_fn: Callable, std_mat : ArrayTree, *args, **kwargs ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step, + esh_dynamics_momentum_update_one_step(std_mat), position_update_fn, - cofficients, + coefficients, format_output_fn=format_isokinetic_state_output, ) return one_step @@ -368,6 +374,47 @@ def isokinetic_integrator( isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients) isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients) +def partially_refresh_momentum(momentum, rng_key, step_size, L): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key + The pseudo-random number generator key used to generate random numbers. + momentum + PyTree that the structure the output should to match. + step_size + Step size + L + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + m, unravel_fn = ravel_pytree(momentum) + dim = m.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + + + +def with_isokinetic_maruyama(integrator): + + def stochastic_integrator(init_state, step_size, L_proposal, rng_key): + + key1, key2 = jax.random.split(rng_key) + # partial refreshment + state = init_state._replace(momentum=partially_refresh_momentum(momentum=init_state.momentum, rng_key=key1, L=L_proposal, step_size=step_size * 0.5)) + # one step of the deterministic dynamics + state, info = integrator(state, step_size) + # partial refreshment + state = state._replace(momentum=partially_refresh_momentum(momentum=state.momentum, rng_key=key2, L=L_proposal, step_size=step_size * 0.5)) + return state, info + + return stochastic_integrator + FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], Tuple[ArrayTree, ArrayTree, Any], diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 406c4125d..d91fb01bf 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,12 +15,10 @@ from typing import Callable, NamedTuple import jax -import jax.numpy as jnp -from jax.flatten_util import ravel_pytree -from jax.random import normal + from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan +from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan, with_isokinetic_maruyama from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size @@ -58,8 +56,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) - -def build_kernel(logdensity_fn, integrator): +def build_kernel(logdensity_fn, std_mat, integrator): """Build a HMC kernel. Parameters @@ -78,19 +75,17 @@ def build_kernel(logdensity_fn, integrator): information about the transition. """ - step = integrator(logdensity_fn) + + print(std_mat, "foo") + step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: (position, momentum, logdensity, logdensitygrad), kinetic_change = step( - state, step_size + state, step_size, L, rng_key ) - # Langevin-like noise - momentum = partially_refresh_momentum( - momentum=momentum, rng_key=rng_key, L=L, step_size=step_size - ) return IntegratorState( position, momentum, logdensity, logdensitygrad @@ -108,6 +103,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, + std_mat=1., ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +151,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, integrator) + kernel = build_kernel(logdensity_fn, std_mat, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) @@ -166,26 +162,4 @@ def update_fn(rng_key, state): return SamplingAlgorithm(init_fn, update_fn) -def partially_refresh_momentum(momentum, rng_key, step_size, L): - """Adds a small noise to momentum and normalizes. - - Parameters - ---------- - rng_key - The pseudo-random number generator key used to generate random numbers. - momentum - PyTree that the structure the output should to match. - step_size - Step size - L - controls rate of momentum change - Returns - ------- - momentum with random change in angle - """ - m, unravel_fn = ravel_pytree(momentum) - dim = m.shape[0] - nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) - z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)) diff --git a/explore.py b/explore.py index 514029420..e97458051 100644 --- a/explore.py +++ b/explore.py @@ -1,53 +1,61 @@ import jax import jax.numpy as jnp -from benchmarks.mcmc.sampling_algorithms import samplers + import blackjax -from blackjax.mcmc.mhmclmc import mhmclmc, rescale -from blackjax.mcmc.hmc import hmc -from blackjax.mcmc.dynamic_hmc import dynamic_hmc -from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm - - - - init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) + def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) -initial_position = jnp.ones(10,) +initial_position = jnp.ones( + 10, +) -def run_mclmc(logdensity_fn, num_steps, initial_position): - key = jax.random.PRNGKey(0) - init_key, tune_key, run_key = jax.random.split(key, 3) +def run_mclmc(logdensity_fn, key, num_steps, initial_position): + init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1, std_mat=1.) + + average, states = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + progress_bar=True, + transform=lambda x: x.position, + streaming=True, ) - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, + print(average) + + _, states, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, num_steps=num_steps, - state=initial_state, - rng_key=tune_key, + progress_bar=False, + transform=lambda x: x.position, + streaming=False, ) - print(blackjax_mclmc_sampler_params) + print(states.mean(axis=0)) -# out = run_hmc(initial_position) -out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0)) -print(out.mean(axis=0) ) + return states +# out = run_hmc(initial_position) +out = run_mclmc( + logdensity_fn=logdensity_fn, + num_steps=5, + initial_position=initial_position, + key=jax.random.PRNGKey(0), +) From 2ab3365b9bef24ad971ed40f2da2b5bc8f139225 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:20:30 -0400 Subject: [PATCH 06/71] ADD PRECONDITIONING TO TUNING FOR MCLMC --- blackjax/adaptation/mclmc_adaptation.py | 149 ++++++++++++------------ blackjax/mcmc/integrators.py | 34 ++++-- blackjax/mcmc/mclmc.py | 15 ++- blackjax/util.py | 18 ++- tests/mcmc/test_sampling.py | 9 +- 5 files changed, 126 insertions(+), 99 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 4fc322e27..dc33eb21c 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size +from blackjax.util import pytree_size, streaming_average class MCLMCAdaptationState(NamedTuple): @@ -30,10 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. + std_mat + A matrix used for preconditioning. """ L: float step_size: float + std_mat: float def mclmc_find_L_and_step_size( @@ -47,6 +50,7 @@ def mclmc_find_L_and_step_size( desired_energy_var=5e-4, trust_in_estimate=1.5, num_effective_samples=150, + diagonal_preconditioning=True, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -77,39 +81,11 @@ def mclmc_find_L_and_step_size( Returns ------- A tuple containing the final state of the MCMC algorithm and the final hyperparameters. - - - Examples - ------- - - .. code:: - - # Define the kernel function - def kernel(x): - return x ** 2 - - # Define the initial state - initial_state = MCMCState(position=0, momentum=1) - - # Generate a random number generator key - rng_key = jax.random.key(0) - - # Find the optimal parameters for the MCLMC algorithm - final_state, final_params = mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=1000, - state=initial_state, - rng_key=rng_key, - frac_tune1=0.2, - frac_tune2=0.3, - frac_tune3=0.1, - desired_energy_var=1e-4, - trust_in_estimate=2.0, - num_effective_samples=200, - ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, std_mat=jnp.ones((dim,)) + ) part1_key, part2_key = jax.random.split(rng_key, 2) state, params = make_L_step_size_adaptation( @@ -120,12 +96,13 @@ def kernel(x): desired_energy_var=desired_energy_var, trust_in_estimate=trust_in_estimate, num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( - state, params, num_steps, part2_key - ) + state, params = make_adaptation_L( + mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4 + )(state, params, num_steps, part2_key) return state, params @@ -135,6 +112,7 @@ def make_L_step_size_adaptation( dim, frac_tune1, frac_tune2, + diagonal_preconditioning, desired_energy_var=1e-3, trust_in_estimate=1.5, num_effective_samples=150, @@ -150,7 +128,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel( + next_state, info = kernel(params.std_mat)( rng_key=rng_key, state=previous_state, L=params.L, @@ -185,68 +163,87 @@ def predictor(previous_state, params, adaptive_state, rng_key): ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences params_new = params._replace(step_size=step_size) - return state, params_new, params_new, (time, x_average, step_size_max), success - - def update_kalman(x, state, outer_weight, success, step_size): - """kalman filter to estimate the size of the posterior""" - time, x_average, x_squared_average = state - weight = outer_weight * step_size * success - zero_prevention = 1 - outer_weight - x_average = (time * x_average + weight * x) / ( - time + weight + zero_prevention - ) # Update with a Kalman filter - x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / ( - time + weight + zero_prevention - ) # Update with a Kalman filter - time += weight - return (time, x_average, x_squared_average) + adaptive_state = (time, x_average, step_size_max) - adap0 = (0.0, 0.0, jnp.inf) + return state, params_new, adaptive_state, success def step(iteration_state, weight_and_key): """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - outer_weight, rng_key = weight_and_key - state, params, adaptive_state, kalman_state = iteration_state - state, params, params_final, adaptive_state, success = predictor( + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state + + state, params, adaptive_state, success = predictor( state, params, adaptive_state, rng_key ) - position, _ = ravel_pytree(state.position) - kalman_state = update_kalman( - position, kalman_state, outer_weight, success, params.step_size + + # update the running average of x, x^2 + streaming_avg = streaming_average( + O=lambda x: jnp.array([x, jnp.square(x)]), + x=ravel_pytree(state.position)[0], + streaming_avg=streaming_avg, + weight=(1 - mask) * success * params.step_size, + zero_prevention=mask, ) - return (state, params_final, adaptive_state, kalman_state), None + return (state, params, adaptive_state, streaming_avg), None def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = int(num_steps * frac_tune1), int( - num_steps * frac_tune2 + num_steps1, num_steps2 = ( + int(num_steps * frac_tune1) + 1, + int(num_steps * frac_tune2) + 1, + ) + L_step_size_adaptation_keys = jax.random.split( + rng_key, num_steps1 + num_steps2 + 1 + ) + L_step_size_adaptation_keys, final_key = ( + L_step_size_adaptation_keys[:-1], + L_step_size_adaptation_keys[-1], ) - L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) # we use the last num_steps2 to compute the diagonal preconditioner - outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - # initial state of the kalman filter - kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim)) + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps - kalman_state, *_ = jax.lax.scan( + state, params, _, (_, average) = jax.lax.scan( step, - init=(state, params, adap0, kalman_state), - xs=(outer_weights, L_step_size_adaptation_keys), - length=num_steps1 + num_steps2, - ) - state, params, _, kalman_state_output = kalman_state + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, L_step_size_adaptation_keys), + )[0] L = params.L # determine L + std_mat = params.std_mat if num_steps2 != 0.0: - _, F1, F2 = kalman_state_output - variances = F2 - jnp.square(F1) + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) - return state, MCLMCAdaptationState(L, params.step_size) + if diagonal_preconditioning: + std_mat = jnp.sqrt(variances) + params = params._replace(std_mat=std_mat) + L = jnp.sqrt(dim) + + # readjust the stepsize + steps = num_steps2 // 3 # we do some small number of steps + keys = jax.random.split(final_key, steps) + state, params, _, (_, average) = jax.lax.scan( + step, + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(jnp.ones(steps), keys), + )[0] + + return state, MCLMCAdaptationState(L, params.step_size, std_mat) return L_step_size_adaptation @@ -258,7 +255,6 @@ def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) adaptation_L_keys = jax.random.split(key, num_steps) - # run kernel in the normal way def step(state, key): next_state, _ = kernel( rng_key=key, @@ -297,5 +293,4 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch (next_state, step_size_max, kinetic_change), (previous_state, step_size * reduced_step_size, 0.0), ) - return nonans, state, step_size, kinetic_change diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index ea68a1c7c..28c0aa8c3 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,8 +311,9 @@ def update( """ del is_last_call - logdensity_grad = logdensity_grad * std_mat + logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_grads = flatten_grads * std_mat flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -324,6 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) + gr = unravel_fn(new_momentum_normalized*std_mat) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -332,9 +334,8 @@ def update( ) * (dims - 1) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change - gr = std_mat * next_momentum return next_momentum, gr, kinetic_energy_change - + return update @@ -356,7 +357,7 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, std_mat : ArrayTree, *args, **kwargs + logdensity_fn: Callable, std_mat: ArrayTree, *args, **kwargs ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( @@ -374,6 +375,7 @@ def isokinetic_integrator( isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients) isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients) + def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. @@ -399,22 +401,34 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): return unravel_fn((m + z) / jnp.linalg.norm(m + z)) - def with_isokinetic_maruyama(integrator): - def stochastic_integrator(init_state, step_size, L_proposal, rng_key): - key1, key2 = jax.random.split(rng_key) # partial refreshment - state = init_state._replace(momentum=partially_refresh_momentum(momentum=init_state.momentum, rng_key=key1, L=L_proposal, step_size=step_size * 0.5)) + state = init_state._replace( + momentum=partially_refresh_momentum( + momentum=init_state.momentum, + rng_key=key1, + L=L_proposal, + step_size=step_size * 0.5, + ) + ) # one step of the deterministic dynamics state, info = integrator(state, step_size) # partial refreshment - state = state._replace(momentum=partially_refresh_momentum(momentum=state.momentum, rng_key=key2, L=L_proposal, step_size=step_size * 0.5)) + state = state._replace( + momentum=partially_refresh_momentum( + momentum=state.momentum, + rng_key=key2, + L=L_proposal, + step_size=step_size * 0.5, + ) + ) return state, info - + return stochastic_integrator + FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], Tuple[ArrayTree, ArrayTree, Any], diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d91fb01bf..62a6da735 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -16,9 +16,12 @@ import jax - from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan, with_isokinetic_maruyama +from blackjax.mcmc.integrators import ( + IntegratorState, + isokinetic_mclachlan, + with_isokinetic_maruyama, +) from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size @@ -56,6 +59,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): logdensity_grad=g, ) + def build_kernel(logdensity_fn, std_mat, integrator): """Build a HMC kernel. @@ -76,7 +80,6 @@ def build_kernel(logdensity_fn, std_mat, integrator): """ - print(std_mat, "foo") step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)) def kernel( @@ -86,7 +89,6 @@ def kernel( state, step_size, L, rng_key ) - return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( @@ -103,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - std_mat=1., + std_mat=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -160,6 +162,3 @@ def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) return SamplingAlgorithm(init_fn, update_fn) - - - diff --git a/blackjax/util.py b/blackjax/util.py index 55b8b3e47..31b82ccf9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -220,7 +220,23 @@ def _online_one_step(average_and_state, xs): def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): - """streaming average of f(x)""" + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + O + function to be averaged + x + current state + streaming_avg + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ total, average = streaming_avg average = (total * average + weight * O(x)) / (total + weight + zero_prevention) total += weight diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 39c1b811b..f4ec3651f 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -74,16 +74,17 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): + def run_mclmc(self, logdensity_fn, num_steps, initial_position, key, diagonal_preconditioning=False): init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = blackjax.mcmc.mclmc.build_kernel( + kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, + std_mat=std_mat, ) ( @@ -94,12 +95,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): num_steps=num_steps, state=initial_state, rng_key=tune_key, + diagonal_preconditioning=diagonal_preconditioning, ) sampling_alg = blackjax.mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, + std_mat=blackjax_mclmc_sampler_params.std_mat, ) _, samples, _ = run_inference_algorithm( From 4cc39713cd46e3085cc4a06028a47b5463cbd866 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:21:31 -0400 Subject: [PATCH 07/71] UPDATE GITIGNORE --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 25b11a123..4e9be6ab1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *.py[cod] *$py.class +explore.py # C extensions *.so From f987da34ca137885bcdcd47fa8c42f363d54ee06 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:21:56 -0400 Subject: [PATCH 08/71] UPDATE GITIGNORE --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4e9be6ab1..25b11a123 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,6 @@ __pycache__/ *.py[cod] *$py.class -explore.py # C extensions *.so From dbab9a3077165234c87beea50018e0d3f33befe7 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:27:26 -0400 Subject: [PATCH 09/71] UPDATE TESTS --- explore.py | 53 ------------------------------------- tests/mcmc/test_sampling.py | 2 +- tests/test_util.py | 6 ++--- 3 files changed, 4 insertions(+), 57 deletions(-) delete mode 100644 explore.py diff --git a/explore.py b/explore.py deleted file mode 100644 index 514029420..000000000 --- a/explore.py +++ /dev/null @@ -1,53 +0,0 @@ -import jax -import jax.numpy as jnp -from benchmarks.mcmc.sampling_algorithms import samplers -import blackjax -from blackjax.mcmc.mhmclmc import mhmclmc, rescale -from blackjax.mcmc.hmc import hmc -from blackjax.mcmc.dynamic_hmc import dynamic_hmc -from blackjax.mcmc.integrators import isokinetic_mclachlan -from blackjax.util import run_inference_algorithm - - - - - -init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) - -def logdensity_fn(x): - return -0.5 * jnp.sum(jnp.square(x)) - -initial_position = jnp.ones(10,) - - -def run_mclmc(logdensity_fn, num_steps, initial_position): - key = jax.random.PRNGKey(0) - init_key, tune_key, run_key = jax.random.split(key, 3) - - - initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key - ) - - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - ) - - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - ) - - print(blackjax_mclmc_sampler_params) - -# out = run_hmc(initial_position) -out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0)) -print(out.mean(axis=0) ) - - diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 39c1b811b..19f72a7c2 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -104,7 +104,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): _, samples, _ = run_inference_algorithm( rng_key=run_key, - initial_state_or_position=blackjax_state_after_tuning, + initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, transform=lambda x: x.position, diff --git a/tests/test_util.py b/tests/test_util.py index ed5cb12f0..97aba5205 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -19,14 +19,14 @@ def setUp(self): ) self.num_steps = 10 - def check_compatible(self, initial_state_or_position, progress_bar): + def check_compatible(self, initial_state, progress_bar): """ Runs 10 steps with `run_inference_algorithm` starting with - `initial_state_or_position` and potentially a progress bar. + `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( self.key, - initial_state_or_position, + initial_state, self.algorithm, self.num_steps, progress_bar, From 098f5ad7d7f98f276e5814d36bcf28e53896ebc5 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:35:06 -0400 Subject: [PATCH 10/71] UPDATE TESTS --- blackjax/mcmc/integrators.py | 2 +- tests/mcmc/test_sampling.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 28c0aa8c3..0f4deeca4 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -325,7 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized*std_mat) + gr = unravel_fn(new_momentum_normalized * std_mat) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 6f136d320..447adeecd 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -74,7 +74,14 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - def run_mclmc(self, logdensity_fn, num_steps, initial_position, key, diagonal_preconditioning=False): + def run_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( From 5bd2a3f4c12aab6ba333ac2baa936e7d3df64ee0 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:41:26 -0400 Subject: [PATCH 11/71] ADD DOCSTRING --- blackjax/util.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 55b8b3e47..a9ed821f4 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -220,9 +220,25 @@ def _online_one_step(average_and_state, xs): def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): - """streaming average of f(x)""" + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + O + function to be averaged + x + current state + streaming_avg + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ total, average = streaming_avg average = (total * average + weight * O(x)) / (total + weight + zero_prevention) total += weight streaming_avg = (total, average) - return streaming_avg + return streaming_avg \ No newline at end of file From 4fc1453b4b4430b5350bafc6b5fbb9b6fc7721e7 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:56:18 -0400 Subject: [PATCH 12/71] ADD TEST --- blackjax/util.py | 4 ++-- tests/test_util.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index a9ed821f4..2efb93f12 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -206,7 +206,7 @@ def _online_one_step(average_and_state, xs): if streaming: xs = (jnp.arange(num_steps), keys) - (average, final_state), _ = lax.scan( + ((_, average), final_state), _ = lax.scan( online_one_step, ((0, transform(initial_state)), initial_state), xs ) return average, transform(final_state) @@ -241,4 +241,4 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): average = (total * average + weight * O(x)) / (total + weight + zero_prevention) total += weight streaming_avg = (total, average) - return streaming_avg \ No newline at end of file + return streaming_avg diff --git a/tests/test_util.py b/tests/test_util.py index 97aba5205..1291b09e7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -33,7 +33,7 @@ def check_compatible(self, initial_state, progress_bar): transform=lambda x: x.position, ) - def test_streamning(self): + def test_streaming(self): def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) @@ -54,13 +54,11 @@ def logdensity_fn(x): initial_state=initial_state, inference_algorithm=alg, num_steps=50, - progress_bar=True, + progress_bar=False, transform=lambda x: x.position, streaming=True, ) - print(average) - _, states, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -71,7 +69,7 @@ def logdensity_fn(x): streaming=False, ) - assert jnp.array_equal(states.mean(axis=0), average) + assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): From 203f1fd4dfbd754b8e4ae49c36a26348be41c025 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 18:37:04 +0200 Subject: [PATCH 13/71] STREAMING AVERAGE --- blackjax/benchmarks/mcmc/benchmark.py | 557 ++++++++++++++++++++++++++ blackjax/util.py | 69 +++- tests/mcmc/test_sampling.py | 2 +- tests/test_util.py | 6 +- 4 files changed, 616 insertions(+), 18 deletions(-) create mode 100644 blackjax/benchmarks/mcmc/benchmark.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py new file mode 100644 index 000000000..549a55364 --- /dev/null +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -0,0 +1,557 @@ +from collections import defaultdict +from functools import partial +import math +import operator +import os +import pprint +from statistics import mean, median +import jax +import jax.numpy as jnp +import pandas as pd +import scipy + +from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, integrator_order, target_acceptance_rate_of_order + +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) +num_cores = jax.local_device_count() +# print(num_cores, jax.lib.xla_bridge.get_backend().platform) + +import itertools + +import numpy as np + +import blackjax +from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers +from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models +from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients +from blackjax.mcmc.mhmclmc import rescale +from blackjax.util import run_inference_algorithm + + + +def get_num_latents(target): + return target.ndims +# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) + + +def err(f_true, var_f, contract): + """Computes the error b^2 = (f - f_true)^2 / var_f + Args: + f: E_sampler[f(x)], can be a vector + f_true: E_true[f(x)] + var_f: Var_true[f(x)] + contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max + + Returns: + contract(b^2) + """ + + return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) + + + +def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): + """Uses the error of the expectation values to compute the effective sample size neff + b^2 = 1/neff""" + + cutoff_reached = err_t[-1] < low_error + return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached + + +def calculate_ess(err_t, grad_evals_per_step, neff= 100): + + grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) + + return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached + + +def find_crossing(array, cutoff): + """the smallest M such that array[m] < cutoff for all m > M""" + + b = array > cutoff + indices = jnp.argwhere(b) + if indices.shape[0] == 0: + print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) + return 1 + + return jnp.max(indices)+1 + + +def cumulative_avg(samples): + return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] + + +def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): + results = defaultdict(float) + converged = False + keys = jax.random.split(key, iterations+1) + for i in range(iterations): + print(f"EPOCH {i}") + width = 2 + step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) + Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) + # print(list(itertools.product(step_sizes , Ls))) + + grid_keys = jax.random.split(keys[i], grid_size^2) + print(f"center step size {center_step_size}, center L {center_L}") + for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): + ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) + results[(step_size, L)] = (ess, grad_calls_until_convergence) + + best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) + # raise Exception + print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + if L==center_L and step_size==center_step_size: + print("converged") + converged = True + break + else: + center_L, center_step_size = L, step_size + + pprint.pp(results) + # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") + return center_L, center_step_size, converged + + +def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): + + def s(logdensity_fn, num_steps, initial_position, transform, key): + + integrator = generate_isokinetic_integrator(coefficients) + + num_steps_per_traj = L/step_size + alg = blackjax.mcmc.mhmclmc.mhmclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , + integrator=integrator, + std_mat=std_mat, + ) + + _, out, info = run_inference_algorithm( + rng_key=key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True) + + return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) + + return s + +def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): + + pvmap = jax.pmap + + # def pvmap(f): + # def f(arr): + # return arr + # print(arr.shape,"shape") + # print(arr) + # arr = arr.reshape(128, -1) + # out = jax.vmap(jax.vmap(f), in_axes=0)(arr) + # return out.flatten() + # return f + + d = get_num_latents(model) + if batch is None: + batch = np.ceil(1000 / d).astype(int) + key, init_key = jax.random.split(key, 2) + keys = jax.random.split(key, batch) + + init_keys = jax.random.split(init_key, batch) + init_pos = pvmap(model.sample_init)(init_keys) + + # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) + samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) + avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) + try: + print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) + except: pass + + full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) + err_t = pvmap(full)(samples**2) + + # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] + # # print(outs[:10]) + # esses = [i[0].item() for i in outs if not math.isnan(i[0].item())] + # grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())] + # return(mean(esses), mean(grad_calls)) + # print(final_da.mean(), "final da") + + + err_t_median = jnp.median(err_t, axis=0) + # import matplotlib.pyplot as plt + # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) + # plt.xlabel('gradient evaluations') + # plt.ylabel('average second moment error') + # plt.xscale('log') + # plt.yscale('log') + # plt.savefig('brownian.png') + # plt.close() + esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) + return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da + + + + +def run_benchmarks(batch_size): + + results = defaultdict(tuple) + for variables in itertools.product( + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], + # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], + [Brownian()], + # [Brownian()], + # [Brownian()], + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): + + sampler, model, coefficients = variables + num_chains = batch_size#1 + batch_size//model.ndims + + + num_steps = 100000 + + sampler, model, coefficients = variables + num_chains = batch_size # 1 + batch_size//model.ndims + + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + + contract = jnp.max + + key = jax.random.PRNGKey(11) + for i in range(1): + key1, key = jax.random.split(key) + ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) + + # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") + jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) + jax.numpy.save(f"acceptance.npy", acceptance_rate) + + + # print(f"grads to low bias: {grad_calls}") + # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") + + results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + print(ess.item()) + # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] + + + # print(results) + + + df = pd.Series(results).reset_index() + df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + # df.result = df.result.apply(lambda x: x[0].item()) + # df.model = df.model.apply(lambda x: x[1]) + df.to_csv("results_simple.csv", index=False) + + return results + +# vary step_size +def run_benchmarks_step_size(batch_size): + + results = defaultdict(tuple) + for variables in itertools.product( + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], + # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], + [StandardNormal(10)], + # [Brownian()], + # [Brownian()], + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): + + + + num_steps = 10000 + + sampler, model, coefficients = variables + num_chains = batch_size # 1 + batch_size//model.ndims + + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + + contract = jnp.average + + center = 6.534974 + key = jax.random.PRNGKey(11) + for step_size in np.linspace(center-1,center+1, 41): + # for L in np.linspace(1, 10, 41): + key1, key2, key3, key = jax.random.split(key, 4) + initial_position = model.sample_init(key2) + initial_state = blackjax.mcmc.mhmclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) + + # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") + # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) + # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) + + + # print(f"grads to low bias: {grad_calls}") + # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") + + results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] + + + # print(results) + + + df = pd.Series(results).reset_index() + df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + # df.result = df.result.apply(lambda x: x[0].item()) + # df.model = df.model.apply(lambda x: x[1]) + df.to_csv("results_step_size.csv", index=False) + + return results + + + +def benchmark_mhmchmc(batch_size): + + key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) + results = defaultdict(tuple) + + # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] + coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] + for model, coeffs in itertools.product(models, coefficients): + + num_chains = batch_size # 1 + batch_size//model.ndims + print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") + num_steps = models[model]["mhmclmc"] + print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") + + ####### run mclmc with standard tuning + + contract = jnp.max + + + ess, grad_calls, params , _, step_size_over_da = benchmark_chains( + model, + partial(run_mclmc,coefficients=coeffs), + key0, + n=num_steps, + batch=num_chains, + contract=contract) + results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() + print(f'mclmc with tuning ESS {ess}') + + + ####### run mhmclmc with standard tuning + for target_acc_rate in [0.65, 0.9]: + # coeffs = mclachlan_coefficients + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( + model, + partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), + key1, + n=num_steps, + batch=num_chains, + contract=contract) + results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + print(f'mhmclmc with tuning ESS {ess}') + + # coeffs = mclachlan_coefficients + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( + model, + partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), + key1, + n=num_steps, + batch=num_chains, + contract=contract) + results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + print(f'mhmclmc with tuning ESS {ess}') + + if True: + ####### run mhmclmc with standard tuning + grid search + + init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) + initial_position = model.sample_init(init_pos_key) + + initial_state = blackjax.mcmc.mhmclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + integrator=generate_isokinetic_integrator(coeffs), + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn) + + ( + state, + blackjax_mhmclmc_sampler_params, + _, _ + ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acceptance_rate_of_order[integrator_order(coeffs)], + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + diagonal_preconditioning=False + ) + + print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") + print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + + + L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + # print(f"params after grid tuning are L={L}, step_size={step_size}") + + + ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) + + print(f"grads to low bias: {grad_calls}") + + results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() + + ####### run nuts + + # coeffs = velocity_verlet_coefficients + ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) + results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + + + + + + + + print(results) + + + df = pd.Series(results).reset_index() + df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] + # df.result = df.result.apply(lambda x: x[0].item()) + # df.model = df.model.apply(lambda x: x[1]) + df.to_csv("results.csv", index=False) + + return results + +def benchmark_omelyan(batch_size): + + + key = jax.random.PRNGKey(2) + results = defaultdict(tuple) + for variables in itertools.product( + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmchmc"], + [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], + # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], + # models, + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients, omelyan_coefficients], + ): + + + sampler, model, coefficients = variables + + # num_chains = 1 + batch_size//model.ndims + num_chains = batch_size + + current_key, key = jax.random.split(key) + init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) + + # num_steps = models[model][sampler] + + num_steps = 1000 + + + initial_position = model.sample_init(init_pos_key) + + initial_state = blackjax.mcmc.mhmclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + ) + + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + integrator=generate_isokinetic_integrator(coefficients), + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn) + + ( + state, + blackjax_mhmclmc_sampler_params, + _, _ + ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acceptance_rate_of_order[integrator_order(coefficients)], + frac_tune1=0.1, + frac_tune2=0.1, + # frac_tune3=0.1, + diagonal_preconditioning=False + ) + + print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + + # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) + + # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) + + L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + print(f"params after grid tuning are L={L}, step_size={step_size}") + + + ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) + + print(f"grads to low bias: {grad_calls}") + + results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() + + df = pd.Series(results).reset_index() + df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] + # df.result = df.result.apply(lambda x: x[0].item()) + # df.model = df.model.apply(lambda x: x[1]) + df.to_csv("omelyan.csv", index=False) + + +def run_benchmarks_divij(): + + sampler = run_mclmc + model = StandardNormal(10) # 10 dimensional standard normal + coefficients = mclachlan_coefficients + contract = jnp.average # how we average across dimensions + num_steps = 2000 + num_chains = 100 + key1 = jax.random.PRNGKey(2) + + ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + + print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") + +if __name__ == "__main__": + + # run_benchmarks_divij() + + benchmark_mhmchmc(batch_size=128) + # run_benchmarks(128) + # run_benchmarks_step_size(128) + benchmark_omelyan(128) + # run_benchmarks(128) + #benchmark_omelyan(10) + # print("4") + + + diff --git a/blackjax/util.py b/blackjax/util.py index 59917e68a..79e665316 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -142,12 +142,13 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, + initial_state: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, -) -> tuple[State, State, Info]: + streaming=False, +) -> tuple: """Wrapper to run an inference algorithm. Note that this utility function does not work for Stochastic Gradient MCMC samplers @@ -158,7 +159,7 @@ def run_inference_algorithm( ---------- rng_key The random state used by JAX's random numbers generator. - initial_state_or_position + initial_state The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. @@ -167,9 +168,11 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transform of the trace of states to be returned. This is useful for + A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. + streaming + if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- @@ -178,14 +181,8 @@ def run_inference_algorithm( 2. The trace of states of the inference algorithm (contains the MCMC samples). 3. The trace of the info of the inference algorithm for diagnostics. """ - init_key, sample_key = split(rng_key, 2) - try: - initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except (TypeError, ValueError, AttributeError): - # We assume initial_state is already in the right format. - initial_state = initial_state_or_position - keys = split(sample_key, num_steps) + keys = split(rng_key, num_steps) @jit def _one_step(state, xs): @@ -193,11 +190,55 @@ def _one_step(state, xs): state, info = inference_algorithm.step(rng_key, state) return state, (transform(state), info) + def _online_one_step(average_and_state, xs): + _, rng_key = xs + average, state = average_and_state + state, _ = inference_algorithm.step(rng_key, state) + average = streaming_average(transform, state, average) + return (average, state), None + if progress_bar: one_step = progress_bar_scan(num_steps)(_one_step) + online_one_step = progress_bar_scan(num_steps)(_online_one_step) else: one_step = _one_step + online_one_step = _online_one_step + + if streaming: + xs = (jnp.arange(num_steps), keys) + ((_, average), final_state), _ = lax.scan( + online_one_step, ((0, transform(initial_state)), initial_state), xs + ) + return average, transform(final_state) + + else: + xs = (jnp.arange(num_steps), keys) + final_state, (state_history, info_history) = lax.scan( + one_step, initial_state, xs + ) + return final_state, state_history, info_history - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) - return final_state, state_history, info_history + +def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + O + function to be averaged + x + current state + streaming_avg + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ + total, average = streaming_avg + average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + total += weight + streaming_avg = (total, average) + return streaming_avg \ No newline at end of file diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 39c1b811b..19f72a7c2 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -104,7 +104,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): _, samples, _ = run_inference_algorithm( rng_key=run_key, - initial_state_or_position=blackjax_state_after_tuning, + initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, transform=lambda x: x.position, diff --git a/tests/test_util.py b/tests/test_util.py index a6e023074..85a68e6e8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -19,14 +19,14 @@ def setUp(self): ) self.num_steps = 10 - def check_compatible(self, initial_state_or_position, progress_bar): + def check_compatible(self, initial_state, progress_bar): """ Runs 10 steps with `run_inference_algorithm` starting with - `initial_state_or_position` and potentially a progress bar. + `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( self.key, - initial_state_or_position, + initial_state, self.algorithm, self.num_steps, progress_bar, From fc347d613af53f103310d51c9dd3bac61e6012ff Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 18:39:46 +0200 Subject: [PATCH 14/71] ADD TEST --- tests/test_util.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_util.py b/tests/test_util.py index 85a68e6e8..1291b09e7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -33,6 +33,44 @@ def check_compatible(self, initial_state, progress_bar): transform=lambda x: x.position, ) + def test_streaming(self): + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + initial_position = jnp.ones( + 10, + ) + + init_key, run_key = jax.random.split(self.key, 2) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + average, states = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + transform=lambda x: x.position, + streaming=True, + ) + + _, states, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + transform=lambda x: x.position, + streaming=False, + ) + + assert jnp.allclose(states.mean(axis=0), average) + @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) From 49410f9037a99f87677e1e89b32d48a350526b7f Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:04:37 +0200 Subject: [PATCH 15/71] REFACTOR RUN_INFERENCE_ALGORITHM --- blackjax/util.py | 67 +++++++++++++++++++++++----------------------- tests/test_util.py | 21 ++++++++++++--- 2 files changed, 51 insertions(+), 37 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 2efb93f12..4c58ad597 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -2,6 +2,7 @@ from functools import partial from typing import Callable, Union +import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree @@ -147,7 +148,8 @@ def run_inference_algorithm( num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, - streaming=False, + return_state_history=True, + expectation: Callable = lambda x: x, ) -> tuple: """Wrapper to run an inference algorithm. @@ -171,52 +173,44 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - streaming + return_expectation if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- Tuple[State, State, Info] - 1. The final state of the inference algorithm. - 2. The trace of states of the inference algorithm (contains the MCMC samples). - 3. The trace of the info of the inference algorithm for diagnostics. + 1. The expectation of transform(state) over the chain. + 2. The final state of the inference algorithm. + 3. The trace of the state and info of the inference algorithm for diagnostics. """ keys = split(rng_key, num_steps) - @jit - def _one_step(state, xs): + def one_step(average_and_state, xs, return_state): _, rng_key = xs + average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - return state, (transform(state), info) + average = streaming_average(expectation, state, average) + if return_state: + return (average, state), (transform(state), info) + else: + return (average, state), None - def _online_one_step(average_and_state, xs): - _, rng_key = xs - average, state = average_and_state - state, _ = inference_algorithm.step(rng_key, state) - average = streaming_average(transform, state, average) - return (average, state), None + one_step = jax.jit(partial(one_step, return_state=return_state_history)) if progress_bar: - one_step = progress_bar_scan(num_steps)(_one_step) - online_one_step = progress_bar_scan(num_steps)(_online_one_step) - else: - one_step = _one_step - online_one_step = _online_one_step - - if streaming: - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), _ = lax.scan( - online_one_step, ((0, transform(initial_state)), initial_state), xs - ) - return average, transform(final_state) + one_step = progress_bar_scan(num_steps)(one_step) + xs = (jnp.arange(num_steps), keys) + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(initial_state)), initial_state), xs + ) + + if not return_state_history: + return average, transform(final_state) else: - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan( - one_step, initial_state, xs - ) - return final_state, state_history, info_history + state_history, info_history = history + return transform(final_state), state_history, info_history def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): @@ -237,8 +231,15 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): ---------- new streaming average """ + + expectation = O(x) + flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg - average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + flat_average, _ = ravel_pytree(average) + average = (total * flat_average + weight * flat_expectation) / ( + total + weight + zero_prevention + ) total += weight - streaming_avg = (total, average) + streaming_avg = (total, unravel_fn(average)) return streaming_avg + diff --git a/tests/test_util.py b/tests/test_util.py index 1291b09e7..1665bc2c3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,26 +49,39 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - average, states = run_inference_algorithm( + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=True, + return_state_history=True, ) - _, states, _ = run_inference_algorithm( + print(states.mean(axis=0)) + + + average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=False, + return_state_history=False, ) + print(average) + print(states.mean(axis=0)[1]==average[1]) + + print(jnp.allclose(states.mean(axis=0), average)) + + assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) From ffdca93147726882fb4dc13fe0778fcd8f435d65 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:13:59 +0200 Subject: [PATCH 16/71] UPDATE DOCSTRING --- blackjax/util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 4c58ad597..77a90ba56 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -173,15 +173,20 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - return_expectation - if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. + expectation + A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. + return_state_history + if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- - Tuple[State, State, Info] - 1. The expectation of transform(state) over the chain. + If return_state_history is True: + 1. The final state. + 2. The trace of the state. + 3. The trace of the info of the inference algorithm for diagnostics. + If return_state_history is False: + 1. This is the expectation of state over the chain. Otherwise the final state. 2. The final state of the inference algorithm. - 3. The trace of the state and info of the inference algorithm for diagnostics. """ keys = split(rng_key, num_steps) From b7b7084f92ea59847ea301c44b8dbc7f64d22e3b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:14:28 +0200 Subject: [PATCH 17/71] Precommit --- blackjax/util.py | 1 - tests/test_util.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 77a90ba56..e579c126d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -247,4 +247,3 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): total += weight streaming_avg = (total, unravel_fn(average)) return streaming_avg - diff --git a/tests/test_util.py b/tests/test_util.py index 1665bc2c3..6a7efd6b5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -50,7 +50,7 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - + _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -64,7 +64,6 @@ def logdensity_fn(x): print(states.mean(axis=0)) - average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -77,11 +76,10 @@ def logdensity_fn(x): ) print(average) - print(states.mean(axis=0)[1]==average[1]) + print(states.mean(axis=0)[1] == average[1]) print(jnp.allclose(states.mean(axis=0), average)) - assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) From 97cfc9eccd92ae5a2616e4bf379d6a75102abf54 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:18:54 +0200 Subject: [PATCH 18/71] CLEAN TESTS --- tests/test_util.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 6a7efd6b5..3bafca894 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,8 +49,6 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -62,8 +60,6 @@ def logdensity_fn(x): return_state_history=True, ) - print(states.mean(axis=0)) - average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -75,11 +71,6 @@ def logdensity_fn(x): return_state_history=False, ) - print(average) - print(states.mean(axis=0)[1] == average[1]) - - print(jnp.allclose(states.mean(axis=0), average)) - assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) From beb6cbe19adfa71ee1fc39e9515bea748a8d0d99 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:22:23 +0200 Subject: [PATCH 19/71] FIX BAD MERGE --- blackjax/util.py | 73 +++++++++++++++++++++++++--------------------- tests/test_util.py | 10 ++++--- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 79e665316..608183cc9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -2,6 +2,7 @@ from functools import partial from typing import Callable, Union +import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree @@ -147,7 +148,8 @@ def run_inference_algorithm( num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, - streaming=False, + return_state_history=True, + expectation: Callable = lambda x: x, ) -> tuple: """Wrapper to run an inference algorithm. @@ -171,52 +173,49 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - streaming - if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. + expectation + A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. + return_state_history + if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- - Tuple[State, State, Info] - 1. The final state of the inference algorithm. - 2. The trace of states of the inference algorithm (contains the MCMC samples). + If return_state_history is True: + 1. The final state. + 2. The trace of the state. 3. The trace of the info of the inference algorithm for diagnostics. + If return_state_history is False: + 1. This is the expectation of state over the chain. Otherwise the final state. + 2. The final state of the inference algorithm. """ keys = split(rng_key, num_steps) - @jit - def _one_step(state, xs): + def one_step(average_and_state, xs, return_state): _, rng_key = xs + average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - return state, (transform(state), info) + average = streaming_average(expectation, state, average) + if return_state: + return (average, state), (transform(state), info) + else: + return (average, state), None - def _online_one_step(average_and_state, xs): - _, rng_key = xs - average, state = average_and_state - state, _ = inference_algorithm.step(rng_key, state) - average = streaming_average(transform, state, average) - return (average, state), None + one_step = jax.jit(partial(one_step, return_state=return_state_history)) if progress_bar: - one_step = progress_bar_scan(num_steps)(_one_step) - online_one_step = progress_bar_scan(num_steps)(_online_one_step) - else: - one_step = _one_step - online_one_step = _online_one_step - - if streaming: - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), _ = lax.scan( - online_one_step, ((0, transform(initial_state)), initial_state), xs - ) - return average, transform(final_state) + one_step = progress_bar_scan(num_steps)(one_step) + + xs = (jnp.arange(num_steps), keys) + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(initial_state)), initial_state), xs + ) + if not return_state_history: + return average, transform(final_state) else: - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan( - one_step, initial_state, xs - ) - return final_state, state_history, info_history + state_history, info_history = history + return transform(final_state), state_history, info_history def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): @@ -237,8 +236,14 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): ---------- new streaming average """ + + expectation = O(x) + flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg - average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + flat_average, _ = ravel_pytree(average) + average = (total * flat_average + weight * flat_expectation) / ( + total + weight + zero_prevention + ) total += weight - streaming_avg = (total, average) + streaming_avg = (total, unravel_fn(average)) return streaming_avg \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index 1291b09e7..3bafca894 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,24 +49,26 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - average, states = run_inference_algorithm( + _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=True, + return_state_history=True, ) - _, states, _ = run_inference_algorithm( + average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=False, + return_state_history=False, ) assert jnp.allclose(states.mean(axis=0), average) From 09b9dbd2c16cdf0a67de07d79034db478e870334 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:47:28 +0200 Subject: [PATCH 20/71] ADJUSTED MCLMC --- blackjax/mcmc/adjusted_mclmc.py | 260 ++++++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 blackjax/mcmc/adjusted_mclmc.py diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py new file mode 100644 index 000000000..86500ec84 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -0,0 +1,260 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.types import ArrayLike +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling + +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = [ + "init", + "build_kernel", + "mhmclmc", +] + +def init( + position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array +): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + +# TODO: no default for std_mat +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + std_mat=1., +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal : float = 1.0, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn( + state.random_generator_arg + ) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + + proposal, info, _ = mhmclmc_proposal( + # integrators.with_isokinetic_maruyama(integrator(logdensity_fn)), + lambda state, step_size, L_prop, key : (integrator(logdensity_fn, std_mat))(state, step_size), + step_size, + L_proposal, + num_integration_steps, + divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ) + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal : float = 0.6, + std_mat=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel(integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, std_mat=std_mat, divergence_threshold=divergence_threshold) + + + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal, + ) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def mhmclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator(state, step_size, L_proposal, rng_key) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop(0*num_integration_steps, num_integration_steps, step, (state, 0, rng_key)) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + # note that this is the POTENTIAL energy only + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) \ No newline at end of file From 71e372148a865321f2988b8b90b2b37e28a3d293 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:48:33 +0200 Subject: [PATCH 21/71] REMOVE BENCHMARKS: --- blackjax/benchmarks/mcmc/benchmark.py | 557 -------------------------- 1 file changed, 557 deletions(-) delete mode 100644 blackjax/benchmarks/mcmc/benchmark.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py deleted file mode 100644 index 549a55364..000000000 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ /dev/null @@ -1,557 +0,0 @@ -from collections import defaultdict -from functools import partial -import math -import operator -import os -import pprint -from statistics import mean, median -import jax -import jax.numpy as jnp -import pandas as pd -import scipy - -from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, integrator_order, target_acceptance_rate_of_order - -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) -num_cores = jax.local_device_count() -# print(num_cores, jax.lib.xla_bridge.get_backend().platform) - -import itertools - -import numpy as np - -import blackjax -from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers -from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients -from blackjax.mcmc.mhmclmc import rescale -from blackjax.util import run_inference_algorithm - - - -def get_num_latents(target): - return target.ndims -# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) - - -def err(f_true, var_f, contract): - """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) - - - -def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): - """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - - cutoff_reached = err_t[-1] < low_error - return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff= 100): - - grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) - - return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached - - -def find_crossing(array, cutoff): - """the smallest M such that array[m] < cutoff for all m > M""" - - b = array > cutoff - indices = jnp.argwhere(b) - if indices.shape[0] == 0: - print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) - return 1 - - return jnp.max(indices)+1 - - -def cumulative_avg(samples): - return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): - results = defaultdict(float) - converged = False - keys = jax.random.split(key, iterations+1) - for i in range(iterations): - print(f"EPOCH {i}") - width = 2 - step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) - Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) - # print(list(itertools.product(step_sizes , Ls))) - - grid_keys = jax.random.split(keys[i], grid_size^2) - print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): - ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) - results[(step_size, L)] = (ess, grad_calls_until_convergence) - - best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) - # raise Exception - print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - if L==center_L and step_size==center_step_size: - print("converged") - converged = True - break - else: - center_L, center_step_size = L, step_size - - pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") - return center_L, center_step_size, converged - - -def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - - def s(logdensity_fn, num_steps, initial_position, transform, key): - - integrator = generate_isokinetic_integrator(coefficients) - - num_steps_per_traj = L/step_size - alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , - integrator=integrator, - std_mat=std_mat, - ) - - _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) - - return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) - - return s - -def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): - - pvmap = jax.pmap - - # def pvmap(f): - # def f(arr): - # return arr - # print(arr.shape,"shape") - # print(arr) - # arr = arr.reshape(128, -1) - # out = jax.vmap(jax.vmap(f), in_axes=0)(arr) - # return out.flatten() - # return f - - d = get_num_latents(model) - if batch is None: - batch = np.ceil(1000 / d).astype(int) - key, init_key = jax.random.split(key, 2) - keys = jax.random.split(key, batch) - - init_keys = jax.random.split(init_key, batch) - init_pos = pvmap(model.sample_init)(init_keys) - - # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) - avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) - try: - print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) - except: pass - - full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) - err_t = pvmap(full)(samples**2) - - # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] - # # print(outs[:10]) - # esses = [i[0].item() for i in outs if not math.isnan(i[0].item())] - # grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())] - # return(mean(esses), mean(grad_calls)) - # print(final_da.mean(), "final da") - - - err_t_median = jnp.median(err_t, axis=0) - # import matplotlib.pyplot as plt - # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) - # plt.xlabel('gradient evaluations') - # plt.ylabel('average second moment error') - # plt.xscale('log') - # plt.yscale('log') - # plt.savefig('brownian.png') - # plt.close() - esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) - return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da - - - - -def run_benchmarks(batch_size): - - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [Brownian()], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - - sampler, model, coefficients = variables - num_chains = batch_size#1 + batch_size//model.ndims - - - num_steps = 100000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.max - - key = jax.random.PRNGKey(11) - for i in range(1): - key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - jax.numpy.save(f"acceptance.npy", acceptance_rate) - - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() - print(ess.item()) - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - - # print(results) - - - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_simple.csv", index=False) - - return results - -# vary step_size -def run_benchmarks_step_size(batch_size): - - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [StandardNormal(10)], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - - - - num_steps = 10000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.average - - center = 6.534974 - key = jax.random.PRNGKey(11) - for step_size in np.linspace(center-1,center+1, 41): - # for L in np.linspace(1, 10, 41): - key1, key2, key3, key = jax.random.split(key, 4) - initial_position = model.sample_init(key2) - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - - # print(results) - - - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_step_size.csv", index=False) - - return results - - - -def benchmark_mhmchmc(batch_size): - - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) - results = defaultdict(tuple) - - # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] - coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] - for model, coeffs in itertools.product(models, coefficients): - - num_chains = batch_size # 1 + batch_size//model.ndims - print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") - num_steps = models[model]["mhmclmc"] - print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") - - ####### run mclmc with standard tuning - - contract = jnp.max - - - ess, grad_calls, params , _, step_size_over_da = benchmark_chains( - model, - partial(run_mclmc,coefficients=coeffs), - key0, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() - print(f'mclmc with tuning ESS {ess}') - - - ####### run mhmclmc with standard tuning - for target_acc_rate in [0.65, 0.9]: - # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') - - # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') - - if True: - ####### run mhmclmc with standard tuning + grid search - - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, _ - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coeffs)], - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - diagonal_preconditioning=False - ) - - print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") - - - L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) - # print(f"params after grid tuning are L={L}, step_size={step_size}") - - - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) - - print(f"grads to low bias: {grad_calls}") - - results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() - - ####### run nuts - - # coeffs = velocity_verlet_coefficients - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) - results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - - - - - - - - print(results) - - - df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results.csv", index=False) - - return results - -def benchmark_omelyan(batch_size): - - - key = jax.random.PRNGKey(2) - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], - # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - - - sampler, model, coefficients = variables - - # num_chains = 1 + batch_size//model.ndims - num_chains = batch_size - - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) - - # num_steps = models[model][sampler] - - num_steps = 1000 - - - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key - ) - - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, _ - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coefficients)], - frac_tune1=0.1, - frac_tune2=0.1, - # frac_tune3=0.1, - diagonal_preconditioning=False - ) - - print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") - - # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) - print(f"params after grid tuning are L={L}, step_size={step_size}") - - - ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) - - print(f"grads to low bias: {grad_calls}") - - results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() - - df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("omelyan.csv", index=False) - - -def run_benchmarks_divij(): - - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal - coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions - num_steps = 2000 - num_chains = 100 - key1 = jax.random.PRNGKey(2) - - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) - - print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") - -if __name__ == "__main__": - - # run_benchmarks_divij() - - benchmark_mhmchmc(batch_size=128) - # run_benchmarks(128) - # run_benchmarks_step_size(128) - benchmark_omelyan(128) - # run_benchmarks(128) - #benchmark_omelyan(10) - # print("4") - - - From 45bc67738dd6d7fe613f08de4c541eefef12af3f Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 21:15:41 +0200 Subject: [PATCH 22/71] ADD ADJUSTED MCLMC --- .gitignore | 3 +++ blackjax/__init__.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 25b11a123..5fca94f69 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python +explore.py +blackjax/benchmarks + ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dfdcfc545..6afd9454a 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -19,6 +19,7 @@ from .mcmc import mala as _mala from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc +from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc @@ -109,6 +110,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) From a27dba993bfb68867b434d0462b76126eb0f0175 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 21:21:36 +0200 Subject: [PATCH 23/71] GITIGNORE --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 25b11a123..d9186a6e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python +explore.py + ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ From 7a6e42b02e4daa7adcc4be7bb57d9e1280cfd07d Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 21:26:30 +0200 Subject: [PATCH 24/71] PRECOMMIT CLEAN UP --- blackjax/benchmarks/mcmc/benchmark.py | 777 +++++++++++++++++++------- blackjax/util.py | 2 +- explore.py | 61 -- 3 files changed, 565 insertions(+), 275 deletions(-) delete mode 100644 explore.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py index 549a55364..9eadc7e2f 100644 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -1,18 +1,26 @@ -from collections import defaultdict -from functools import partial +# mypy: ignore-errors +# flake8: noqa + import math import operator import os import pprint +from collections import defaultdict +from functools import partial from statistics import mean, median + import jax import jax.numpy as jnp import pandas as pd import scipy -from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, integrator_order, target_acceptance_rate_of_order +from blackjax.adaptation.mclmc_adaptation import ( + MCLMCAdaptationState, + integrator_order, + target_acceptance_rate_of_order, +) -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) num_cores = jax.local_device_count() # print(num_cores, jax.lib.xla_bridge.get_backend().platform) @@ -21,48 +29,77 @@ import numpy as np import blackjax -from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers -from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients +from blackjax.benchmarks.mcmc.inference_models import ( + Brownian, + GermanCredit, + ItemResponseTheory, + MixedLogit, + StandardNormal, + StochasticVolatility, + models, +) +from blackjax.benchmarks.mcmc.sampling_algorithms import ( + run_mclmc, + run_mhmclmc, + run_nuts, + samplers, +) +from blackjax.mcmc.integrators import ( + calls_per_integrator_step, + generate_euclidean_integrator, + generate_isokinetic_integrator, + isokinetic_mclachlan, + mclachlan_coefficients, + name_integrator, + omelyan_coefficients, + velocity_verlet, + velocity_verlet_coefficients, + yoshida_coefficients, +) from blackjax.mcmc.mhmclmc import rescale from blackjax.util import run_inference_algorithm - def get_num_latents(target): - return target.ndims + return target.ndims + + # return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) def err(f_true, var_f, contract): """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) + Args: + f: E_sampler[f(x)], can be a vector + f_true: E_true[f(x)] + var_f: Var_true[f(x)] + contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max + Returns: + contract(b^2) + """ + + return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) -def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): +def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - + b^2 = 1/neff""" + cutoff_reached = err_t[-1] < low_error return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff= 100): - - grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) - - return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached + + +def calculate_ess(err_t, grad_evals_per_step, neff=100): + grads_to_low, cutoff_reached = grads_to_low_error( + err_t, grad_evals_per_step, 1.0 / neff + ) + + return ( + (neff / grads_to_low) * cutoff_reached, + grads_to_low * (1 / cutoff_reached), + cutoff_reached, + ) def find_crossing(array, cutoff): @@ -74,34 +111,61 @@ def find_crossing(array, cutoff): print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) return 1 - return jnp.max(indices)+1 + return jnp.max(indices) + 1 def cumulative_avg(samples): - return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): + return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] + + +def gridsearch_tune( + key, + iterations, + grid_size, + model, + sampler, + batch, + num_steps, + center_L, + center_step_size, + contract, +): results = defaultdict(float) converged = False - keys = jax.random.split(key, iterations+1) + keys = jax.random.split(key, iterations + 1) for i in range(iterations): print(f"EPOCH {i}") width = 2 - step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) - Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) + step_sizes = np.logspace( + np.log10(center_step_size / width), + np.log10(center_step_size * width), + grid_size, + ) + Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) # print(list(itertools.product(step_sizes , Ls))) - grid_keys = jax.random.split(keys[i], grid_size^2) + grid_keys = jax.random.split(keys[i], grid_size ^ 2) print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): - ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) + for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): + ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( + model, + sampler(step_size=step_size, L=L), + grid_keys[j], + n=num_steps, + batch=batch, + contract=contract, + ) results[(step_size, L)] = (ess, grad_calls_until_convergence) - best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) + best_ess, best_grads, (step_size, L) = max( + ((results[r][0], results[r][1], r) for r in results), + key=operator.itemgetter(0), + ) # raise Exception - print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - if L==center_L and step_size==center_step_size: + print( + f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" + ) + if L == center_L and step_size == center_step_size: print("converged") converged = True break @@ -109,40 +173,55 @@ def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps center_L, center_step_size = L, step_size pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") + # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") return center_L, center_step_size, converged def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - def s(logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - num_steps_per_traj = L/step_size + num_steps_per_traj = L / step_size alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , - integrator=integrator, - std_mat=std_mat, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(num_steps_per_traj) + ), + integrator=integrator, + std_mat=std_mat, ) _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) + rng_key=key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) - return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) + return ( + out, + MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), + num_steps_per_traj * calls_per_integrator_step(coefficients), + info.acceptance_rate.mean(), + None, + jnp.array([0]), + ) return s -def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): +def benchmark_chains( + model, + sampler, + key, + n=10000, + batch=None, + contract=jnp.average, +): pvmap = jax.pmap # def pvmap(f): @@ -154,7 +233,7 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # out = jax.vmap(jax.vmap(f), in_axes=0)(arr) # return out.flatten() # return f - + d = get_num_latents(model) if batch is None: batch = np.ceil(1000 / d).astype(int) @@ -165,13 +244,31 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av init_pos = pvmap(model.sample_init)(init_keys) # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) + ( + samples, + params, + grad_calls_per_traj, + acceptance_rate, + step_size_over_da, + final_da, + ) = pvmap( + lambda pos, key: sampler( + logdensity_fn=model.logdensity_fn, + num_steps=n, + initial_position=pos, + transform=model.transform, + key=key, + ) + )( + init_pos, keys + ) avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) try: - print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) - except: pass - - full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) + print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) + except: + pass + + full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) err_t = pvmap(full)(samples**2) # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] @@ -181,7 +278,6 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # return(mean(esses), mean(grad_calls)) # print(final_da.mean(), "final da") - err_t_median = jnp.median(err_t, axis=0) # import matplotlib.pyplot as plt # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) @@ -191,121 +287,200 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # plt.yscale('log') # plt.savefig('brownian.png') # plt.close() - esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) - return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da - - + esses, grad_calls, _ = calculate_ess( + err_t_median, grad_evals_per_step=avg_grad_calls_per_traj + ) + return ( + esses, + grad_calls, + params, + jnp.mean(acceptance_rate, axis=0), + step_size_over_da, + ) def run_benchmarks(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): sampler, model, coefficients = variables - num_chains = batch_size#1 + batch_size//model.ndims - + num_chains = batch_size # 1 + batch_size//model.ndims num_steps = 100000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.max key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) + ( + ess, + grad_calls, + params, + acceptance_rate, + step_size_over_da, + ) = benchmark_chains( + model, + partial( + samplers[sampler], + coefficients=coefficients, + frac_tune1=0.1, + frac_tune2=0.0, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) jax.numpy.save(f"acceptance.npy", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() print(ess.item()) # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_simple.csv", index=False) return results + # vary step_size def run_benchmarks_step_size(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [StandardNormal(10)], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): num_steps = 10000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.average center = 6.534974 key = jax.random.PRNGKey(11) - for step_size in np.linspace(center-1,center+1, 41): - # for L in np.linspace(1, 10, 41): + for step_size in np.linspace(center - 1, center + 1, 41): + # for L in np.linspace(1, 10, 41): key1, key2, key3, key = jax.random.split(key, 4) initial_position = model.sample_init(key2) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=key3, + ) + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + initial_state=initial_state, + coefficients=mclachlan_coefficients, + step_size=step_size, + L=5 * step_size, + std_mat=1.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_step_size.csv", index=False) @@ -313,17 +488,14 @@ def run_benchmarks_step_size(batch_size): return results - def benchmark_mhmchmc(batch_size): - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) results = defaultdict(tuple) # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] for model, coeffs in itertools.product(models, coefficients): - - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") num_steps = models[model]["mhmclmc"] print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") @@ -331,67 +503,123 @@ def benchmark_mhmchmc(batch_size): ####### run mclmc with standard tuning contract = jnp.max - - ess, grad_calls, params , _, step_size_over_da = benchmark_chains( + ess, grad_calls, params, _, step_size_over_da = benchmark_chains( model, - partial(run_mclmc,coefficients=coeffs), + partial(run_mclmc, coefficients=coeffs), key0, n=num_steps, batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() - print(f'mclmc with tuning ESS {ess}') - + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mclmc", + params.L.mean().item(), + params.step_size.mean().item(), + name_integrator(coeffs), + "standard", + 1.0, + ) + ] = ess.item() + print(f"mclmc with tuning ESS {ess}") - ####### run mhmclmc with standard tuning + ####### run mhmclmc with standard tuning for target_acc_rate in [0.65, 0.9]: # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') - + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_mhmclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"mhmclmc with tuning ESS {ess}") + # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_mhmclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc:st3" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"mhmclmc with tuning ESS {ess}") if True: ####### run mhmclmc with standard tuning + grid search - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) + init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( + key2, 5 + ) initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coeffs), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_mhmclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -401,96 +629,165 @@ def benchmark_mhmchmc(batch_size): frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") - + print( + f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" + ) + print( + f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" + ) - L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + L, step_size, convergence = gridsearch_tune( + grid_key, + iterations=10, + contract=contract, + grid_size=5, + model=model, + sampler=partial( + run_mhmclmc_no_tuning, + coefficients=coeffs, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_mhmclmc_sampler_params.L, + center_step_size=blackjax_mhmclmc_sampler_params.step_size, + ) # print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + coefficients=coeffs, + L=L, + step_size=step_size, + initial_state=state, + std_mat=1.0, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() + results[ + ( + model.name, + model.ndims, + "mhmchmc:grid", + L.item(), + step_size.item(), + name_integrator(coeffs), + f"gridsearch:{convergence}", + acceptance_rate.mean().item(), + ) + ] = ess.item() ####### run nuts # coeffs = velocity_verlet_coefficients - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) - results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - - - - - + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + partial(run_nuts, coefficients=coeffs), + key3, + n=models[model]["nuts"], + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "nuts", + 0.0, + 0.0, + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() - print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "L", + "step_size", + "integrator", + "tuning", + "acc_rate", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results.csv", index=False) return results -def benchmark_omelyan(batch_size): - +def benchmark_omelyan(batch_size): key = jax.random.PRNGKey(2) results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmchmc"], + [ + StandardNormal(d) + for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) + ], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients, omelyan_coefficients], + ): sampler, model, coefficients = variables # num_chains = 1 + batch_size//model.ndims num_chains = batch_size - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) + current_key, key = jax.random.split(key) + init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( + current_key, 5 + ) # num_steps = models[model][sampler] num_steps = 1000 - initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coefficients), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_mhmclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -500,49 +797,106 @@ def benchmark_omelyan(batch_size): frac_tune1=0.1, frac_tune2=0.1, # frac_tune3=0.1, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + print( + f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", + ) + print( + f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" + ) # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) + + L, step_size, converged = gridsearch_tune( + grid_key, + iterations=10, + contract=jnp.average, + grid_size=5, + model=model, + sampler=partial( + run_mhmclmc_no_tuning, + coefficients=coefficients, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_mhmclmc_sampler_params.L, + center_step_size=blackjax_mhmclmc_sampler_params.step_size, + ) print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) + ess, grad_calls, _, _, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + coefficients=coefficients, + L=L, + step_size=step_size, + std_mat=1.0, + initial_state=state, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=jnp.average, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() + results[ + ( + model.name, + model.ndims, + sampler, + name_integrator(coefficients), + converged, + L.item(), + step_size.item(), + ) + ] = ess.item() df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "integrator", + "convergence", + "L", + "step_size", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("omelyan.csv", index=False) def run_benchmarks_divij(): - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal + model = StandardNormal(10) # 10 dimensional standard normal coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions + contract = jnp.average # how we average across dimensions num_steps = 2000 num_chains = 100 key1 = jax.random.PRNGKey(2) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( + model, + partial(sampler, coefficients=coefficients), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") -if __name__ == "__main__": +if __name__ == "__main__": # run_benchmarks_divij() benchmark_mhmchmc(batch_size=128) @@ -550,8 +904,5 @@ def run_benchmarks_divij(): # run_benchmarks_step_size(128) benchmark_omelyan(128) # run_benchmarks(128) - #benchmark_omelyan(10) + # benchmark_omelyan(10) # print("4") - - - diff --git a/blackjax/util.py b/blackjax/util.py index 2b52ca5ae..e579c126d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -246,4 +246,4 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): ) total += weight streaming_avg = (total, unravel_fn(average)) - return streaming_avg \ No newline at end of file + return streaming_avg diff --git a/explore.py b/explore.py deleted file mode 100644 index e97458051..000000000 --- a/explore.py +++ /dev/null @@ -1,61 +0,0 @@ -import jax -import jax.numpy as jnp - -import blackjax -from blackjax.util import run_inference_algorithm - -init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) - - -def logdensity_fn(x): - return -0.5 * jnp.sum(jnp.square(x)) - - -initial_position = jnp.ones( - 10, -) - - -def run_mclmc(logdensity_fn, key, num_steps, initial_position): - init_key, tune_key, run_key = jax.random.split(key, 3) - - initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key - ) - - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1, std_mat=1.) - - average, states = run_inference_algorithm( - rng_key=run_key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - progress_bar=True, - transform=lambda x: x.position, - streaming=True, - ) - - print(average) - - _, states, _ = run_inference_algorithm( - rng_key=run_key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - progress_bar=False, - transform=lambda x: x.position, - streaming=False, - ) - - print(states.mean(axis=0)) - - return states - - -# out = run_hmc(initial_position) -out = run_mclmc( - logdensity_fn=logdensity_fn, - num_steps=5, - initial_position=initial_position, - key=jax.random.PRNGKey(0), -) From 2d3c3fc3d002c4bb7fd700d26be502828f328a2c Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 21:48:43 +0200 Subject: [PATCH 25/71] FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS --- blackjax/mcmc/integrators.py | 80 ++++++++++++++++++++++++++++------ blackjax/mcmc/trajectory.py | 2 +- tests/mcmc/test_integrators.py | 8 ++-- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0f4deeca4..5a2f71838 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -23,13 +23,22 @@ from blackjax.types import ArrayTree __all__ = [ + "velocity_verlet_coefficients", + "mclachlan_coefficients", + "yoshida_coefficients", + "omelyan_coefficients", "mclachlan", + "omelyan", "velocity_verlet", "yoshida", - "implicit_midpoint", - "isokinetic_leapfrog", + "with_isokinetic_maruyama", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", + "isokinetic_omelyan", "isokinetic_yoshida", + "implicit_midpoint", + "calls_per_integrator_step", + "name_integrator", ] @@ -70,7 +79,7 @@ def generalized_two_stage_integrator( .. math:: \\frac{d}{dt}f = (O_1+O_2)f - The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + The velocity_verlet operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and @@ -210,7 +219,7 @@ def format_euclidean_state_output( return IntegratorState(position, momentum, logdensity, logdensity_grad) -def generate_euclidean_integrator(cofficients): +def generate_euclidean_integrator(coefficients): """Generate symplectic integrator for solving a Hamiltonian system. The resulting integrator is volume-preserve and preserves the symplectic structure @@ -225,7 +234,7 @@ def euclidean_integrator( one_step = generalized_two_stage_integrator( momentum_update_fn, position_update_fn, - cofficients, + coefficients, format_output_fn=format_euclidean_state_output, ) return one_step @@ -251,8 +260,8 @@ def euclidean_integrator( of the kinetic energy. We are trading accuracy in exchange, and it is not clear whether this is the right tradeoff. """ -velocity_verlet_cofficients = [0.5, 1.0, 0.5] -velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) +velocity_verlet_coefficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_coefficients) """ Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. @@ -268,8 +277,8 @@ def euclidean_integrator( b1 = 0.1931833275037836 a1 = 0.5 b2 = 1 - 2 * b1 -mclachlan_cofficients = [b1, a1, b2, a1, b1] -mclachlan = generate_euclidean_integrator(mclachlan_cofficients) +mclachlan_coefficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_coefficients) """ Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` @@ -284,8 +293,20 @@ def euclidean_integrator( a1 = 0.29619504261126 b2 = 0.5 - b1 a2 = 1 - 2 * a1 -yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] -yoshida = generate_euclidean_integrator(yoshida_cofficients) +yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_coefficients) + +"""11 stage Omelyan integrator [I.P. Omelyan, I.M. Mryglod and R. Folk, Comput. Phys. Commun. 151 (2003) 272.], +4MN5FV in [Takaishi, Tetsuya, and Philippe De Forcrand. "Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD." Physical Review E 73.3 (2006): 036706.] +popular in LQCD""" +b1 = 0.08398315262876693 +a1 = 0.2539785108410595 +b2 = 0.6822365335719091 +a2 = -0.03230286765269967 +b3 = 0.5 - b1 - b2 +a3 = 1 - 2 * (a1 + a2) +omelyan_coefficients = [b1, a1, b2, a2, b3, a3, b3, a2, b2, a1, b1] +omelyan = generate_euclidean_integrator(omelyan_coefficients) # Intergrators with non Euclidean updates @@ -371,9 +392,12 @@ def isokinetic_integrator( return isokinetic_integrator -isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients) -isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients) -isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients) +isokinetic_velocity_verlet = generate_isokinetic_integrator( + velocity_verlet_coefficients +) +isokinetic_yoshida = generate_isokinetic_integrator(yoshida_coefficients) +isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_coefficients) +isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients) def partially_refresh_momentum(momentum, rng_key, step_size, L): @@ -429,6 +453,34 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): return stochastic_integrator +def calls_per_integrator_step(c): + if c == velocity_verlet_coefficients: + return 1 + if c == mclachlan_coefficients: + return 2 + if c == yoshida_coefficients: + return 3 + if c == omelyan_coefficients: + return 5 + + else: + raise Exception + + +def name_integrator(c): + if c == velocity_verlet_coefficients: + return "velocity_verlet" + if c == mclachlan_coefficients: + return "mclachlan" + if c == yoshida_coefficients: + return "yoshida" + if c == omelyan_coefficients: + return "omelyan" + + else: + raise Exception + + FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], Tuple[ArrayTree, ArrayTree, Any], diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 85891bda6..7bb1b35a5 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -357,7 +357,7 @@ def buildtree_integrate( """ if tree_depth == 0: - # Base case - take one leapfrog step in the direction v. + # Base case - take one velocity_verlet step in the direction v. next_state = integrator(initial_state, direction * step_size) new_proposal = generate_proposal(initial_energy, next_state) is_diverging = -new_proposal.weight > divergence_threshold diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index ddb13ad57..345290bfa 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -140,7 +140,7 @@ def kinetic_energy(p, position=None): "algorithm": integrators.implicit_midpoint, "precision": 1e-4, }, - "isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog}, + "isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet}, "isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan}, "isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida}, } @@ -239,13 +239,13 @@ def test_esh_momentum_update(self, dims): np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @chex.all_variants(with_pmap=False) - def test_isokinetic_leapfrog(self): + def test_isokinetic_velocity_verlet(self): cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]]) logdensity_fn = lambda x: stats.multivariate_normal.logpdf( x, jnp.zeros([3]), cov ) - step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn)) + step = self.variant(integrators.isokinetic_velocity_verlet(logdensity_fn)) rng = jax.random.key(4263456) key0, key1 = jax.random.split(rng, 2) @@ -294,7 +294,7 @@ def test_isokinetic_leapfrog(self): @chex.all_variants(with_pmap=False) @parameterized.parameters( [ - "isokinetic_leapfrog", + "isokinetic_velocity_verlet", "isokinetic_mclachlan", "isokinetic_yoshida", ], From dad0060dbcc44c9c44eb3acccb403331a0b5b0a3 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 22:18:25 +0200 Subject: [PATCH 26/71] TEMPORARILY ADD BENCHMARKS --- blackjax/benchmarks/mcmc/benchmark.py | 823 +++++----------- .../ground_truth/brownian/ground_truth.npy | Bin 0 -> 384 bytes blackjax/benchmarks/mcmc/inference_models.py | 892 ++++++++++++++++++ .../benchmarks/mcmc/sampling_algorithms.py | 188 ++++ blackjax/mcmc/integrators.py | 13 +- 5 files changed, 1340 insertions(+), 576 deletions(-) create mode 100644 blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy create mode 100644 blackjax/benchmarks/mcmc/inference_models.py create mode 100644 blackjax/benchmarks/mcmc/sampling_algorithms.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py index 9eadc7e2f..174cd30f7 100644 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -1,26 +1,21 @@ # mypy: ignore-errors # flake8: noqa +from collections import defaultdict +from functools import partial import math import operator import os import pprint -from collections import defaultdict -from functools import partial from statistics import mean, median - import jax import jax.numpy as jnp import pandas as pd import scipy -from blackjax.adaptation.mclmc_adaptation import ( - MCLMCAdaptationState, - integrator_order, - target_acceptance_rate_of_order, -) +from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) num_cores = jax.local_device_count() # print(num_cores, jax.lib.xla_bridge.get_backend().platform) @@ -29,77 +24,48 @@ import numpy as np import blackjax -from blackjax.benchmarks.mcmc.inference_models import ( - Brownian, - GermanCredit, - ItemResponseTheory, - MixedLogit, - StandardNormal, - StochasticVolatility, - models, -) -from blackjax.benchmarks.mcmc.sampling_algorithms import ( - run_mclmc, - run_mhmclmc, - run_nuts, - samplers, -) -from blackjax.mcmc.integrators import ( - calls_per_integrator_step, - generate_euclidean_integrator, - generate_isokinetic_integrator, - isokinetic_mclachlan, - mclachlan_coefficients, - name_integrator, - omelyan_coefficients, - velocity_verlet, - velocity_verlet_coefficients, - yoshida_coefficients, -) -from blackjax.mcmc.mhmclmc import rescale +from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers +from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models +from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients +# from blackjax.mcmc.mhmclmc import rescale from blackjax.util import run_inference_algorithm +target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} def get_num_latents(target): - return target.ndims - - + return target.ndims # return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) def err(f_true, var_f, contract): """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - + Args: + f: E_sampler[f(x)], can be a vector + f_true: E_true[f(x)] + var_f: Var_true[f(x)] + contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max + + Returns: + contract(b^2) + """ + return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) -def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): - """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" +def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): + """Uses the error of the expectation values to compute the effective sample size neff + b^2 = 1/neff""" + cutoff_reached = err_t[-1] < low_error return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff=100): - grads_to_low, cutoff_reached = grads_to_low_error( - err_t, grad_evals_per_step, 1.0 / neff - ) - - return ( - (neff / grads_to_low) * cutoff_reached, - grads_to_low * (1 / cutoff_reached), - cutoff_reached, - ) + + +def calculate_ess(err_t, grad_evals_per_step, neff= 100): + + grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) + + return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached def find_crossing(array, cutoff): @@ -111,61 +77,34 @@ def find_crossing(array, cutoff): print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) return 1 - return jnp.max(indices) + 1 + return jnp.max(indices)+1 def cumulative_avg(samples): - return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune( - key, - iterations, - grid_size, - model, - sampler, - batch, - num_steps, - center_L, - center_step_size, - contract, -): + return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] + + +def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): results = defaultdict(float) converged = False - keys = jax.random.split(key, iterations + 1) + keys = jax.random.split(key, iterations+1) for i in range(iterations): print(f"EPOCH {i}") width = 2 - step_sizes = np.logspace( - np.log10(center_step_size / width), - np.log10(center_step_size * width), - grid_size, - ) - Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) + step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) + Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) # print(list(itertools.product(step_sizes , Ls))) - grid_keys = jax.random.split(keys[i], grid_size ^ 2) + grid_keys = jax.random.split(keys[i], grid_size^2) print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): - ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( - model, - sampler(step_size=step_size, L=L), - grid_keys[j], - n=num_steps, - batch=batch, - contract=contract, - ) + for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): + ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) results[(step_size, L)] = (ess, grad_calls_until_convergence) - best_ess, best_grads, (step_size, L) = max( - ((results[r][0], results[r][1], r) for r in results), - key=operator.itemgetter(0), - ) + best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) # raise Exception - print( - f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" - ) - if L == center_L and step_size == center_step_size: + print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + if L==center_L and step_size==center_step_size: print("converged") converged = True break @@ -173,67 +112,42 @@ def gridsearch_tune( center_L, center_step_size = L, step_size pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") + # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") return center_L, center_step_size, converged def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): + def s(logdensity_fn, num_steps, initial_position, transform, key): + integrator = generate_isokinetic_integrator(coefficients) - num_steps_per_traj = L / step_size + num_steps_per_traj = L/step_size alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(num_steps_per_traj) - ), - integrator=integrator, - std_mat=std_mat, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , + integrator=integrator, + std_mat=std_mat, ) _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) + rng_key=key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True) - return ( - out, - MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), - num_steps_per_traj * calls_per_integrator_step(coefficients), - info.acceptance_rate.mean(), - None, - jnp.array([0]), - ) + return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) return s +def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): -def benchmark_chains( - model, - sampler, - key, - n=10000, - batch=None, - contract=jnp.average, -): pvmap = jax.pmap - - # def pvmap(f): - # def f(arr): - # return arr - # print(arr.shape,"shape") - # print(arr) - # arr = arr.reshape(128, -1) - # out = jax.vmap(jax.vmap(f), in_axes=0)(arr) - # return out.flatten() - # return f - + d = get_num_latents(model) if batch is None: batch = np.ceil(1000 / d).astype(int) @@ -244,31 +158,13 @@ def benchmark_chains( init_pos = pvmap(model.sample_init)(init_keys) # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - ( - samples, - params, - grad_calls_per_traj, - acceptance_rate, - step_size_over_da, - final_da, - ) = pvmap( - lambda pos, key: sampler( - logdensity_fn=model.logdensity_fn, - num_steps=n, - initial_position=pos, - transform=model.transform, - key=key, - ) - )( - init_pos, keys - ) + samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) try: - print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) - except: - pass - - full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) + print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) + except: pass + + full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) err_t = pvmap(full)(samples**2) # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] @@ -278,6 +174,7 @@ def benchmark_chains( # return(mean(esses), mean(grad_calls)) # print(final_da.mean(), "final da") + err_t_median = jnp.median(err_t, axis=0) # import matplotlib.pyplot as plt # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) @@ -287,106 +184,62 @@ def benchmark_chains( # plt.yscale('log') # plt.savefig('brownian.png') # plt.close() - esses, grad_calls, _ = calculate_ess( - err_t_median, grad_evals_per_step=avg_grad_calls_per_traj - ) - return ( - esses, - grad_calls, - params, - jnp.mean(acceptance_rate, axis=0), - step_size_over_da, - ) + esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) + return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da + + def run_benchmarks(batch_size): + results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): + sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size#1 + batch_size//model.ndims + num_steps = 100000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.max key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ( - ess, - grad_calls, - params, - acceptance_rate, - step_size_over_da, - ) = benchmark_chains( - model, - partial( - samplers[sampler], - coefficients=coefficients, - frac_tune1=0.1, - frac_tune2=0.0, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) + ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) jax.numpy.save(f"acceptance.npy", acceptance_rate) + # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() + results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() print(ess.item()) # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] + # print(results) + df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] + df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_simple.csv", index=False) @@ -394,93 +247,92 @@ def run_benchmarks(batch_size): return results +def run_simple(): + + results = defaultdict(tuple) + for variables in itertools.product( + # ["mhmclmc", "nuts", "mclmc", ], + ["mclmc"], + # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], + [Brownian()], + # [Brownian()], + # [Brownian()], + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): + + sampler, model, coefficients = variables + num_chains = 128 + + num_steps = 10000 + + contract = jnp.max + + key = jax.random.PRNGKey(11) + for i in range(1): + key1, key = jax.random.split(key) + ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + + + results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + print(ess.item()) + + + return results + # vary step_size def run_benchmarks_step_size(batch_size): + results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [StandardNormal(10)], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): + + + num_steps = 10000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.average center = 6.534974 key = jax.random.PRNGKey(11) - for step_size in np.linspace(center - 1, center + 1, 41): - # for L in np.linspace(1, 10, 41): + for step_size in np.linspace(center-1,center+1, 41): + # for L in np.linspace(1, 10, 41): key1, key2, key3, key = jax.random.split(key, 4) initial_position = model.sample_init(key2) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=key3, - ) - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - initial_state=initial_state, - coefficients=mclachlan_coefficients, - step_size=step_size, - L=5 * step_size, - std_mat=1.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) + # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() + results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] + # print(results) + df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] + df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_step_size.csv", index=False) @@ -488,14 +340,17 @@ def run_benchmarks_step_size(batch_size): return results + def benchmark_mhmchmc(batch_size): + key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) results = defaultdict(tuple) # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] for model, coeffs in itertools.product(models, coefficients): - num_chains = batch_size # 1 + batch_size//model.ndims + + num_chains = batch_size # 1 + batch_size//model.ndims print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") num_steps = models[model]["mhmclmc"] print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") @@ -503,123 +358,67 @@ def benchmark_mhmchmc(batch_size): ####### run mclmc with standard tuning contract = jnp.max + - ess, grad_calls, params, _, step_size_over_da = benchmark_chains( + ess, grad_calls, params , _, step_size_over_da = benchmark_chains( model, - partial(run_mclmc, coefficients=coeffs), + partial(run_mclmc,coefficients=coeffs), key0, n=num_steps, batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mclmc", - params.L.mean().item(), - params.step_size.mean().item(), - name_integrator(coeffs), - "standard", - 1.0, - ) - ] = ess.item() - print(f"mclmc with tuning ESS {ess}") + contract=contract) + results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() + print(f'mclmc with tuning ESS {ess}') + - ####### run mhmclmc with standard tuning + ####### run mhmclmc with standard tuning for target_acc_rate in [0.65, 0.9]: # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") - + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( + model, + partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), + key1, + n=num_steps, + batch=num_chains, + contract=contract) + results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + print(f'mhmclmc with tuning ESS {ess}') + # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.1, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc:st3" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( + model, + partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), + key1, + n=num_steps, + batch=num_chains, + contract=contract) + results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + print(f'mhmclmc with tuning ESS {ess}') if True: ####### run mhmclmc with standard tuning + grid search - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( - key2, 5 - ) + init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) + integrator=generate_isokinetic_integrator(coeffs), + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn) ( state, blackjax_mhmclmc_sampler_params, - _, - _, + _, _ ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -629,165 +428,96 @@ def benchmark_mhmchmc(batch_size): frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, - diagonal_preconditioning=False, + diagonal_preconditioning=False ) - print( - f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) + print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") + print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") - L, step_size, convergence = gridsearch_tune( - grid_key, - iterations=10, - contract=contract, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coeffs, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) + + L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) # print(f"params after grid tuning are L={L}, step_size={step_size}") - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coeffs, - L=L, - step_size=step_size, - initial_state=state, - std_mat=1.0, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=contract, - ) + + ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) print(f"grads to low bias: {grad_calls}") - results[ - ( - model.name, - model.ndims, - "mhmchmc:grid", - L.item(), - step_size.item(), - name_integrator(coeffs), - f"gridsearch:{convergence}", - acceptance_rate.mean().item(), - ) - ] = ess.item() + results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() ####### run nuts # coeffs = velocity_verlet_coefficients - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - partial(run_nuts, coefficients=coeffs), - key3, - n=models[model]["nuts"], - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "nuts", - 0.0, - 0.0, - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() + ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) + results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() + + + + + + print(results) + df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "L", - "step_size", - "integrator", - "tuning", - "acc_rate", - "ESS", - ] + df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results.csv", index=False) return results - def benchmark_omelyan(batch_size): + + key = jax.random.PRNGKey(2) results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [ - StandardNormal(d) - for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) - ], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmchmc"], + [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients, omelyan_coefficients], + ): + + sampler, model, coefficients = variables # num_chains = 1 + batch_size//model.ndims num_chains = batch_size - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( - current_key, 5 - ) + current_key, key = jax.random.split(key) + init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) # num_steps = models[model][sampler] num_steps = 1000 + initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, + position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key ) + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) + integrator=generate_isokinetic_integrator(coefficients), + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn) ( state, blackjax_mhmclmc_sampler_params, - _, - _, + _, _ ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -797,112 +527,57 @@ def benchmark_omelyan(batch_size): frac_tune1=0.1, frac_tune2=0.1, # frac_tune3=0.1, - diagonal_preconditioning=False, + diagonal_preconditioning=False ) - print( - f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) + print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune( - grid_key, - iterations=10, - contract=jnp.average, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coefficients, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) + # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) + + L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) print(f"params after grid tuning are L={L}, step_size={step_size}") - ess, grad_calls, _, _, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coefficients, - L=L, - step_size=step_size, - std_mat=1.0, - initial_state=state, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=jnp.average, - ) + + ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) print(f"grads to low bias: {grad_calls}") - results[ - ( - model.name, - model.ndims, - sampler, - name_integrator(coefficients), - converged, - L.item(), - step_size.item(), - ) - ] = ess.item() + results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "integrator", - "convergence", - "L", - "step_size", - "ESS", - ] + df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("omelyan.csv", index=False) def run_benchmarks_divij(): + sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal + model = StandardNormal(10) # 10 dimensional standard normal coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions + contract = jnp.average # how we average across dimensions num_steps = 2000 num_chains = 100 key1 = jax.random.PRNGKey(2) - ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( - model, - partial(sampler, coefficients=coefficients), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) + ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") - if __name__ == "__main__": + # run_benchmarks_divij() - benchmark_mhmchmc(batch_size=128) - # run_benchmarks(128) + + + # benchmark_mhmchmc(batch_size=128) + run_simple() # run_benchmarks_step_size(128) - benchmark_omelyan(128) + # benchmark_omelyan(128) # run_benchmarks(128) - # benchmark_omelyan(10) + #benchmark_omelyan(10) # print("4") diff --git a/blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy b/blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy new file mode 100644 index 0000000000000000000000000000000000000000..d381c47de0061e8dd351fc55551c143439c7381d GIT binary patch literal 384 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Its={nmP)#3giN=HsrpV z?PsMKwrgvr+xp7gw{>nVw_R9u!1h`Jv)z%e+IICTL+spj8toc0XWQwN@36aIe8cY9 z%-?oicLnVo88z%@vZ~vcx69bC->+j|`BcjOJPWTqQmVb6{~e;&sZ_vu(iHo&1bz(|CsfW zc2}FscMEJ@Gd{4{<1S=7;gOl`9=#CT(~POMrvnmfvsEo@|2c28k=oc~ldPO%W7pAP e<0&0(^RCq0CX~P0=KmsFn?;B9Y}9Xs*#H1}DT0^) literal 0 HcmV?d00001 diff --git a/blackjax/benchmarks/mcmc/inference_models.py b/blackjax/benchmarks/mcmc/inference_models.py new file mode 100644 index 000000000..b918ce3bf --- /dev/null +++ b/blackjax/benchmarks/mcmc/inference_models.py @@ -0,0 +1,892 @@ +# mypy: ignore-errors +# flake8: noqa + +#from inference_gym import using_jax as gym +import jax +import jax.numpy as jnp +import numpy as np +import os +#import numpyro.distributions as dist +dirr = os.path.dirname(os.path.realpath(__file__)) + + + +class StandardNormal(): + """Standard Normal distribution in d dimensions""" + + def __init__(self, d): + self.ndims = d + self.E_x2 = jnp.ones(d) + self.Var_x2 = 2 * self.E_x2 + self.name = 'StandardNormal' + + + def logdensity_fn(self, x): + """- log p of the target distribution""" + return -0.5 * jnp.sum(jnp.square(x), axis= -1) + + + def transform(self, x): + return x + + def sample_init(self, key): + return jax.random.normal(key, shape = (self.ndims, )) + + + +class IllConditionedGaussian(): + """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" + + + def __init__(self, d, condition_number, numpy_seed=None, prior= 'prior'): + """numpy_seed is used to generate a random rotation for the covariance matrix. + If None, the covariance matrix is diagonal.""" + + self.ndims = d + self.name = 'IllConditionedGaussian' + self.condition_number = condition_number + eigs = jnp.logspace(-0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d) + + if numpy_seed == None: # diagonal + self.E_x2 = eigs + self.R = jnp.eye(d) + self.Hessian = jnp.diag(1 / eigs) + self.Cov = jnp.diag(eigs) + + else: # randomly rotate + rng = np.random.RandomState(seed=numpy_seed) + D = jnp.diag(eigs) + inv_D = jnp.diag(1 / eigs) + R, _ = jnp.array(np.linalg.qr(rng.randn(self.ndims, self.ndims))) # random rotation + self.R = R + self.Hessian = R @ inv_D @ R.T + self.Cov = R @ D @ R.T + self.E_x2 = jnp.diagonal(R @ D @ R.T) + + #Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) + + #print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) + + self.Var_x2 = 2 * jnp.square(self.E_x2) + + + self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x + self.transform = lambda x: x + + + if prior == 'map': + self.sample_init = lambda key: jnp.zeros(self.ndims) + + elif prior == 'posterior': + self.sample_init = lambda key: self.R @ (jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs)) + + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(jnp.sqrt(eigs)) + + + +class IllConditionedESH(): + """ICG from the ESH paper.""" + + def __init__(self): + self.ndims = 50 + self.name = 'IllConditionedESH' + self.variance = jnp.linspace(0.01, 1, self.ndims) + + + + + def logdensity_fn(self, x): + """- log p of the target distribution""" + return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis= -1) + + + def transform(self, x): + return x + + def draw(self, key): + return jax.random.normal(key, shape = (self.ndims, )) * jnp.sqrt(self.variance) + + def sample_init(self, key): + return jax.random.normal(key, shape = (self.ndims, )) + + + + +class IllConditionedGaussianGamma(): + """Inference gym's Ill conditioned Gaussian""" + + def __init__(self, prior = 'prior'): + self.ndims = 100 + self.name = 'IllConditionedGaussianGamma' + + # define the Hessian + rng = np.random.RandomState(seed=10 & (2 ** 32 - 1)) + eigs = np.sort(rng.gamma(shape=0.5, scale=1., size=self.ndims)) #eigenvalues of the Hessian + eigs *= jnp.average(1.0/eigs) + self.entropy = 0.5 * self.ndims + self.maxmin = (1./jnp.sqrt(eigs[0]), 1./jnp.sqrt(eigs[-1])) + R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) #random rotation + self.map_to_worst = (R.T)[[0, -1], :] + self.Hessian = R @ np.diag(eigs) @ R.T + + # analytic ground truth moments + self.E_x2 = jnp.diagonal(R @ np.diag(1.0/eigs) @ R.T) + self.Var_x2 = 2 * jnp.square(self.E_x2) + + # norm = jnp.diag(1/jnp.sqrt(self.E_x2)) + # Sigma = R @ np.diag(1/eigs) @ R.T + # reduced = norm @ Sigma @ norm + # print(np.linalg.cond(reduced), np.linalg.cond(Sigma)) + + # gradient + + + if prior == 'map': + self.sample_init = lambda key: jnp.zeros(self.ndims) + + elif prior == 'posterior': + self.sample_init = lambda key: R @ (jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs)) + + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(1.0/jnp.sqrt(eigs)) + + def logdensity_fn(self, x): + """- log p of the target distribution""" + return -0.5 * x.T @ self.Hessian @ x + + def transform(self, x): + return x + + + + +class Banana(): + """Banana target fromm the Inference Gym""" + + def __init__(self, prior = 'map'): + self.curvature = 0.03 + self.ndims = 2 + self.name = 'Banana' + + self.transform = lambda x: x + self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.Var_x2 = jnp.array([20000.0, 4600.898]) + + if prior == 'map': + self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) + elif prior == 'posterior': + self.sample_init = lambda key: self.posterior_draw(key) + elif prior == 'prior': + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + else: + raise ValueError('prior = '+prior +' is not defined.') + + def logdensity_fn(self, x): + mu2 = self.curvature * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + + def posterior_draw(self, key): + z = jax.random.normal(key, shape = (2, )) + x0 = 10.0 * z[0] + x1 = self.curvature * (x0 ** 2 - 100) + z[1] + return jnp.array([x0, x1]) + + def ground_truth(self): + x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + print(jnp.average(x, axis=0)) + print(jnp.average(jnp.square(x), axis=0)) + print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) + + + + +class Cauchy(): + """d indpendent copies of the standard Cauchy distribution""" + + def __init__(self, d): + self.ndims = d + self.name = 'Cauchy' + + self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1. + jnp.square(x))) + + self.transform = lambda x: x + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) + + + + +class HardConvex(): + + def __init__(self, d, kappa, theta = 0.1): + """d is the dimension, kappa = condition number, 0 < theta < 1/4""" + self.ndims = d + self.name = 'HardConvex' + self.theta, self.kappa = theta, kappa + C = jnp.power(d-1, 0.25 - theta) + self.logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) - (0.75 / kappa)* x[-1]**2 + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 + + self.transform = lambda x: x + + # numerically precomputed variances + num_integration = [0.93295, 0.968802, 0.990595, 0.998002, 0.999819] + if d == 100: + self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[0], jnp.ones(1) * 2.0*kappa/3.0)) + elif d == 300: + self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[1], jnp.ones(1) * 2.0*kappa/3.0)) + elif d == 1000: + self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[2], jnp.ones(1) * 2.0*kappa/3.0)) + elif d == 3000: + self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[3], jnp.ones(1) * 2.0*kappa/3.0)) + elif d == 10000: + self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[4], jnp.ones(1) * 2.0*kappa/3.0)) + else: + None + + + def sample_init(self, key): + """Gaussian prior with approximately estimating the variance along each dimension""" + scale = jnp.concatenate((jnp.ones(self.ndims-1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0))) + return jax.random.normal(key, shape=(self.ndims,)) * scale + + + + +class BiModal(): + """A Gaussian mixture p(x) = f N(x | mu1, sigma1) + (1-f) N(x | mu2, sigma2).""" + + def __init__(self, d = 50, mu1 = 0.0, mu2 = 8.0, sigma1 = 1.0, sigma2 = 1.0, f = 0.2): + + self.ndims = d + self.name = 'BiModal' + + self.mu1 = jnp.insert(jnp.zeros(d-1), 0, mu1) + self.mu2 = jnp.insert(jnp.zeros(d - 1), 0, mu2) + self.sigma1, self.sigma2 = sigma1, sigma2 + self.f = f + self.variance = jnp.insert(jnp.ones(d-1) * ((1 - f) * sigma1**2 + f * sigma2**2), 0, (1-f)*(sigma1**2 + mu1**2) + f*(sigma2**2 + mu2**2)) + + + + def logdensity_fn(self, x): + """- log p of the target distribution""" + + N1 = (1.0 - self.f) * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu1), axis= -1) / self.sigma1 ** 2) / jnp.power(2 * jnp.pi * self.sigma1 ** 2, self.ndims * 0.5) + N2 = self.f * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu2), axis= -1) / self.sigma2 ** 2) / jnp.power(2 * jnp.pi * self.sigma2 ** 2, self.ndims * 0.5) + + return jnp.log(N1 + N2) + + + def draw(self, num_samples): + """direct sampler from a target""" + X = np.random.normal(size = (num_samples, self.ndims)) + mask = np.random.uniform(0, 1, num_samples) < self.f + X[mask, :] = (X[mask, :] * self.sigma2) + self.mu2 + X[~mask] = (X[~mask] * self.sigma1) + self.mu1 + + return X + + + def transform(self, x): + return x + + def sample_init(self, key): + z = jax.random.normal(key, shape = (self.ndims, )) *self.sigma1 + #z= z.at[0].set(self.mu1 + z[0]) + return z + + +class BiModalEqual(): + """Mixture of two Gaussians, one centered at x = [mu/2, 0, 0, ...], the other at x = [-mu/2, 0, 0, ...]. + Both have equal probability mass.""" + + def __init__(self, d, mu): + + self.ndims = d + self.name = 'BiModalEqual' + self.mu = mu + + + + def logdensity_fn(self, x): + """- log p of the target distribution""" + + return -0.5 * jnp.sum(jnp.square(x), axis= -1) + jnp.log(jnp.cosh(0.5*self.mu*x[0])) - 0.5* self.ndims * jnp.log(2 * jnp.pi) - self.mu**2 / 8.0 + + + def draw(self, num_samples): + """direct sampler from a target""" + X = np.random.normal(size = (num_samples, self.ndims)) + mask = np.random.uniform(0, 1, num_samples) < 0.5 + X[mask, 0] += 0.5*self.mu + X[~mask, 0] -= 0.5 * self.mu + + return X + + def transform(self, x): + return x + + +class Funnel(): + """Noise-less funnel""" + + def __init__(self, d = 20): + + self.ndims = d + self.name = 'Funnel' + self.sigma_theta= 3.0 + + self.E_x2 = jnp.ones(d) # the transformed variables are standard Gaussian distributed + self.Var_x2 = 2 * self.E_x2 + + + + def logdensity_fn(self, x): + """ - log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta] """ + theta = x[-1] + X = x[..., :- 1] + + return -0.5* jnp.square(theta / self.sigma_theta) - 0.5 * (self.ndims - 1) * theta - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis = -1) + + def inverse_transform(self, xtilde): + theta = 3 * xtilde[-1] + return jnp.concatenate((xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1)*theta)) + + + def transform(self, x): + """gaussianization""" + xtilde = jnp.empty(x.shape) + xtilde = xtilde.at[-1].set(x.T[-1] / 3.0) + xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5*x.T[-1])) + return xtilde.T + + + def sample_init(self, key): + return self.inverse_transform(jax.random.normal(key, shape = (self.ndims, ))) + + + + +class Funnel_with_Data(): + + def __init__(self, d, sigma, minibatch_size, key): + + self.ndims = d + self.name = 'Funnel_with_Data' + self.sigma_theta= 3.0 + self.theta_true = 0.0 + self.sigma_data = sigma + + + self.data = self.simulate_data() + + self.batch = minibatch_size + + def simulate_data(self): + + norm = jax.random.normal(jax.random.PRNGKey(123), shape = (2*(self.ndims-1), )) + z_true = norm[:self.ndims-1] * jnp.exp(self.theta_true * 0.5) + self.data = z_true + norm[self.ndims-1:] * self.sigma_data + + + def logdensity_fn(self, x, subset): + """ - log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta] """ + theta = x[-1] + z = x[:- 1][subset] + + prior_theta = jnp.square(theta / self.sigma_theta) + prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum(jnp.square(z*subset)) + likelihood = jnp.sum(jnp.square((z - self.data)*subset / self.sigma_data)) + + return -0.5 * (prior_theta + prior_z + likelihood) + + + def transform(self, x): + """gaussianization""" + return x + + def sample_init(self, key): + key1, key2 = jax.random.split(key) + theta = jax.random.normal(key1) * self.sigma_theta + z = jax.random.normal(key2, shape = (self.ndims-1, )) * jnp.exp(theta * 0.5) + return jnp.concatenate((z, theta)) + + + + +class Rosenbrock(): + + def __init__(self, d = 36, Q = 0.1): + + self.ndims = d + self.name = 'Rosenbrock' + self.Q = Q + #ground truth moments + var_x = 2.0 + + #these two options were precomputed: + if Q == 0.1: + var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) + elif Q == 0.5: + var_y = 10.498957879911487 + else: + raise ValueError('Ground truth moments for Q = ' + str(Q) + ' were not precomputed. Use Q = 0.1 or 0.5.') + + self.variance = jnp.concatenate((var_x * jnp.ones(d//2), var_y * jnp.ones(d//2))) + + + + + def logdensity_fn(self, x): + """- log p of the target distribution""" + X, Y = x[..., :self.ndims//2], x[..., self.ndims//2:] + return -0.5 * jnp.sum(jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis= -1) + + + + def draw(self, num): + n = self.ndims // 2 + X= np.empty((num, self.ndims)) + X[:, :n] = np.random.normal(loc= 1.0, scale= 1.0, size= (num, n)) + X[:, n:] = np.random.normal(loc= jnp.square(X[:, :n]), scale= jnp.sqrt(self.Q), size= (num, n)) + + return X + + + def transform(self, x): + return x + + + def sample_init(self, key): + return jax.random.normal(key, shape = (self.ndims, )) + + + def ground_truth(self): + num = 100000000 + x = np.random.normal(loc=1.0, scale=1.0, size=num) + y = np.random.normal(loc=np.square(x), scale=jnp.sqrt(self.Q), size=num) + + x2 = jnp.sum(jnp.square(x)) / (num - 1) + y2 = jnp.sum(jnp.square(y)) / (num - 1) + + x1 = np.average(x) + y1 = np.average(y) + + print(np.sqrt(0.5*(np.square(np.std(x)) + np.square(np.std(y))))) + + print(x2, y2) + + + +class Brownian(): + """ + log sigma_i ~ N(0, 2) + log sigma_obs ~N(0, 2) + + x ~ RandomWalk(0, sigma_i) + x_observed = (x + noise) * mask + noise ~ N(0, sigma_obs) + mask = 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 + """ + + def __init__(self): + self.num_data = 30 + self.name = 'Brownian' + self.ndims = self.num_data + 2 + + ground_truth_moments = jnp.load(dirr + '/ground_truth/brownian/ground_truth.npy') + self.E_x2, self.Var_x2 = ground_truth_moments[0], ground_truth_moments[1] + + self.data = jnp.array([0.21592641, 0.118771404, -0.07945447, 0.037677474, -0.27885845, -0.1484156, -0.3250906, -0.22957903, + -0.44110894, -0.09830782, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8786016, -0.83736074, + -0.7384849, -0.8939254, -0.7774566, -0.70238715, -0.87771565, -0.51853573, -0.6948214, -0.6202789]) + # sigma_obs = 0.15, sigma_i = 0.1 + + self.observable = jnp.concatenate((jnp.ones(10), jnp.zeros(10), jnp.ones(10))) + self.num_observable = jnp.sum(self.observable) # = 20 + + + def logdensity_fn(self, x): + # y = softplus_to_log(x[:2]) + + lik = 0.5 * jnp.exp(-2 * x[1]) * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) + x[ + 1] * self.num_observable + prior_x = 0.5 * jnp.exp(-2 * x[0]) * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) + x[0] * self.num_data + prior_logsigma = 0.5 * jnp.sum(jnp.square(x / 2.0)) + + return -lik - prior_x - prior_logsigma + + + def transform(self, x): + return jnp.concatenate((jnp.exp(x[:2]), x[2:])) + + + def sample_init(self, key): + key_walk, key_sigma = jax.random.split(key) + + # original prior + # log_sigma = jax.random.normal(key_sigma, shape= (2, )) * 2 + + # narrower prior + log_sigma = jnp.log(np.array([0.1, 0.15])) + jax.random.normal(key_sigma, shape=( + 2,)) * 0.1 # *0.05# log sigma_i, log sigma_obs + + walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) + + return jnp.concatenate((log_sigma, walk)) + + def generate_data(self, key): + key_walk, key_sigma, key_noise = jax.random.split(key, 3) + + log_sigma = jax.random.normal(key_sigma, shape=(2,)) * 2 # log sigma_i, log sigma_obs + + walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) + noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp(log_sigma[1]) + + return walk + noise + + +class GermanCredit: + """ Taken from inference gym. + + x = (global scale, local scales, weights) + + global_scale ~ Gamma(0.5, 0.5) + + for i in range(num_features): + unscaled_weights[i] ~ Normal(loc=0, scale=1) + local_scales[i] ~ Gamma(0.5, 0.5) + weights[i] = unscaled_weights[i] * local_scales[i] * global_scale + + for j in range(num_datapoints): + label[j] ~ Bernoulli(features @ weights) + + We use a log transform for the scale parameters. + """ + + def __init__(self): + self.ndims = 51 #global scale + 25 local scales + 25 weights + self.name = 'GermanCredit' + + self.labels = jnp.load(dirr + '/data/gc_labels.npy') + self.features = jnp.load(dirr + '/data/gc_features.npy') + + truth = jnp.load(dirr+'/ground_truth/german_credit/ground_truth.npy') + self.E_x2, self.Var_x2 = truth[0], truth[1] + + + + + def transform(self, x): + return jnp.concatenate((jnp.exp(x[:26]), x[26:])) + + def logdensity_fn(self, x): + + scales = jnp.exp(x[:26]) + + # prior + pr = jnp.sum(0.5 * scales + 0.5 * x[:26]) + 0.5 * jnp.sum(jnp.square(x[26:])) + + # transform + transform = -jnp.sum(x[:26]) + + # likelihood + weights = scales[0] * scales[1:26] * x[26:] + logits = self.features @ weights # = jnp.einsum('nd,...d->...n', self.features, weights) + lik = jnp.sum(self.labels * jnp.logaddexp(0., -logits) + (1-self.labels)* jnp.logaddexp(0., logits)) + + return -(lik + pr + transform) + + def sample_init(self, key): + weights = jax.random.normal(key, shape = (25, )) + return jnp.concatenate((jnp.zeros(26), weights)) + + + + +class ItemResponseTheory: + """ Taken from inference gym.""" + + def __init__(self): + self.ndims = 501 + self.name = 'ItemResponseTheory' + self.students = 400 + self.questions = 100 + + self.mask = jnp.load(dirr + '/data/irt_mask.npy') + self.labels = jnp.load(dirr + '/data/irt_labels.npy') + + truth = jnp.load(dirr+'/ground_truth/item_response_theory/ground_truth.npy') + self.E_x2, self.Var_x2 = truth[0], truth[1] + + + self.transform = lambda x: x + + def logdensity_fn(self, x): + + students = x[:self.students] + mean = x[self.students] + questions = x[self.students + 1:] + + # prior + pr = 0.5 * (jnp.square(mean - 0.75) + jnp.sum(jnp.square(students)) + jnp.sum(jnp.square(questions))) + + # likelihood + logits = mean + students[:, jnp.newaxis] - questions[jnp.newaxis, :] + bern = self.labels * jnp.logaddexp(0., -logits) + (1 - self.labels) * jnp.logaddexp(0., logits) + bern = jnp.where(self.mask, bern, jnp.zeros_like(bern)) + lik = jnp.sum(bern) + + return -lik - pr + + + def sample_init(self, key): + x = jax.random.normal(key, shape = (self.ndims,)) + x = x.at[self.students].add(0.75) + return x + + + + +class StochasticVolatility(): + """Example from https://num.pyro.ai/en/latest/examples/stochastic_volatility.html""" + + def __init__(self): + self.SP500_returns = jnp.load(dirr + '/data/SP500.npy') + + self.ndims = 2429 + self.name = 'StochasticVolatility' + + self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda + + data = jnp.load(dirr + '/ground_truth/stochastic_volatility/ground_truth_0.npy') + self.E_x2 = data[0] + self.Var_x2 = data[1] + + + + def logdensity_fn(self, x): + """- log p of the target distribution + x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" + + sigma = jnp.exp(x[-2]) * self.typical_sigma #we used this transformation to make x unconstrained + nu = jnp.exp(x[-1]) * self.typical_nu + + l1= (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) + l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * (jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3]))) / jnp.square(sigma) + l3 = jnp.sum(nlogp_StudentT(self.SP500_returns, nu, jnp.exp(x[:-2]))) + + return -(l1 + l2 + l3) + + + def transform(self, x): + """transforms to the variables which are used by numpyro (and in which we have the ground truth moments)""" + + z = jnp.empty(x.shape) + z = z.at[:-2].set(x[:-2]) # = s = log R + z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma + z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu + + return z + + + def sample_init(self, key): + """draws x from the prior""" + + key_walk, key_exp = jax.random.split(key) + + scales = jnp.array([self.typical_sigma, self.typical_nu]) + #params = jax.random.exponential(key_exp, shape = (2, )) * scales + params= scales + walk = random_walk(key_walk, self.ndims - 2) * params[0] + return jnp.concatenate((walk, jnp.log(params/scales))) + + +class MixedLogit(): + + def __init__(self): + + key = jax.random.PRNGKey(0) + key_poisson, key_x, key_beta, key_logit = jax.random.split(key, 4) + + self.ndims = 2014 + self.name = "Mixed Logit" + self.nind = 500 + self.nsessions = jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 + self.nbeta = 4 + nobs = jnp.sum(self.nsessions) + + mu_true = jnp.array([-1.5, -0.3, 0.8, 1.2]) + sigma_true = jnp.array([[0.5, 0.1, 0.1, 0.1], [0.1, 0.5, 0.1, 0.1], [0.1, 0.1, 0.5, 0.1], [0.1, 0.1, 0.1, 0.5]]) + beta_true = jax.random.multivariate_normal(key_beta, mu_true, sigma_true, shape=(self.nind,)) + beta_true_repeat = jnp.repeat(beta_true, self.nsessions, axis=0) + + self.x = jax.random.normal(key_x, (nobs, self.nbeta)) + self.y = 1 * jax.random.bernoulli(key_logit, (jax.nn.sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_true_repeat)))) + + self.d = self.nbeta + self.nbeta + (self.nbeta * (self.nbeta-1) // 2) + self.nbeta * self.nind # mu, tau, omega_chol, and (beta for each i) + self.prior_mean_mu = jnp.zeros(self.nbeta) + self.prior_var_mu = 10.0 * jnp.eye(self.nbeta) + self.prior_scale_tau = 5.0 + self.prior_concentration_omega = 1.0 + + self.grad_logp = jax.value_and_grad(self.logdensity_fn) + + def corrchol_to_reals(self,x): + '''Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals''' + dim = x.shape[0] + z = jnp.zeros((dim, dim)) + for i in range(dim): + for j in range(i): + z = z.at[i, j].set(x[i,j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + z_lower_triang = z[jnp.tril_indices(dim, -1)] + y = 0.5 * (jnp.log(1.0 + z_lower_triang) - jnp.log(1.0 - z_lower_triang)) + + return y + + def reals_to_corrchol(self,y): + '''Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix''' + len_vec = len(y) + dim = int(0.5 * (1 + 8 * len_vec) ** 0.5 + 0.5) + assert dim * (dim - 1) // 2 == len_vec + + z = jnp.zeros((dim, dim)) + z = z.at[jnp.tril_indices(dim, -1)].set(jnp.tanh(y)) + + x = jnp.zeros((dim, dim)) + for i in range(dim): + for j in range(i+1): + if i == j: + x = x.at[i, j].set(jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + else: + x = x.at[i, j].set(z[i,j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + return x + + + def logdensity_fn(self, pars): + """log p of the target distribution, i.e., log posterior distribution up to a constant""" + + mu = pars[:self.nbeta] + dim1 = self.nbeta + self.nbeta + log_tau = pars[self.nbeta:dim1] + dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 + omega_chol_realvec = pars[dim1:dim2] + beta = pars[dim2:].reshape(self.nind, self.nbeta) + + omega_chol = self.reals_to_corrchol(omega_chol_realvec) + omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) + tau = jnp.exp(log_tau) + tau_diagmat = jnp.diag(tau) + sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) + + beta_repeat = jnp.repeat(beta, self.nsessions, axis=0) + + log_lik = jnp.sum(self.y * jax.nn.log_sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat)) + (1 - self.y) * jax.nn.log_sigmoid(-jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat))) + + log_density_beta_popdist = -0.5 * self.nind * jnp.log(jnp.linalg.det(sigma)) - 0.5 * jnp.sum(jax.vmap(lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), in_axes=(0, None))(beta - mu, sigma)) + + muMinusPriorMean = mu - self.prior_mean_mu + log_prior_mu = -0.5 * jnp.log(jnp.linalg.det(self.prior_var_mu)) - 0.5 * jnp.dot(muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean)) + + log_prior_tau = jnp.sum(dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau)) + #log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) + log_prior_omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).log_prob(omega_chol) + #log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) + + return log_lik + log_density_beta_popdist + log_prior_mu + log_prior_tau + log_prior_omega_chol + + + def transform(self, pars): + """transform pars to the original (possibly constrained) pars""" + mu = pars[:self.nbeta] + dim1 = self.nbeta + self.nbeta + log_tau = pars[self.nbeta:dim1] + dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 + omega_chol_realvec = pars[dim1:dim2] + beta_flattened = pars[dim2:] + + omega_chol = self.reals_to_corrchol(omega_chol_realvec) + omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) + tau = jnp.exp(log_tau) + tau_diagmat = jnp.diag(tau) + sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) + + return jnp.concatenate((mu, sigma.flatten(), beta_flattened)) + + def sample_init(self, key): + """draws pars from the prior""" + + key_mu, key_omega_chol, key_tau, key_beta = jax.random.split(key, 4) + mu = jax.random.multivariate_normal(key_mu, self.prior_mean_mu, self.prior_var_mu) + omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).sample(key_omega_chol) + tau = dist.HalfCauchy(scale=self.prior_scale_tau).sample(key_tau, (self.nbeta,)) + + omega_chol_realvec = self.corrchol_to_reals(omega_chol) + log_tau = jnp.log(tau) + + omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) + tau_diagmat = jnp.diag(tau) + sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) + + beta = jax.random.multivariate_normal(key_beta, mu, sigma, shape=(self.nind,)) + + pars = jnp.concatenate((mu, log_tau, omega_chol_realvec, beta.flatten())) + return pars + + + +def nlogp_StudentT(x, df, scale): + y = x / scale + z = ( + jnp.log(scale) + + 0.5 * jnp.log(df) + + 0.5 * jnp.log(jnp.pi) + + jax.scipy.special.gammaln(0.5 * df) + - jax.scipy.special.gammaln(0.5 * (df + 1.0)) + ) + return 0.5 * (df + 1.0) * jnp.log1p(y**2.0 / df) + z + + + +def random_walk(key, num): + """ Genereting process for the standard normal walk: + x[0] ~ N(0, 1) + x[n+1] ~ N(x[n], 1) + + Args: + key: jax random key + num: number of points in the walk + Returns: + 1 realization of the random walk (array of length num) + """ + + def step(track, useless): + x, key = track + randkey, subkey = jax.random.split(key) + x += jax.random.normal(subkey) + return (x, randkey), x + + return jax.lax.scan(step, init=(0.0, key), xs=None, length=num)[1] + + + +models = { + + # Cauchy(100) : {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}, + # StandardNormal(100) : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, + # Banana() : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, + Brownian() : {'mclmc': 20000, 'mhmclmc' : 80000, 'nuts': 40000}, + + + # 'banana': Banana(), + # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}), + # GermanCredit(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, + # ItemResponseTheory(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, + # StochasticVolatility(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000} + } + +# models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'mhmclmc' : 40000, 'nuts': 1000}), +# # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'mhmclmc' : 50000, 'nuts': 1000}) +# } \ No newline at end of file diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py new file mode 100644 index 000000000..617797b79 --- /dev/null +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -0,0 +1,188 @@ +# mypy: ignore-errors +# flake8: noqa + + +import jax +import jax.numpy as jnp +import blackjax +# from blackjax.adaptation.window_adaptation import da_adaptation +from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator +# from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.util import run_inference_algorithm +import blackjax + +__all__ = ["samplers"] + + + + +def run_nuts( + coefficients, logdensity_fn, num_steps, initial_position, transform, key): + + integrator = generate_euclidean_integrator(coefficients) + # integrator = blackjax.mcmc.integrators.velocity_verlet # note: defaulted to in nuts + + rng_key, warmup_key = jax.random.split(key, 2) + + state, params = da_adaptation( + rng_key=warmup_key, + initial_position=initial_position, + algorithm=blackjax.nuts, + logdensity_fn=logdensity_fn) + + # print(params["inverse_mass_matrix"], "inv\n\n") + # warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, integrator=integrator) + # (state, params), _ = warmup.run(warmup_key, initial_position, 2000) + + nuts = blackjax.nuts(logdensity_fn=logdensity_fn, step_size=params['step_size'], inverse_mass_matrix= params['inverse_mass_matrix'], integrator=integrator) + + final_state, state_history, info_history = run_inference_algorithm( + rng_key=rng_key, + initial_state=state, + inference_algorithm=nuts, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True + ) + + # print("INFO\n\n",info_history.num_integration_steps) + + return state_history, params, info_history.num_integration_steps.mean() * calls_per_integrator_step(coefficients), info_history.acceptance_rate.mean(), None, None + +def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): + + integrator = generate_isokinetic_integrator(coefficients) + + init_key, tune_key, run_key = jax.random.split(key, 3) + + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=integrator, + std_mat=std_mat, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + diagonal_preconditioning=False, + # desired_energy_var= 1e-5 + ) + + # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.L, blackjax_mclmc_sampler_params.step_size)) + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=blackjax_mclmc_sampler_params.L, + step_size=blackjax_mclmc_sampler_params.step_size, + std_mat=blackjax_mclmc_sampler_params.std_mat, + integrator = integrator, + + # std_mat=jnp.ones((initial_position.shape[0],)), + ) + + _, samples, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) + + acceptance_rate = 1. + return samples, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients), acceptance_rate, None, None + + +def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, target_acc_rate=None): + integrator = generate_isokinetic_integrator(coefficients) + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.mhmclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + integrator=integrator, + integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn) + + if target_acc_rate is None: + target_acc_rate = target_acceptance_rate_of_order[integrator_order(coefficients)] + print("target acc rate") + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + params_history, + final_da + ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + frac_tune3=frac_tune3, + diagonal_preconditioning=False, + ) + + + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.step_size, blackjax_mclmc_sampler_params.L)) + + + alg = blackjax.mcmc.mhmclmc.mhmclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn = lambda key: jnp.ceil(jax.random.uniform(key) * rescale(L/step_size)) , + integrator=integrator, + std_mat=blackjax_mclmc_sampler_params.std_mat, + + + ) + + + _, out, info = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True) + + + + return out, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients) * (L/step_size), info.acceptance_rate, params_history, final_da + +# we should do at least: mclmc, nuts, unadjusted hmc, mhmclmc, langevin + +samplers = { + 'nuts' : run_nuts, + 'mclmc' : run_mclmc, + 'mhmclmc': run_mhmclmc, + } + + +# foo = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(20.56)) + +# print(jnp.mean(jax.vmap(foo)(jax.random.split(jax.random.PRNGKey(1), 10000000)))) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 5a2f71838..6f402dd67 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -39,6 +39,7 @@ "implicit_midpoint", "calls_per_integrator_step", "name_integrator", + "integrator_order" ] @@ -464,7 +465,7 @@ def calls_per_integrator_step(c): return 5 else: - raise Exception + raise Exception("No such integrator exists in blackjax") def name_integrator(c): @@ -478,8 +479,16 @@ def name_integrator(c): return "omelyan" else: - raise Exception + raise Exception("No such integrator exists in blackjax") +def integrator_order(c): + if c==velocity_verlet_coefficients: return 2 + if c==mclachlan_coefficients: return 2 + if c==yoshida_coefficients: return 4 + if c==omelyan_coefficients: return 4 + + + else: raise Exception("No such integrator exists in blackjax") FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], From 48fa9b195a0ed55b8e1f86d7cc340f8548fc86b3 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 22:25:39 +0200 Subject: [PATCH 27/71] ADD ADJUSTED MCLMC TUNING --- blackjax/adaptation/mclmc_adaptation.py | 335 ++++++++++++++++++ blackjax/benchmarks/mcmc/benchmark.py | 74 ++-- blackjax/benchmarks/mcmc/inference_models.py | 20 +- .../benchmarks/mcmc/sampling_algorithms.py | 19 +- blackjax/mcmc/adjusted_mclmc.py | 6 +- 5 files changed, 395 insertions(+), 59 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index dc33eb21c..476ea6654 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -19,6 +19,7 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree +from blackjax.adaptation.step_size import DualAveragingAdaptationState, dual_averaging_adaptation from blackjax.diagnostics import effective_sample_size from blackjax.util import pytree_size, streaming_average @@ -281,6 +282,340 @@ def step(state, key): return adaptation_L + +Lratio_lowerbound = 0.0 +Lratio_upperbound = 2. + + +def adjusted_mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + target, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=True, + params=None +): + """ + Finds the optimal value of the parameters for the MH-MCHMC algorithm. + + Parameters + ---------- + mclmc_kernel + The kernel function used for the MCMC algorithm. + num_steps + The number of MCMC steps that will subsequently be run, after tuning. + state + The initial state of the MCMC algorithm. + rng_key + The random number generator key. + target + The target acceptance rate for the step size adaptation. + frac_tune1 + The fraction of tuning for the first step of the adaptation. + frac_tune2 + The fraction of tuning for the second step of the adaptation. + frac_tune3 + The fraction of tuning for the third step of the adaptation. + desired_energy_va + The desired energy variance for the MCMC algorithm. + trust_in_estimate + The trust in the estimate of optimal stepsize. + num_effective_samples + The number of effective samples for the MCMC algorithm. + + Returns + ------- + A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + """ + + dim = pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, std_mat=jnp.ones((dim,))) + else: + params = params + # jax.debug.print("initial params {x}", x=params) + part1_key, part2_key = jax.random.split(rng_key, 2) + + state, params, params_history, final_da_val = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + target=target, + diagonal_preconditioning=diagonal_preconditioning + )(state, params, num_steps, part1_key) + + if frac_tune3 != 0: + + part2_key1, part2_key2 = jax.random.split(part2_key, 2) + + state, params = adjusted_mclmc_make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( + state, params, num_steps, part2_key1 + ) + + state, params, params_history, final_da_val = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning + )(state, params, num_steps, part2_key2) + + return state, params, params_history, final_da_val + + +def adjusted_mclmc_make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + target, + diagonal_preconditioning, + fix_L_first_da=False, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + + + + + def dual_avg_step(fix_L, update_da): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + def step(iteration_state, weight_and_key): + + mask, rng_key = weight_and_key + kernel_key, num_steps_key = jax.random.split(rng_key, 2) + previous_state, params, (adaptive_state, step_size_max), streaming_avg = iteration_state + + avg_num_integration_steps = params.L/params.step_size + + state, info = kernel( + rng_key=kernel_key, + state=previous_state, + avg_num_integration_steps=avg_num_integration_steps, + step_size=params.step_size, + std_mat=params.std_mat, + ) + + # jax.debug.print("step size during {x}",x=(params.step_size, params.L)) + + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + state, + params.step_size, + step_size_max, + info.energy, + ) + + # jax.debug.print("info acc rate {x}", x=(info.acceptance_rate,)) + # jax.debug.print("state {x}", x=(state.position,)) + + + log_step_size, log_step_size_avg, step, avg_error, mu = update_da( + adaptive_state, info.acceptance_rate) + + adaptive_state = DualAveragingAdaptationState( + mask * log_step_size + (1-mask)*adaptive_state.log_step_size, + mask * log_step_size_avg + (1-mask)*adaptive_state.log_step_size_avg, + mask * step + (1-mask)*adaptive_state.step, + mask * avg_error + (1-mask)*adaptive_state.avg_error, + mask * mu + (1-mask)*adaptive_state.mu, + ) + + # jax.debug.print("{x} step_size before",x=(adaptive_state.log_step_size, info.acceptance_rate,)) + # adaptive_state = update(adaptive_state, info.acceptance_rate) + # jax.debug.print("{x} step_size after",x=(adaptive_state.log_step_size,)) + + + # step_size = jax.lax.clamp(1e-3, jnp.exp(adaptive_state.log_step_size), 1e0) + # step_size = jax.lax.clamp(1e-5, jnp.exp(adaptive_state.log_step_size), step_size_max) + step_size = jax.lax.clamp(1e-5, jnp.exp(adaptive_state.log_step_size), params.L/1.1) + adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) + # step_size = 1e-3 + + # update the running average of x, x^2 + streaming_avg = streaming_average( + O=lambda x: jnp.array([x, jnp.square(x)]), + x=ravel_pytree(state.position)[0], + streaming_avg=streaming_avg, + weight=(1-mask)*success*step_size, + zero_prevention=mask, + ) + + + + + if fix_L: + params = params._replace( + step_size=mask * step_size + (1-mask)*params.step_size, + + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + ) + + else: + + params = params._replace( + step_size=mask * step_size + (1-mask)*params.step_size, + + L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + ) + # params = params._replace(step_size=step_size, + # L=(params.L/params.step_size * step_size) + # ) + + + return (state, params, (adaptive_state, step_size_max), streaming_avg), (info, params) + return step + + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): + + return jax.lax.scan( + dual_avg_step(fix_L, update_da), + init=( + state, + params, + (initial_da(params.step_size), jnp.inf), # step size max + # (init(params.step_size), params.L/4), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, keys), + ) + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + + # num_steps2=0 + + rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) + L_step_size_adaptation_keys_pass1 = jax.random.split(rng_key_pass1, num_steps1 + num_steps2) + L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) + + # determine which steps to ignore in the streaming average + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + + # jax.debug.print("{x} initial num steps",x=(params.L/params.step_size)) + + ((state, params, (dual_avg_state, step_size_max), (_, average)), (info, params_history)) = step_size_adaptation(mask, state, params, L_step_size_adaptation_keys_pass1, fix_L=fix_L_first_da, initial_da=initial_da, update_da=update_da) + # jax.debug.print("final da {x}", x=final_da(dual_avg_state)) + # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) + # params = params._replace(step_size=final_da(dual_avg_state)) + + # jax.debug.print("{x} new num steps",x=(params.L/params.step_size)) + + + # jax.debug.print("{x} mean acceptance rate",x=((jnp.mean(info.acceptance_rate)))) + + + # jax.debug.print("{x} params after a round of tuning",x=(params)) + # jax.debug.print("{x} step size max",x=(step_size_max)) + # jax.debug.print("{x} final",x=(final(dual_avg_state))) + # jax.debug.print("{x} params",x=(params)) + + # raise Exception + + # determine L + if num_steps2 != 0.0: + # if False: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + # jax.debug.print("{x} frac tune 2 guess",x=(jnp.sqrt(jnp.sum(variances)))) + # jax.debug.print("{x} frac tune 2 before",x=(params.L)) + + + change = jax.lax.clamp(Lratio_lowerbound, jnp.sqrt(jnp.sum(variances))/params.L, Lratio_upperbound) + # change = jnp.sqrt(jnp.sum(variances))/params.L + # jax.debug.print("{x} L ratio, old val, new val",x=(change, params.L, params.L*change)) + # jax.debug.print("{x} variance",x=(jnp.sqrt(jnp.sum(variances)))) + params = params._replace(L=params.L*change, step_size=params.step_size*change) + # params = params._replace(L=16.) + # params = params._replace(L=jnp.sqrt(jnp.sum(variances))) + # jax.debug.print("{x} params after a round of tuning",x=(params)) + + if diagonal_preconditioning: + + # diagonal preconditioning + params = params._replace(std_mat=jnp.sqrt(variances)) + + # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] + # dyn, _, hyp, adap, kalman_state = state + + + # jax.debug.print("{x} params before second round",x=(params)) + # jax.debug.print("{x}",x=("L before", params.L)) + # jax.debug.print("{x}",x=("target", target)) + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + ((state, params, (dual_avg_state, step_size_max), (_, average)), (info, params_history)) = step_size_adaptation(jnp.ones(num_steps1), state, params, L_step_size_adaptation_keys_pass2, fix_L=True, update_da=update_da, initial_da=initial_da) + # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) + # params = params._replace(step_size=final_da(dual_avg_state)) + # jax.debug.print("{x} mean acceptance rate 2",x=(jnp.mean(info.acceptance_rate,))) + # jax.debug.print("{x}",x=("L after", params.L)) + # jax.debug.print("{x} params after a round of tuning",x=(params)) + + return state, params, params_history.step_size, final_da(dual_avg_state) + + return L_step_size_adaptation + + + +def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + # jax.debug.print("tune 1\n\n {x}", x=(params.L, params.step_size)) + + + def step(state, key): + next_state, _ = kernel( + rng_key=key, + state=state, + step_size=params.step_size, + avg_num_integration_steps=params.L/params.step_size, + std_mat=params.std_mat, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan( + f=step, + init=state, + xs=adaptation_L_keys, + ) + + + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + + change = jax.lax.clamp(Lratio_lowerbound, (Lfactor * params.step_size * jnp.mean(num_steps / ess))/params.L, Lratio_upperbound) + # change = (Lfactor * params.step_size * jnp.mean(num_steps / ess))/params.L + + # jax.debug.print("tune 3\n\n {x}", x=(params.L*change, change)) + return state, params._replace( + # L=Lfactor * params.step_size * jnp.mean(num_steps / ess) + L=params.L*change + ) + + return adaptation_L + + + + + + def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py index 174cd30f7..fd1e8a2a8 100644 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -24,10 +24,10 @@ import numpy as np import blackjax -from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers +from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_adjusted_mclmc, run_nuts, samplers from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients -# from blackjax.mcmc.mhmclmc import rescale +# from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.util import run_inference_algorithm target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} @@ -117,14 +117,14 @@ def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps return center_L, center_step_size, converged -def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): +def run_adjusted_mclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): def s(logdensity_fn, num_steps, initial_position, transform, key): integrator = generate_isokinetic_integrator(coefficients) num_steps_per_traj = L/step_size - alg = blackjax.mcmc.mhmclmc.mhmclmc( + alg = blackjax.mcmc.adjusted_mclmc.adjusted_mclmc( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , @@ -194,8 +194,8 @@ def run_benchmarks(batch_size): results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], @@ -251,8 +251,8 @@ def run_simple(): results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], @@ -264,7 +264,7 @@ def run_simple(): sampler, model, coefficients = variables num_chains = 128 - num_steps = 10000 + num_steps = 40000 contract = jnp.max @@ -285,8 +285,8 @@ def run_benchmarks_step_size(batch_size): results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [StandardNormal(10)], # [Brownian()], @@ -312,9 +312,9 @@ def run_benchmarks_step_size(batch_size): # for L in np.linspace(1, 10, 41): key1, key2, key3, key = jax.random.split(key, 4) initial_position = model.sample_init(key2) - initial_state = blackjax.mcmc.mhmclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc.init( position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) @@ -351,8 +351,8 @@ def benchmark_mhmchmc(batch_size): for model, coeffs in itertools.product(models, coefficients): num_chains = batch_size # 1 + batch_size//model.ndims - print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") - num_steps = models[model]["mhmclmc"] + print(f"NUMBER OF CHAINS for {model.name} and adjusted_mclmc is {num_chains}") + num_steps = models[model]["adjusted_mclmc"] print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") ####### run mclmc with standard tuning @@ -371,41 +371,41 @@ def benchmark_mhmchmc(batch_size): print(f'mclmc with tuning ESS {ess}') - ####### run mhmclmc with standard tuning + ####### run adjusted_mclmc with standard tuning for target_acc_rate in [0.65, 0.9]: # coeffs = mclachlan_coefficients ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), + partial(run_adjusted_mclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), key1, n=num_steps, batch=num_chains, contract=contract) results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') + print(f'adjusted_mclmc with tuning ESS {ess}') # coeffs = mclachlan_coefficients ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), + partial(run_adjusted_mclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), key1, n=num_steps, batch=num_chains, contract=contract) results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') + print(f'adjusted_mclmc with tuning ESS {ess}') if True: - ####### run mhmclmc with standard tuning + grid search + ####### run adjusted_mclmc with standard tuning + grid search init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) initial_position = model.sample_init(init_pos_key) - initial_state = blackjax.mcmc.mhmclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc.init( position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=generate_isokinetic_integrator(coeffs), integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), std_mat=std_mat, @@ -417,9 +417,9 @@ def benchmark_mhmchmc(batch_size): ( state, - blackjax_mhmclmc_sampler_params, + blackjax_adjusted_mclmc_sampler_params, _, _ - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, @@ -432,14 +432,14 @@ def benchmark_mhmchmc(batch_size): ) print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + print(f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}") - L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_adjusted_mclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_adjusted_mclmc_sampler_params.L, center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size) # print(f"params after grid tuning are L={L}, step_size={step_size}") - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) print(f"grads to low bias: {grad_calls}") @@ -474,7 +474,7 @@ def benchmark_omelyan(batch_size): key = jax.random.PRNGKey(2) results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], + # ["adjusted_mclmc", "nuts", "mclmc", ], ["mhmchmc"], [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], @@ -499,12 +499,12 @@ def benchmark_omelyan(batch_size): initial_position = model.sample_init(init_pos_key) - initial_state = blackjax.mcmc.mhmclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc.init( position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=generate_isokinetic_integrator(coefficients), integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), std_mat=std_mat, @@ -516,9 +516,9 @@ def benchmark_omelyan(batch_size): ( state, - blackjax_mhmclmc_sampler_params, + blackjax_adjusted_mclmc_sampler_params, _, _ - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, @@ -531,17 +531,17 @@ def benchmark_omelyan(batch_size): ) print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + print(f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}") - # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) + # ess, grad_calls, _ , _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_adjusted_mclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_adjusted_mclmc_sampler_params.L, center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size) print(f"params after grid tuning are L={L}, step_size={step_size}") - ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) + ess, grad_calls, _ , _, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) print(f"grads to low bias: {grad_calls}") diff --git a/blackjax/benchmarks/mcmc/inference_models.py b/blackjax/benchmarks/mcmc/inference_models.py index b918ce3bf..b3f87e9ea 100644 --- a/blackjax/benchmarks/mcmc/inference_models.py +++ b/blackjax/benchmarks/mcmc/inference_models.py @@ -874,19 +874,19 @@ def step(track, useless): models = { - # Cauchy(100) : {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}, - # StandardNormal(100) : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - # Banana() : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - Brownian() : {'mclmc': 20000, 'mhmclmc' : 80000, 'nuts': 40000}, + # Cauchy(100) : {'mclmc': 2000, 'adjusted_mclmc' : 2000, 'nuts': 2000}, + # StandardNormal(100) : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, + # Banana() : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, + Brownian() : {'mclmc': 20000, 'adjusted_mclmc' : 80000, 'nuts': 40000}, # 'banana': Banana(), - # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}), - # GermanCredit(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, - # ItemResponseTheory(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, - # StochasticVolatility(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000} + # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'adjusted_mclmc' : 2000, 'nuts': 2000}), + # GermanCredit(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000}, + # ItemResponseTheory(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000}, + # StochasticVolatility(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000} } -# models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'mhmclmc' : 40000, 'nuts': 1000}), -# # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'mhmclmc' : 50000, 'nuts': 1000}) +# models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'adjusted_mclmc' : 40000, 'nuts': 1000}), +# # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'adjusted_mclmc' : 50000, 'nuts': 1000}) # } \ No newline at end of file diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py index 617797b79..bbd7e3e59 100644 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -6,7 +6,8 @@ import jax.numpy as jnp import blackjax # from blackjax.adaptation.window_adaptation import da_adaptation -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator +from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order # from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.util import run_inference_algorithm import blackjax @@ -14,7 +15,7 @@ __all__ = ["samplers"] - +target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} def run_nuts( coefficients, logdensity_fn, num_steps, initial_position, transform, key): @@ -104,16 +105,16 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor return samples, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients), acceptance_rate, None, None -def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, target_acc_rate=None): +def run_adjusted_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, target_acc_rate=None): integrator = generate_isokinetic_integrator(coefficients) init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.mhmclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), std_mat=std_mat, @@ -132,7 +133,7 @@ def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transf blackjax_mclmc_sampler_params, params_history, final_da - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( + ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, @@ -151,7 +152,7 @@ def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transf # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.step_size, blackjax_mclmc_sampler_params.L)) - alg = blackjax.mcmc.mhmclmc.mhmclmc( + alg = blackjax.adjusted_mclmc( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn = lambda key: jnp.ceil(jax.random.uniform(key) * rescale(L/step_size)) , @@ -174,12 +175,12 @@ def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transf return out, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients) * (L/step_size), info.acceptance_rate, params_history, final_da -# we should do at least: mclmc, nuts, unadjusted hmc, mhmclmc, langevin +# we should do at least: mclmc, nuts, unadjusted hmc, adjusted_mclmc, langevin samplers = { 'nuts' : run_nuts, 'mclmc' : run_mclmc, - 'mhmclmc': run_mhmclmc, + 'adjusted_mclmc': run_adjusted_mclmc, } diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 86500ec84..aef0b3f57 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -30,7 +30,7 @@ __all__ = [ "init", "build_kernel", - "mhmclmc", + "adjusted_mclmc", ] def init( @@ -84,7 +84,7 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) - proposal, info, _ = mhmclmc_proposal( + proposal, info, _ = adjusted_mclmc_proposal( # integrators.with_isokinetic_maruyama(integrator(logdensity_fn)), lambda state, step_size, L_prop, key : (integrator(logdensity_fn, std_mat))(state, step_size), step_size, @@ -170,7 +170,7 @@ def init_fn(position: ArrayLike, rng_key: PRNGKey): return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] -def mhmclmc_proposal( +def adjusted_mclmc_proposal( integrator: Callable, step_size: Union[float, ArrayLikeTree], L_proposal: float, From 4330d1897f2ffe1deca59802220d2814745a665b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 22:27:10 +0200 Subject: [PATCH 28/71] CLEAN --- blackjax/__init__.py | 2 +- blackjax/adaptation/mclmc_adaptation.py | 220 +++-- blackjax/benchmarks/mcmc/benchmark.py | 815 +++++++++++++----- blackjax/benchmarks/mcmc/inference_models.py | 753 +++++++++------- .../benchmarks/mcmc/sampling_algorithms.py | 142 +-- blackjax/mcmc/adjusted_mclmc.py | 83 +- blackjax/mcmc/integrators.py | 20 +- 7 files changed, 1306 insertions(+), 729 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6afd9454a..4b23614f5 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -11,6 +11,7 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -19,7 +20,6 @@ from .mcmc import mala as _mala from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc -from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 476ea6654..76c481c0a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -19,7 +19,10 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from blackjax.adaptation.step_size import DualAveragingAdaptationState, dual_averaging_adaptation +from blackjax.adaptation.step_size import ( + DualAveragingAdaptationState, + dual_averaging_adaptation, +) from blackjax.diagnostics import effective_sample_size from blackjax.util import pytree_size, streaming_average @@ -282,9 +285,8 @@ def step(state, key): return adaptation_L - Lratio_lowerbound = 0.0 -Lratio_upperbound = 2. +Lratio_upperbound = 2.0 def adjusted_mclmc_find_L_and_step_size( @@ -297,7 +299,7 @@ def adjusted_mclmc_find_L_and_step_size( frac_tune2=0.1, frac_tune3=0.1, diagonal_preconditioning=True, - params=None + params=None, ): """ Finds the optimal value of the parameters for the MH-MCHMC algorithm. @@ -334,39 +336,54 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: - params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, std_mat=jnp.ones((dim,))) + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, std_mat=jnp.ones((dim,)) + ) else: params = params # jax.debug.print("initial params {x}", x=params) part1_key, part2_key = jax.random.split(rng_key, 2) - state, params, params_history, final_da_val = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + params_history, + final_da_val, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, target=target, - diagonal_preconditioning=diagonal_preconditioning - )(state, params, num_steps, part1_key) + diagonal_preconditioning=diagonal_preconditioning, + )( + state, params, num_steps, part1_key + ) if frac_tune3 != 0: - part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params = adjusted_mclmc_make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( - state, params, num_steps, part2_key1 + state, params = adjusted_mclmc_make_adaptation_L( + mclmc_kernel, frac=frac_tune3, Lfactor=0.4 + )(state, params, num_steps, part2_key1) + + ( + state, + params, + params_history, + final_da_val, + ) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning, + )( + state, params, num_steps, part2_key2 ) - state, params, params_history, final_da_val = adjusted_mclmc_make_L_step_size_adaptation( - kernel=mclmc_kernel, - dim=dim, - frac_tune1=frac_tune1, - frac_tune2=0, - target=target, - fix_L_first_da=True, - diagonal_preconditioning=diagonal_preconditioning - )(state, params, num_steps, part2_key2) - return state, params, params_history, final_da_val @@ -381,19 +398,20 @@ def adjusted_mclmc_make_L_step_size_adaptation( ): """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" - - - def dual_avg_step(fix_L, update_da): """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" def step(iteration_state, weight_and_key): - mask, rng_key = weight_and_key kernel_key, num_steps_key = jax.random.split(rng_key, 2) - previous_state, params, (adaptive_state, step_size_max), streaming_avg = iteration_state + ( + previous_state, + params, + (adaptive_state, step_size_max), + streaming_avg, + ) = iteration_state - avg_num_integration_steps = params.L/params.step_size + avg_num_integration_steps = params.L / params.step_size state, info = kernel( rng_key=kernel_key, @@ -417,26 +435,28 @@ def step(iteration_state, weight_and_key): # jax.debug.print("info acc rate {x}", x=(info.acceptance_rate,)) # jax.debug.print("state {x}", x=(state.position,)) - log_step_size, log_step_size_avg, step, avg_error, mu = update_da( - adaptive_state, info.acceptance_rate) - + adaptive_state, info.acceptance_rate + ) + adaptive_state = DualAveragingAdaptationState( - mask * log_step_size + (1-mask)*adaptive_state.log_step_size, - mask * log_step_size_avg + (1-mask)*adaptive_state.log_step_size_avg, - mask * step + (1-mask)*adaptive_state.step, - mask * avg_error + (1-mask)*adaptive_state.avg_error, - mask * mu + (1-mask)*adaptive_state.mu, - ) + mask * log_step_size + (1 - mask) * adaptive_state.log_step_size, + mask * log_step_size_avg + + (1 - mask) * adaptive_state.log_step_size_avg, + mask * step + (1 - mask) * adaptive_state.step, + mask * avg_error + (1 - mask) * adaptive_state.avg_error, + mask * mu + (1 - mask) * adaptive_state.mu, + ) # jax.debug.print("{x} step_size before",x=(adaptive_state.log_step_size, info.acceptance_rate,)) # adaptive_state = update(adaptive_state, info.acceptance_rate) # jax.debug.print("{x} step_size after",x=(adaptive_state.log_step_size,)) - # step_size = jax.lax.clamp(1e-3, jnp.exp(adaptive_state.log_step_size), 1e0) # step_size = jax.lax.clamp(1e-5, jnp.exp(adaptive_state.log_step_size), step_size_max) - step_size = jax.lax.clamp(1e-5, jnp.exp(adaptive_state.log_step_size), params.L/1.1) + step_size = jax.lax.clamp( + 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 + ) adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) # step_size = 1e-3 @@ -445,50 +465,47 @@ def step(iteration_state, weight_and_key): O=lambda x: jnp.array([x, jnp.square(x)]), x=ravel_pytree(state.position)[0], streaming_avg=streaming_avg, - weight=(1-mask)*success*step_size, + weight=(1 - mask) * success * step_size, zero_prevention=mask, ) - - - if fix_L: params = params._replace( - step_size=mask * step_size + (1-mask)*params.step_size, + step_size=mask * step_size + (1 - mask) * params.step_size, + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + ) - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L - ) - else: - params = params._replace( - step_size=mask * step_size + (1-mask)*params.step_size, - - L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L - ) - # params = params._replace(step_size=step_size, + step_size=mask * step_size + (1 - mask) * params.step_size, + L=mask * (params.L * (step_size / params.step_size)) + + (1 - mask) * params.L + # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + ) + # params = params._replace(step_size=step_size, # L=(params.L/params.step_size * step_size) # ) + return (state, params, (adaptive_state, step_size_max), streaming_avg), ( + info, + params, + ) - return (state, params, (adaptive_state, step_size_max), streaming_avg), (info, params) return step - + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): - return jax.lax.scan( dual_avg_step(fix_L, update_da), init=( state, params, - (initial_da(params.step_size), jnp.inf), # step size max + (initial_da(params.step_size), jnp.inf), # step size max # (init(params.step_size), params.L/4), (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), ), xs=(mask, keys), - ) + ) def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( @@ -498,7 +515,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): # num_steps2=0 rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) - L_step_size_adaptation_keys_pass1 = jax.random.split(rng_key_pass1, num_steps1 + num_steps2) + L_step_size_adaptation_keys_pass1 = jax.random.split( + rng_key_pass1, num_steps1 + num_steps2 + ) L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) # determine which steps to ignore in the streaming average @@ -508,17 +527,26 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): # jax.debug.print("{x} initial num steps",x=(params.L/params.step_size)) - ((state, params, (dual_avg_state, step_size_max), (_, average)), (info, params_history)) = step_size_adaptation(mask, state, params, L_step_size_adaptation_keys_pass1, fix_L=fix_L_first_da, initial_da=initial_da, update_da=update_da) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + mask, + state, + params, + L_step_size_adaptation_keys_pass1, + fix_L=fix_L_first_da, + initial_da=initial_da, + update_da=update_da, + ) # jax.debug.print("final da {x}", x=final_da(dual_avg_state)) - # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) - # params = params._replace(step_size=final_da(dual_avg_state)) + # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) + # params = params._replace(step_size=final_da(dual_avg_state)) # jax.debug.print("{x} new num steps",x=(params.L/params.step_size)) - # jax.debug.print("{x} mean acceptance rate",x=((jnp.mean(info.acceptance_rate)))) - # jax.debug.print("{x} params after a round of tuning",x=(params)) # jax.debug.print("{x} step size max",x=(step_size_max)) # jax.debug.print("{x} final",x=(final(dual_avg_state))) @@ -528,46 +556,59 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): # determine L if num_steps2 != 0.0: - # if False: + # if False: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) # jax.debug.print("{x} frac tune 2 guess",x=(jnp.sqrt(jnp.sum(variances)))) # jax.debug.print("{x} frac tune 2 before",x=(params.L)) - - - change = jax.lax.clamp(Lratio_lowerbound, jnp.sqrt(jnp.sum(variances))/params.L, Lratio_upperbound) + + change = jax.lax.clamp( + Lratio_lowerbound, + jnp.sqrt(jnp.sum(variances)) / params.L, + Lratio_upperbound, + ) # change = jnp.sqrt(jnp.sum(variances))/params.L # jax.debug.print("{x} L ratio, old val, new val",x=(change, params.L, params.L*change)) # jax.debug.print("{x} variance",x=(jnp.sqrt(jnp.sum(variances)))) - params = params._replace(L=params.L*change, step_size=params.step_size*change) + params = params._replace( + L=params.L * change, step_size=params.step_size * change + ) # params = params._replace(L=16.) - # params = params._replace(L=jnp.sqrt(jnp.sum(variances))) + # params = params._replace(L=jnp.sqrt(jnp.sum(variances))) # jax.debug.print("{x} params after a round of tuning",x=(params)) if diagonal_preconditioning: + # diagonal preconditioning + params = params._replace(std_mat=jnp.sqrt(variances)) - # diagonal preconditioning - params = params._replace(std_mat=jnp.sqrt(variances)) + # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] + # dyn, _, hyp, adap, kalman_state = state - # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] - # dyn, _, hyp, adap, kalman_state = state - - # jax.debug.print("{x} params before second round",x=(params)) # jax.debug.print("{x}",x=("L before", params.L)) # jax.debug.print("{x}",x=("target", target)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) - ((state, params, (dual_avg_state, step_size_max), (_, average)), (info, params_history)) = step_size_adaptation(jnp.ones(num_steps1), state, params, L_step_size_adaptation_keys_pass2, fix_L=True, update_da=update_da, initial_da=initial_da) - # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + jnp.ones(num_steps1), + state, + params, + L_step_size_adaptation_keys_pass2, + fix_L=True, + update_da=update_da, + initial_da=initial_da, + ) + # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) # params = params._replace(step_size=final_da(dual_avg_state)) # jax.debug.print("{x} mean acceptance rate 2",x=(jnp.mean(info.acceptance_rate,))) # jax.debug.print("{x}",x=("L after", params.L)) # jax.debug.print("{x} params after a round of tuning",x=(params)) return state, params, params_history.step_size, final_da(dual_avg_state) - - return L_step_size_adaptation + return L_step_size_adaptation def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor): @@ -579,13 +620,12 @@ def adaptation_L(state, params, num_steps, key): # jax.debug.print("tune 1\n\n {x}", x=(params.L, params.step_size)) - def step(state, key): next_state, _ = kernel( rng_key=key, state=state, step_size=params.step_size, - avg_num_integration_steps=params.L/params.step_size, + avg_num_integration_steps=params.L / params.step_size, std_mat=params.std_mat, ) return next_state, next_state.position @@ -596,26 +636,26 @@ def step(state, key): xs=adaptation_L_keys, ) - flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) - change = jax.lax.clamp(Lratio_lowerbound, (Lfactor * params.step_size * jnp.mean(num_steps / ess))/params.L, Lratio_upperbound) + change = jax.lax.clamp( + Lratio_lowerbound, + (Lfactor * params.step_size * jnp.mean(num_steps / ess)) / params.L, + Lratio_upperbound, + ) # change = (Lfactor * params.step_size * jnp.mean(num_steps / ess))/params.L # jax.debug.print("tune 3\n\n {x}", x=(params.L*change, change)) return state, params._replace( # L=Lfactor * params.step_size * jnp.mean(num_steps / ess) - L=params.L*change + L=params.L + * change ) return adaptation_L - - - - def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py index fd1e8a2a8..062e1f2e3 100644 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -1,13 +1,14 @@ # mypy: ignore-errors # flake8: noqa -from collections import defaultdict -from functools import partial import math import operator import os import pprint +from collections import defaultdict +from functools import partial from statistics import mean, median + import jax import jax.numpy as jnp import pandas as pd @@ -15,7 +16,7 @@ from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) num_cores = jax.local_device_count() # print(num_cores, jax.lib.xla_bridge.get_backend().platform) @@ -24,48 +25,81 @@ import numpy as np import blackjax -from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_adjusted_mclmc, run_nuts, samplers -from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients +from blackjax.benchmarks.mcmc.inference_models import ( + Brownian, + GermanCredit, + ItemResponseTheory, + MixedLogit, + StandardNormal, + StochasticVolatility, + models, +) +from blackjax.benchmarks.mcmc.sampling_algorithms import ( + run_adjusted_mclmc, + run_mclmc, + run_nuts, + samplers, +) +from blackjax.mcmc.integrators import ( + calls_per_integrator_step, + generate_euclidean_integrator, + generate_isokinetic_integrator, + integrator_order, + isokinetic_mclachlan, + mclachlan_coefficients, + name_integrator, + omelyan_coefficients, + velocity_verlet, + velocity_verlet_coefficients, + yoshida_coefficients, +) + # from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.util import run_inference_algorithm -target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} +target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} + def get_num_latents(target): - return target.ndims + return target.ndims + + # return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) def err(f_true, var_f, contract): """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) + Args: + f: E_sampler[f(x)], can be a vector + f_true: E_true[f(x)] + var_f: Var_true[f(x)] + contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max + Returns: + contract(b^2) + """ + + return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) -def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): +def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - + b^2 = 1/neff""" + cutoff_reached = err_t[-1] < low_error return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff= 100): - - grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) - - return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached + + +def calculate_ess(err_t, grad_evals_per_step, neff=100): + grads_to_low, cutoff_reached = grads_to_low_error( + err_t, grad_evals_per_step, 1.0 / neff + ) + + return ( + (neff / grads_to_low) * cutoff_reached, + grads_to_low * (1 / cutoff_reached), + cutoff_reached, + ) def find_crossing(array, cutoff): @@ -77,34 +111,61 @@ def find_crossing(array, cutoff): print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) return 1 - return jnp.max(indices)+1 + return jnp.max(indices) + 1 def cumulative_avg(samples): - return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): + return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] + + +def gridsearch_tune( + key, + iterations, + grid_size, + model, + sampler, + batch, + num_steps, + center_L, + center_step_size, + contract, +): results = defaultdict(float) converged = False - keys = jax.random.split(key, iterations+1) + keys = jax.random.split(key, iterations + 1) for i in range(iterations): print(f"EPOCH {i}") width = 2 - step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) - Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) + step_sizes = np.logspace( + np.log10(center_step_size / width), + np.log10(center_step_size * width), + grid_size, + ) + Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) # print(list(itertools.product(step_sizes , Ls))) - grid_keys = jax.random.split(keys[i], grid_size^2) + grid_keys = jax.random.split(keys[i], grid_size ^ 2) print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): - ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) + for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): + ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( + model, + sampler(step_size=step_size, L=L), + grid_keys[j], + n=num_steps, + batch=batch, + contract=contract, + ) results[(step_size, L)] = (ess, grad_calls_until_convergence) - best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) + best_ess, best_grads, (step_size, L) = max( + ((results[r][0], results[r][1], r) for r in results), + key=operator.itemgetter(0), + ) # raise Exception - print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - if L==center_L and step_size==center_step_size: + print( + f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" + ) + if L == center_L and step_size == center_step_size: print("converged") converged = True break @@ -112,42 +173,57 @@ def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps center_L, center_step_size = L, step_size pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") + # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") return center_L, center_step_size, converged def run_adjusted_mclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - def s(logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - num_steps_per_traj = L/step_size + num_steps_per_traj = L / step_size alg = blackjax.mcmc.adjusted_mclmc.adjusted_mclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , - integrator=integrator, - std_mat=std_mat, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(num_steps_per_traj) + ), + integrator=integrator, + std_mat=std_mat, ) _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) + rng_key=key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) - return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) + return ( + out, + MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), + num_steps_per_traj * calls_per_integrator_step(coefficients), + info.acceptance_rate.mean(), + None, + jnp.array([0]), + ) return s -def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): +def benchmark_chains( + model, + sampler, + key, + n=10000, + batch=None, + contract=jnp.average, +): pvmap = jax.pmap - + d = get_num_latents(model) if batch is None: batch = np.ceil(1000 / d).astype(int) @@ -158,13 +234,31 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av init_pos = pvmap(model.sample_init)(init_keys) # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) + ( + samples, + params, + grad_calls_per_traj, + acceptance_rate, + step_size_over_da, + final_da, + ) = pvmap( + lambda pos, key: sampler( + logdensity_fn=model.logdensity_fn, + num_steps=n, + initial_position=pos, + transform=model.transform, + key=key, + ) + )( + init_pos, keys + ) avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) try: - print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) - except: pass - - full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) + print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) + except: + pass + + full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) err_t = pvmap(full)(samples**2) # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] @@ -174,7 +268,6 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # return(mean(esses), mean(grad_calls)) # print(final_da.mean(), "final da") - err_t_median = jnp.median(err_t, axis=0) # import matplotlib.pyplot as plt # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) @@ -184,62 +277,106 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # plt.yscale('log') # plt.savefig('brownian.png') # plt.close() - esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) - return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da - - + esses, grad_calls, _ = calculate_ess( + err_t_median, grad_evals_per_step=avg_grad_calls_per_traj + ) + return ( + esses, + grad_calls, + params, + jnp.mean(acceptance_rate, axis=0), + step_size_over_da, + ) def run_benchmarks(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["adjusted_mclmc", "nuts", "mclmc", ], - ["adjusted_mclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): sampler, model, coefficients = variables - num_chains = batch_size#1 + batch_size//model.ndims - + num_chains = batch_size # 1 + batch_size//model.ndims num_steps = 100000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.max key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) + ( + ess, + grad_calls, + params, + acceptance_rate, + step_size_over_da, + ) = benchmark_chains( + model, + partial( + samplers[sampler], + coefficients=coefficients, + frac_tune1=0.1, + frac_tune2=0.0, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) jax.numpy.save(f"acceptance.npy", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() print(ess.item()) # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_simple.csv", index=False) @@ -248,19 +385,17 @@ def run_benchmarks(batch_size): def run_simple(): - results = defaultdict(tuple) for variables in itertools.product( - # ["adjusted_mclmc", "nuts", "mclmc", ], - ["adjusted_mclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): sampler, model, coefficients = variables num_chains = 128 @@ -271,68 +406,127 @@ def run_simple(): key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + ( + ess, + grad_calls, + params, + acceptance_rate, + step_size_over_da, + ) = benchmark_chains( + model, + partial(samplers[sampler], coefficients=coefficients), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) - - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() print(ess.item()) - return results + # vary step_size def run_benchmarks_step_size(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["adjusted_mclmc", "nuts", "mclmc", ], - ["adjusted_mclmc"], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["adjusted_mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [StandardNormal(10)], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): num_steps = 10000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.average center = 6.534974 key = jax.random.PRNGKey(11) - for step_size in np.linspace(center-1,center+1, 41): - # for L in np.linspace(1, 10, 41): + for step_size in np.linspace(center - 1, center + 1, 41): + # for L in np.linspace(1, 10, 41): key1, key2, key3, key = jax.random.split(key, 4) initial_position = model.sample_init(key2) initial_state = blackjax.mcmc.adjusted_mclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=key3, + ) + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + run_adjusted_mclmc_no_tuning( + initial_state=initial_state, + coefficients=mclachlan_coefficients, + step_size=step_size, + L=5 * step_size, + std_mat=1.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_step_size.csv", index=False) @@ -340,17 +534,14 @@ def run_benchmarks_step_size(batch_size): return results - def benchmark_mhmchmc(batch_size): - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) results = defaultdict(tuple) # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] for model, coeffs in itertools.product(models, coefficients): - - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims print(f"NUMBER OF CHAINS for {model.name} and adjusted_mclmc is {num_chains}") num_steps = models[model]["adjusted_mclmc"] print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") @@ -358,67 +549,123 @@ def benchmark_mhmchmc(batch_size): ####### run mclmc with standard tuning contract = jnp.max - - ess, grad_calls, params , _, step_size_over_da = benchmark_chains( + ess, grad_calls, params, _, step_size_over_da = benchmark_chains( model, - partial(run_mclmc,coefficients=coeffs), + partial(run_mclmc, coefficients=coeffs), key0, n=num_steps, batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() - print(f'mclmc with tuning ESS {ess}') - + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mclmc", + params.L.mean().item(), + params.step_size.mean().item(), + name_integrator(coeffs), + "standard", + 1.0, + ) + ] = ess.item() + print(f"mclmc with tuning ESS {ess}") - ####### run adjusted_mclmc with standard tuning + ####### run adjusted_mclmc with standard tuning for target_acc_rate in [0.65, 0.9]: # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_adjusted_mclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'adjusted_mclmc with tuning ESS {ess}') - + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_adjusted_mclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"adjusted_mclmc with tuning ESS {ess}") + # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_adjusted_mclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'adjusted_mclmc with tuning ESS {ess}') + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_adjusted_mclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc:st3" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"adjusted_mclmc with tuning ESS {ess}") if True: ####### run adjusted_mclmc with standard tuning + grid search - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) + init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( + key2, 5 + ) initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.adjusted_mclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coeffs), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_adjusted_mclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -428,96 +675,165 @@ def benchmark_mhmchmc(batch_size): frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") - print(f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}") - + print( + f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" + ) + print( + f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}" + ) - L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_adjusted_mclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_adjusted_mclmc_sampler_params.L, center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size) + L, step_size, convergence = gridsearch_tune( + grid_key, + iterations=10, + contract=contract, + grid_size=5, + model=model, + sampler=partial( + run_adjusted_mclmc_no_tuning, + coefficients=coeffs, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_adjusted_mclmc_sampler_params.L, + center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size, + ) # print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + run_adjusted_mclmc_no_tuning( + coefficients=coeffs, + L=L, + step_size=step_size, + initial_state=state, + std_mat=1.0, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() + results[ + ( + model.name, + model.ndims, + "mhmchmc:grid", + L.item(), + step_size.item(), + name_integrator(coeffs), + f"gridsearch:{convergence}", + acceptance_rate.mean().item(), + ) + ] = ess.item() ####### run nuts # coeffs = velocity_verlet_coefficients - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) - results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - - - - - + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + partial(run_nuts, coefficients=coeffs), + key3, + n=models[model]["nuts"], + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "nuts", + 0.0, + 0.0, + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() - print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "L", + "step_size", + "integrator", + "tuning", + "acc_rate", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results.csv", index=False) return results -def benchmark_omelyan(batch_size): - +def benchmark_omelyan(batch_size): key = jax.random.PRNGKey(2) results = defaultdict(tuple) for variables in itertools.product( - # ["adjusted_mclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], + # ["adjusted_mclmc", "nuts", "mclmc", ], + ["mhmchmc"], + [ + StandardNormal(d) + for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) + ], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients, omelyan_coefficients], + ): sampler, model, coefficients = variables # num_chains = 1 + batch_size//model.ndims num_chains = batch_size - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) + current_key, key = jax.random.split(key) + init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( + current_key, 5 + ) # num_steps = models[model][sampler] num_steps = 1000 - initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.adjusted_mclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coefficients), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_adjusted_mclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -527,57 +843,112 @@ def benchmark_omelyan(batch_size): frac_tune1=0.1, frac_tune2=0.1, # frac_tune3=0.1, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - print(f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}") + print( + f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", + ) + print( + f"params after initial tuning are L={blackjax_adjusted_mclmc_sampler_params.L}, step_size={blackjax_adjusted_mclmc_sampler_params.step_size}" + ) # ess, grad_calls, _ , _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_adjusted_mclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_adjusted_mclmc_sampler_params.L, center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size) + # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) + + L, step_size, converged = gridsearch_tune( + grid_key, + iterations=10, + contract=jnp.average, + grid_size=5, + model=model, + sampler=partial( + run_adjusted_mclmc_no_tuning, + coefficients=coefficients, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_adjusted_mclmc_sampler_params.L, + center_step_size=blackjax_adjusted_mclmc_sampler_params.step_size, + ) print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , _, _ = benchmark_chains(model, run_adjusted_mclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) + ess, grad_calls, _, _, _ = benchmark_chains( + model, + run_adjusted_mclmc_no_tuning( + coefficients=coefficients, + L=L, + step_size=step_size, + std_mat=1.0, + initial_state=state, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=jnp.average, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() + results[ + ( + model.name, + model.ndims, + sampler, + name_integrator(coefficients), + converged, + L.item(), + step_size.item(), + ) + ] = ess.item() df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "integrator", + "convergence", + "L", + "step_size", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("omelyan.csv", index=False) def run_benchmarks_divij(): - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal + model = StandardNormal(10) # 10 dimensional standard normal coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions + contract = jnp.average # how we average across dimensions num_steps = 2000 num_chains = 100 key1 = jax.random.PRNGKey(2) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( + model, + partial(sampler, coefficients=coefficients), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") -if __name__ == "__main__": +if __name__ == "__main__": # run_benchmarks_divij() - - # benchmark_mhmchmc(batch_size=128) run_simple() # run_benchmarks_step_size(128) # benchmark_omelyan(128) # run_benchmarks(128) - #benchmark_omelyan(10) + # benchmark_omelyan(10) # print("4") diff --git a/blackjax/benchmarks/mcmc/inference_models.py b/blackjax/benchmarks/mcmc/inference_models.py index b3f87e9ea..2477e2560 100644 --- a/blackjax/benchmarks/mcmc/inference_models.py +++ b/blackjax/benchmarks/mcmc/inference_models.py @@ -1,51 +1,50 @@ # mypy: ignore-errors # flake8: noqa -#from inference_gym import using_jax as gym +import os + +# from inference_gym import using_jax as gym import jax import jax.numpy as jnp import numpy as np -import os -#import numpyro.distributions as dist -dirr = os.path.dirname(os.path.realpath(__file__)) +# import numpyro.distributions as dist +dirr = os.path.dirname(os.path.realpath(__file__)) -class StandardNormal(): +class StandardNormal: """Standard Normal distribution in d dimensions""" def __init__(self, d): self.ndims = d self.E_x2 = jnp.ones(d) self.Var_x2 = 2 * self.E_x2 - self.name = 'StandardNormal' - + self.name = "StandardNormal" def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x), axis= -1) - + return -0.5 * jnp.sum(jnp.square(x), axis=-1) def transform(self, x): return x def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) + return jax.random.normal(key, shape=(self.ndims,)) - -class IllConditionedGaussian(): +class IllConditionedGaussian: """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" - - def __init__(self, d, condition_number, numpy_seed=None, prior= 'prior'): + def __init__(self, d, condition_number, numpy_seed=None, prior="prior"): """numpy_seed is used to generate a random rotation for the covariance matrix. - If None, the covariance matrix is diagonal.""" + If None, the covariance matrix is diagonal.""" self.ndims = d - self.name = 'IllConditionedGaussian' + self.name = "IllConditionedGaussian" self.condition_number = condition_number - eigs = jnp.logspace(-0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d) + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d + ) if numpy_seed == None: # diagonal self.E_x2 = eigs @@ -57,268 +56,296 @@ def __init__(self, d, condition_number, numpy_seed=None, prior= 'prior'): rng = np.random.RandomState(seed=numpy_seed) D = jnp.diag(eigs) inv_D = jnp.diag(1 / eigs) - R, _ = jnp.array(np.linalg.qr(rng.randn(self.ndims, self.ndims))) # random rotation + R, _ = jnp.array( + np.linalg.qr(rng.randn(self.ndims, self.ndims)) + ) # random rotation self.R = R self.Hessian = R @ inv_D @ R.T self.Cov = R @ D @ R.T self.E_x2 = jnp.diagonal(R @ D @ R.T) - #Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) + # Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) - #print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) + # print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) self.Var_x2 = 2 * jnp.square(self.E_x2) - self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x self.transform = lambda x: x - - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.zeros(self.ndims) - elif prior == 'posterior': - self.sample_init = lambda key: self.R @ (jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs)) + elif prior == "posterior": + self.sample_init = lambda key: self.R @ ( + jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs) + ) - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(jnp.sqrt(eigs)) + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(jnp.sqrt(eigs)) - -class IllConditionedESH(): +class IllConditionedESH: """ICG from the ESH paper.""" def __init__(self): self.ndims = 50 - self.name = 'IllConditionedESH' + self.name = "IllConditionedESH" self.variance = jnp.linspace(0.01, 1, self.ndims) - - - def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis= -1) - + return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis=-1) def transform(self, x): return x def draw(self, key): - return jax.random.normal(key, shape = (self.ndims, )) * jnp.sqrt(self.variance) + return jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(self.variance) def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) - - + return jax.random.normal(key, shape=(self.ndims,)) -class IllConditionedGaussianGamma(): +class IllConditionedGaussianGamma: """Inference gym's Ill conditioned Gaussian""" - def __init__(self, prior = 'prior'): + def __init__(self, prior="prior"): self.ndims = 100 - self.name = 'IllConditionedGaussianGamma' + self.name = "IllConditionedGaussianGamma" # define the Hessian - rng = np.random.RandomState(seed=10 & (2 ** 32 - 1)) - eigs = np.sort(rng.gamma(shape=0.5, scale=1., size=self.ndims)) #eigenvalues of the Hessian - eigs *= jnp.average(1.0/eigs) + rng = np.random.RandomState(seed=10 & (2**32 - 1)) + eigs = np.sort( + rng.gamma(shape=0.5, scale=1.0, size=self.ndims) + ) # eigenvalues of the Hessian + eigs *= jnp.average(1.0 / eigs) self.entropy = 0.5 * self.ndims - self.maxmin = (1./jnp.sqrt(eigs[0]), 1./jnp.sqrt(eigs[-1])) - R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) #random rotation + self.maxmin = (1.0 / jnp.sqrt(eigs[0]), 1.0 / jnp.sqrt(eigs[-1])) + R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) # random rotation self.map_to_worst = (R.T)[[0, -1], :] self.Hessian = R @ np.diag(eigs) @ R.T # analytic ground truth moments - self.E_x2 = jnp.diagonal(R @ np.diag(1.0/eigs) @ R.T) + self.E_x2 = jnp.diagonal(R @ np.diag(1.0 / eigs) @ R.T) self.Var_x2 = 2 * jnp.square(self.E_x2) # norm = jnp.diag(1/jnp.sqrt(self.E_x2)) # Sigma = R @ np.diag(1/eigs) @ R.T # reduced = norm @ Sigma @ norm # print(np.linalg.cond(reduced), np.linalg.cond(Sigma)) - + # gradient - - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.zeros(self.ndims) - elif prior == 'posterior': - self.sample_init = lambda key: R @ (jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs)) + elif prior == "posterior": + self.sample_init = lambda key: R @ ( + jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs) + ) + + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(1.0 / jnp.sqrt(eigs)) - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(1.0/jnp.sqrt(eigs)) - def logdensity_fn(self, x): """- log p of the target distribution""" return -0.5 * x.T @ self.Hessian @ x def transform(self, x): return x - - -class Banana(): +class Banana: """Banana target fromm the Inference Gym""" - def __init__(self, prior = 'map'): + def __init__(self, prior="map"): self.curvature = 0.03 self.ndims = 2 - self.name = 'Banana' - + self.name = "Banana" + self.transform = lambda x: x - self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.E_x2 = jnp.array( + [100.0, 19.0] + ) # the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. self.Var_x2 = jnp.array([20000.0, 4600.898]) - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) - elif prior == 'posterior': + elif prior == "posterior": self.sample_init = lambda key: self.posterior_draw(key) - elif prior == 'prior': - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + elif prior == "prior": + self.sample_init = ( + lambda key: jax.random.normal(key, shape=(self.ndims,)) + * jnp.array([10.0, 5.0]) + * 2 + ) else: - raise ValueError('prior = '+prior +' is not defined.') + raise ValueError("prior = " + prior + " is not defined.") def logdensity_fn(self, x): mu2 = self.curvature * (x[0] ** 2 - 100) return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) def posterior_draw(self, key): - z = jax.random.normal(key, shape = (2, )) + z = jax.random.normal(key, shape=(2,)) x0 = 10.0 * z[0] - x1 = self.curvature * (x0 ** 2 - 100) + z[1] + x1 = self.curvature * (x0**2 - 100) + z[1] return jnp.array([x0, x1]) def ground_truth(self): - x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + x = jax.vmap(self.posterior_draw)( + jax.random.split(jax.random.PRNGKey(0), 100000000) + ) print(jnp.average(x, axis=0)) print(jnp.average(jnp.square(x), axis=0)) print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) - - -class Cauchy(): +class Cauchy: """d indpendent copies of the standard Cauchy distribution""" def __init__(self, d): self.ndims = d - self.name = 'Cauchy' - - self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1. + jnp.square(x))) - - self.transform = lambda x: x - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) - + self.name = "Cauchy" + self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1.0 + jnp.square(x))) + self.transform = lambda x: x + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) -class HardConvex(): - def __init__(self, d, kappa, theta = 0.1): +class HardConvex: + def __init__(self, d, kappa, theta=0.1): """d is the dimension, kappa = condition number, 0 < theta < 1/4""" self.ndims = d - self.name = 'HardConvex' + self.name = "HardConvex" self.theta, self.kappa = theta, kappa - C = jnp.power(d-1, 0.25 - theta) - self.logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) - (0.75 / kappa)* x[-1]**2 + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 - + C = jnp.power(d - 1, 0.25 - theta) + self.logdensity_fn = ( + lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) + - (0.75 / kappa) * x[-1] ** 2 + + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 + ) + self.transform = lambda x: x # numerically precomputed variances num_integration = [0.93295, 0.968802, 0.990595, 0.998002, 0.999819] if d == 100: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[0], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[0], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 300: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[1], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[1], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 1000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[2], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[2], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 3000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[3], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[3], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 10000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[4], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[4], jnp.ones(1) * 2.0 * kappa / 3.0) + ) else: None - def sample_init(self, key): """Gaussian prior with approximately estimating the variance along each dimension""" - scale = jnp.concatenate((jnp.ones(self.ndims-1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0))) + scale = jnp.concatenate( + (jnp.ones(self.ndims - 1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0)) + ) return jax.random.normal(key, shape=(self.ndims,)) * scale - - -class BiModal(): +class BiModal: """A Gaussian mixture p(x) = f N(x | mu1, sigma1) + (1-f) N(x | mu2, sigma2).""" - def __init__(self, d = 50, mu1 = 0.0, mu2 = 8.0, sigma1 = 1.0, sigma2 = 1.0, f = 0.2): - + def __init__(self, d=50, mu1=0.0, mu2=8.0, sigma1=1.0, sigma2=1.0, f=0.2): self.ndims = d - self.name = 'BiModal' + self.name = "BiModal" - self.mu1 = jnp.insert(jnp.zeros(d-1), 0, mu1) + self.mu1 = jnp.insert(jnp.zeros(d - 1), 0, mu1) self.mu2 = jnp.insert(jnp.zeros(d - 1), 0, mu2) self.sigma1, self.sigma2 = sigma1, sigma2 self.f = f - self.variance = jnp.insert(jnp.ones(d-1) * ((1 - f) * sigma1**2 + f * sigma2**2), 0, (1-f)*(sigma1**2 + mu1**2) + f*(sigma2**2 + mu2**2)) - - + self.variance = jnp.insert( + jnp.ones(d - 1) * ((1 - f) * sigma1**2 + f * sigma2**2), + 0, + (1 - f) * (sigma1**2 + mu1**2) + f * (sigma2**2 + mu2**2), + ) def logdensity_fn(self, x): """- log p of the target distribution""" - N1 = (1.0 - self.f) * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu1), axis= -1) / self.sigma1 ** 2) / jnp.power(2 * jnp.pi * self.sigma1 ** 2, self.ndims * 0.5) - N2 = self.f * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu2), axis= -1) / self.sigma2 ** 2) / jnp.power(2 * jnp.pi * self.sigma2 ** 2, self.ndims * 0.5) + N1 = ( + (1.0 - self.f) + * jnp.exp( + -0.5 * jnp.sum(jnp.square(x - self.mu1), axis=-1) / self.sigma1**2 + ) + / jnp.power(2 * jnp.pi * self.sigma1**2, self.ndims * 0.5) + ) + N2 = ( + self.f + * jnp.exp( + -0.5 * jnp.sum(jnp.square(x - self.mu2), axis=-1) / self.sigma2**2 + ) + / jnp.power(2 * jnp.pi * self.sigma2**2, self.ndims * 0.5) + ) return jnp.log(N1 + N2) - def draw(self, num_samples): """direct sampler from a target""" - X = np.random.normal(size = (num_samples, self.ndims)) + X = np.random.normal(size=(num_samples, self.ndims)) mask = np.random.uniform(0, 1, num_samples) < self.f X[mask, :] = (X[mask, :] * self.sigma2) + self.mu2 X[~mask] = (X[~mask] * self.sigma1) + self.mu1 return X - def transform(self, x): return x def sample_init(self, key): - z = jax.random.normal(key, shape = (self.ndims, )) *self.sigma1 - #z= z.at[0].set(self.mu1 + z[0]) + z = jax.random.normal(key, shape=(self.ndims,)) * self.sigma1 + # z= z.at[0].set(self.mu1 + z[0]) return z -class BiModalEqual(): +class BiModalEqual: """Mixture of two Gaussians, one centered at x = [mu/2, 0, 0, ...], the other at x = [-mu/2, 0, 0, ...]. - Both have equal probability mass.""" + Both have equal probability mass.""" def __init__(self, d, mu): - self.ndims = d - self.name = 'BiModalEqual' + self.name = "BiModalEqual" self.mu = mu - - def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x), axis= -1) + jnp.log(jnp.cosh(0.5*self.mu*x[0])) - 0.5* self.ndims * jnp.log(2 * jnp.pi) - self.mu**2 / 8.0 - + return ( + -0.5 * jnp.sum(jnp.square(x), axis=-1) + + jnp.log(jnp.cosh(0.5 * self.mu * x[0])) + - 0.5 * self.ndims * jnp.log(2 * jnp.pi) + - self.mu**2 / 8.0 + ) def draw(self, num_samples): """direct sampler from a target""" - X = np.random.normal(size = (num_samples, self.ndims)) + X = np.random.normal(size=(num_samples, self.ndims)) mask = np.random.uniform(0, 1, num_samples) < 0.5 - X[mask, 0] += 0.5*self.mu + X[mask, 0] += 0.5 * self.mu X[~mask, 0] -= 0.5 * self.mu return X @@ -327,82 +354,79 @@ def transform(self, x): return x -class Funnel(): +class Funnel: """Noise-less funnel""" - def __init__(self, d = 20): - + def __init__(self, d=20): self.ndims = d - self.name = 'Funnel' - self.sigma_theta= 3.0 - - self.E_x2 = jnp.ones(d) # the transformed variables are standard Gaussian distributed - self.Var_x2 = 2 * self.E_x2 - + self.name = "Funnel" + self.sigma_theta = 3.0 + self.E_x2 = jnp.ones( + d + ) # the transformed variables are standard Gaussian distributed + self.Var_x2 = 2 * self.E_x2 def logdensity_fn(self, x): - """ - log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta] """ + """- log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta]""" theta = x[-1] - X = x[..., :- 1] + X = x[..., :-1] - return -0.5* jnp.square(theta / self.sigma_theta) - 0.5 * (self.ndims - 1) * theta - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis = -1) + return ( + -0.5 * jnp.square(theta / self.sigma_theta) + - 0.5 * (self.ndims - 1) * theta + - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis=-1) + ) def inverse_transform(self, xtilde): theta = 3 * xtilde[-1] - return jnp.concatenate((xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1)*theta)) - + return jnp.concatenate( + (xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1) * theta) + ) def transform(self, x): """gaussianization""" xtilde = jnp.empty(x.shape) xtilde = xtilde.at[-1].set(x.T[-1] / 3.0) - xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5*x.T[-1])) + xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5 * x.T[-1])) return xtilde.T - def sample_init(self, key): - return self.inverse_transform(jax.random.normal(key, shape = (self.ndims, ))) - - - + return self.inverse_transform(jax.random.normal(key, shape=(self.ndims,))) -class Funnel_with_Data(): +class Funnel_with_Data: def __init__(self, d, sigma, minibatch_size, key): - self.ndims = d - self.name = 'Funnel_with_Data' - self.sigma_theta= 3.0 + self.name = "Funnel_with_Data" + self.sigma_theta = 3.0 self.theta_true = 0.0 self.sigma_data = sigma - self.data = self.simulate_data() self.batch = minibatch_size def simulate_data(self): - - norm = jax.random.normal(jax.random.PRNGKey(123), shape = (2*(self.ndims-1), )) - z_true = norm[:self.ndims-1] * jnp.exp(self.theta_true * 0.5) - self.data = z_true + norm[self.ndims-1:] * self.sigma_data - + norm = jax.random.normal(jax.random.PRNGKey(123), shape=(2 * (self.ndims - 1),)) + z_true = norm[: self.ndims - 1] * jnp.exp(self.theta_true * 0.5) + self.data = z_true + norm[self.ndims - 1 :] * self.sigma_data def logdensity_fn(self, x, subset): - """ - log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta] """ + """- log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta]""" theta = x[-1] - z = x[:- 1][subset] + z = x[:-1][subset] prior_theta = jnp.square(theta / self.sigma_theta) - prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum(jnp.square(z*subset)) - likelihood = jnp.sum(jnp.square((z - self.data)*subset / self.sigma_data)) + prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum( + jnp.square(z * subset) + ) + likelihood = jnp.sum(jnp.square((z - self.data) * subset / self.sigma_data)) return -0.5 * (prior_theta + prior_z + likelihood) - def transform(self, x): """gaussianization""" return x @@ -410,58 +434,56 @@ def transform(self, x): def sample_init(self, key): key1, key2 = jax.random.split(key) theta = jax.random.normal(key1) * self.sigma_theta - z = jax.random.normal(key2, shape = (self.ndims-1, )) * jnp.exp(theta * 0.5) + z = jax.random.normal(key2, shape=(self.ndims - 1,)) * jnp.exp(theta * 0.5) return jnp.concatenate((z, theta)) - - -class Rosenbrock(): - - def __init__(self, d = 36, Q = 0.1): - +class Rosenbrock: + def __init__(self, d=36, Q=0.1): self.ndims = d - self.name = 'Rosenbrock' + self.name = "Rosenbrock" self.Q = Q - #ground truth moments + # ground truth moments var_x = 2.0 - #these two options were precomputed: + # these two options were precomputed: if Q == 0.1: - var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) + var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) elif Q == 0.5: var_y = 10.498957879911487 else: - raise ValueError('Ground truth moments for Q = ' + str(Q) + ' were not precomputed. Use Q = 0.1 or 0.5.') - - self.variance = jnp.concatenate((var_x * jnp.ones(d//2), var_y * jnp.ones(d//2))) - - + raise ValueError( + "Ground truth moments for Q = " + + str(Q) + + " were not precomputed. Use Q = 0.1 or 0.5." + ) + self.variance = jnp.concatenate( + (var_x * jnp.ones(d // 2), var_y * jnp.ones(d // 2)) + ) def logdensity_fn(self, x): """- log p of the target distribution""" - X, Y = x[..., :self.ndims//2], x[..., self.ndims//2:] - return -0.5 * jnp.sum(jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis= -1) - - + X, Y = x[..., : self.ndims // 2], x[..., self.ndims // 2 :] + return -0.5 * jnp.sum( + jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis=-1 + ) def draw(self, num): n = self.ndims // 2 - X= np.empty((num, self.ndims)) - X[:, :n] = np.random.normal(loc= 1.0, scale= 1.0, size= (num, n)) - X[:, n:] = np.random.normal(loc= jnp.square(X[:, :n]), scale= jnp.sqrt(self.Q), size= (num, n)) + X = np.empty((num, self.ndims)) + X[:, :n] = np.random.normal(loc=1.0, scale=1.0, size=(num, n)) + X[:, n:] = np.random.normal( + loc=jnp.square(X[:, :n]), scale=jnp.sqrt(self.Q), size=(num, n) + ) return X - def transform(self, x): return x - def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) - + return jax.random.normal(key, shape=(self.ndims,)) def ground_truth(self): num = 100000000 @@ -474,13 +496,12 @@ def ground_truth(self): x1 = np.average(x) y1 = np.average(y) - print(np.sqrt(0.5*(np.square(np.std(x)) + np.square(np.std(y))))) + print(np.sqrt(0.5 * (np.square(np.std(x)) + np.square(np.std(y))))) print(x2, y2) - -class Brownian(): +class Brownian: """ log sigma_i ~ N(0, 2) log sigma_obs ~N(0, 2) @@ -493,36 +514,75 @@ class Brownian(): def __init__(self): self.num_data = 30 - self.name = 'Brownian' + self.name = "Brownian" self.ndims = self.num_data + 2 - ground_truth_moments = jnp.load(dirr + '/ground_truth/brownian/ground_truth.npy') + ground_truth_moments = jnp.load( + dirr + "/ground_truth/brownian/ground_truth.npy" + ) self.E_x2, self.Var_x2 = ground_truth_moments[0], ground_truth_moments[1] - self.data = jnp.array([0.21592641, 0.118771404, -0.07945447, 0.037677474, -0.27885845, -0.1484156, -0.3250906, -0.22957903, - -0.44110894, -0.09830782, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8786016, -0.83736074, - -0.7384849, -0.8939254, -0.7774566, -0.70238715, -0.87771565, -0.51853573, -0.6948214, -0.6202789]) + self.data = jnp.array( + [ + 0.21592641, + 0.118771404, + -0.07945447, + 0.037677474, + -0.27885845, + -0.1484156, + -0.3250906, + -0.22957903, + -0.44110894, + -0.09830782, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -0.8786016, + -0.83736074, + -0.7384849, + -0.8939254, + -0.7774566, + -0.70238715, + -0.87771565, + -0.51853573, + -0.6948214, + -0.6202789, + ] + ) # sigma_obs = 0.15, sigma_i = 0.1 self.observable = jnp.concatenate((jnp.ones(10), jnp.zeros(10), jnp.ones(10))) self.num_observable = jnp.sum(self.observable) # = 20 - def logdensity_fn(self, x): # y = softplus_to_log(x[:2]) - lik = 0.5 * jnp.exp(-2 * x[1]) * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) + x[ - 1] * self.num_observable - prior_x = 0.5 * jnp.exp(-2 * x[0]) * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) + x[0] * self.num_data + lik = ( + 0.5 + * jnp.exp(-2 * x[1]) + * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) + + x[1] * self.num_observable + ) + prior_x = ( + 0.5 + * jnp.exp(-2 * x[0]) + * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) + + x[0] * self.num_data + ) prior_logsigma = 0.5 * jnp.sum(jnp.square(x / 2.0)) return -lik - prior_x - prior_logsigma - def transform(self, x): return jnp.concatenate((jnp.exp(x[:2]), x[2:])) - def sample_init(self, key): key_walk, key_sigma = jax.random.split(key) @@ -530,8 +590,10 @@ def sample_init(self, key): # log_sigma = jax.random.normal(key_sigma, shape= (2, )) * 2 # narrower prior - log_sigma = jnp.log(np.array([0.1, 0.15])) + jax.random.normal(key_sigma, shape=( - 2,)) * 0.1 # *0.05# log sigma_i, log sigma_obs + log_sigma = ( + jnp.log(np.array([0.1, 0.15])) + + jax.random.normal(key_sigma, shape=(2,)) * 0.1 + ) # *0.05# log sigma_i, log sigma_obs walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) @@ -540,50 +602,50 @@ def sample_init(self, key): def generate_data(self, key): key_walk, key_sigma, key_noise = jax.random.split(key, 3) - log_sigma = jax.random.normal(key_sigma, shape=(2,)) * 2 # log sigma_i, log sigma_obs + log_sigma = ( + jax.random.normal(key_sigma, shape=(2,)) * 2 + ) # log sigma_i, log sigma_obs walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) - noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp(log_sigma[1]) + noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp( + log_sigma[1] + ) return walk + noise class GermanCredit: - """ Taken from inference gym. + """Taken from inference gym. - x = (global scale, local scales, weights) + x = (global scale, local scales, weights) - global_scale ~ Gamma(0.5, 0.5) + global_scale ~ Gamma(0.5, 0.5) - for i in range(num_features): - unscaled_weights[i] ~ Normal(loc=0, scale=1) - local_scales[i] ~ Gamma(0.5, 0.5) - weights[i] = unscaled_weights[i] * local_scales[i] * global_scale + for i in range(num_features): + unscaled_weights[i] ~ Normal(loc=0, scale=1) + local_scales[i] ~ Gamma(0.5, 0.5) + weights[i] = unscaled_weights[i] * local_scales[i] * global_scale - for j in range(num_datapoints): - label[j] ~ Bernoulli(features @ weights) + for j in range(num_datapoints): + label[j] ~ Bernoulli(features @ weights) - We use a log transform for the scale parameters. + We use a log transform for the scale parameters. """ def __init__(self): - self.ndims = 51 #global scale + 25 local scales + 25 weights - self.name = 'GermanCredit' + self.ndims = 51 # global scale + 25 local scales + 25 weights + self.name = "GermanCredit" - self.labels = jnp.load(dirr + '/data/gc_labels.npy') - self.features = jnp.load(dirr + '/data/gc_features.npy') + self.labels = jnp.load(dirr + "/data/gc_labels.npy") + self.features = jnp.load(dirr + "/data/gc_features.npy") - truth = jnp.load(dirr+'/ground_truth/german_credit/ground_truth.npy') + truth = jnp.load(dirr + "/ground_truth/german_credit/ground_truth.npy") self.E_x2, self.Var_x2 = truth[0], truth[1] - - - def transform(self, x): return jnp.concatenate((jnp.exp(x[:26]), x[26:])) def logdensity_fn(self, x): - scales = jnp.exp(x[:26]) # prior @@ -594,139 +656,166 @@ def logdensity_fn(self, x): # likelihood weights = scales[0] * scales[1:26] * x[26:] - logits = self.features @ weights # = jnp.einsum('nd,...d->...n', self.features, weights) - lik = jnp.sum(self.labels * jnp.logaddexp(0., -logits) + (1-self.labels)* jnp.logaddexp(0., logits)) + logits = ( + self.features @ weights + ) # = jnp.einsum('nd,...d->...n', self.features, weights) + lik = jnp.sum( + self.labels * jnp.logaddexp(0.0, -logits) + + (1 - self.labels) * jnp.logaddexp(0.0, logits) + ) return -(lik + pr + transform) def sample_init(self, key): - weights = jax.random.normal(key, shape = (25, )) + weights = jax.random.normal(key, shape=(25,)) return jnp.concatenate((jnp.zeros(26), weights)) - - class ItemResponseTheory: - """ Taken from inference gym.""" + """Taken from inference gym.""" def __init__(self): self.ndims = 501 - self.name = 'ItemResponseTheory' + self.name = "ItemResponseTheory" self.students = 400 self.questions = 100 - self.mask = jnp.load(dirr + '/data/irt_mask.npy') - self.labels = jnp.load(dirr + '/data/irt_labels.npy') + self.mask = jnp.load(dirr + "/data/irt_mask.npy") + self.labels = jnp.load(dirr + "/data/irt_labels.npy") - truth = jnp.load(dirr+'/ground_truth/item_response_theory/ground_truth.npy') + truth = jnp.load(dirr + "/ground_truth/item_response_theory/ground_truth.npy") self.E_x2, self.Var_x2 = truth[0], truth[1] - self.transform = lambda x: x def logdensity_fn(self, x): - - students = x[:self.students] + students = x[: self.students] mean = x[self.students] - questions = x[self.students + 1:] + questions = x[self.students + 1 :] # prior - pr = 0.5 * (jnp.square(mean - 0.75) + jnp.sum(jnp.square(students)) + jnp.sum(jnp.square(questions))) + pr = 0.5 * ( + jnp.square(mean - 0.75) + + jnp.sum(jnp.square(students)) + + jnp.sum(jnp.square(questions)) + ) # likelihood logits = mean + students[:, jnp.newaxis] - questions[jnp.newaxis, :] - bern = self.labels * jnp.logaddexp(0., -logits) + (1 - self.labels) * jnp.logaddexp(0., logits) + bern = self.labels * jnp.logaddexp(0.0, -logits) + ( + 1 - self.labels + ) * jnp.logaddexp(0.0, logits) bern = jnp.where(self.mask, bern, jnp.zeros_like(bern)) lik = jnp.sum(bern) return -lik - pr - def sample_init(self, key): - x = jax.random.normal(key, shape = (self.ndims,)) + x = jax.random.normal(key, shape=(self.ndims,)) x = x.at[self.students].add(0.75) return x - - -class StochasticVolatility(): +class StochasticVolatility: """Example from https://num.pyro.ai/en/latest/examples/stochastic_volatility.html""" def __init__(self): - self.SP500_returns = jnp.load(dirr + '/data/SP500.npy') + self.SP500_returns = jnp.load(dirr + "/data/SP500.npy") self.ndims = 2429 - self.name = 'StochasticVolatility' + self.name = "StochasticVolatility" - self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda + self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda - data = jnp.load(dirr + '/ground_truth/stochastic_volatility/ground_truth_0.npy') + data = jnp.load(dirr + "/ground_truth/stochastic_volatility/ground_truth_0.npy") self.E_x2 = data[0] self.Var_x2 = data[1] - - def logdensity_fn(self, x): """- log p of the target distribution - x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" + x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" - sigma = jnp.exp(x[-2]) * self.typical_sigma #we used this transformation to make x unconstrained + sigma = ( + jnp.exp(x[-2]) * self.typical_sigma + ) # we used this transformation to make x unconstrained nu = jnp.exp(x[-1]) * self.typical_nu - l1= (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) - l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * (jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3]))) / jnp.square(sigma) + l1 = (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) + l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * ( + jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3])) + ) / jnp.square(sigma) l3 = jnp.sum(nlogp_StudentT(self.SP500_returns, nu, jnp.exp(x[:-2]))) return -(l1 + l2 + l3) - def transform(self, x): """transforms to the variables which are used by numpyro (and in which we have the ground truth moments)""" z = jnp.empty(x.shape) - z = z.at[:-2].set(x[:-2]) # = s = log R - z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma - z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu + z = z.at[:-2].set(x[:-2]) # = s = log R + z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma + z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu return z - def sample_init(self, key): """draws x from the prior""" key_walk, key_exp = jax.random.split(key) scales = jnp.array([self.typical_sigma, self.typical_nu]) - #params = jax.random.exponential(key_exp, shape = (2, )) * scales - params= scales + # params = jax.random.exponential(key_exp, shape = (2, )) * scales + params = scales walk = random_walk(key_walk, self.ndims - 2) * params[0] - return jnp.concatenate((walk, jnp.log(params/scales))) - + return jnp.concatenate((walk, jnp.log(params / scales))) -class MixedLogit(): +class MixedLogit: def __init__(self): - key = jax.random.PRNGKey(0) key_poisson, key_x, key_beta, key_logit = jax.random.split(key, 4) self.ndims = 2014 self.name = "Mixed Logit" self.nind = 500 - self.nsessions = jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 + self.nsessions = ( + jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 + ) self.nbeta = 4 nobs = jnp.sum(self.nsessions) mu_true = jnp.array([-1.5, -0.3, 0.8, 1.2]) - sigma_true = jnp.array([[0.5, 0.1, 0.1, 0.1], [0.1, 0.5, 0.1, 0.1], [0.1, 0.1, 0.5, 0.1], [0.1, 0.1, 0.1, 0.5]]) - beta_true = jax.random.multivariate_normal(key_beta, mu_true, sigma_true, shape=(self.nind,)) + sigma_true = jnp.array( + [ + [0.5, 0.1, 0.1, 0.1], + [0.1, 0.5, 0.1, 0.1], + [0.1, 0.1, 0.5, 0.1], + [0.1, 0.1, 0.1, 0.5], + ] + ) + beta_true = jax.random.multivariate_normal( + key_beta, mu_true, sigma_true, shape=(self.nind,) + ) beta_true_repeat = jnp.repeat(beta_true, self.nsessions, axis=0) self.x = jax.random.normal(key_x, (nobs, self.nbeta)) - self.y = 1 * jax.random.bernoulli(key_logit, (jax.nn.sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_true_repeat)))) - - self.d = self.nbeta + self.nbeta + (self.nbeta * (self.nbeta-1) // 2) + self.nbeta * self.nind # mu, tau, omega_chol, and (beta for each i) + self.y = 1 * jax.random.bernoulli( + key_logit, + ( + jax.nn.sigmoid( + jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))( + self.x, beta_true_repeat + ) + ) + ), + ) + + self.d = ( + self.nbeta + + self.nbeta + + (self.nbeta * (self.nbeta - 1) // 2) + + self.nbeta * self.nind + ) # mu, tau, omega_chol, and (beta for each i) self.prior_mean_mu = jnp.zeros(self.nbeta) self.prior_var_mu = 10.0 * jnp.eye(self.nbeta) self.prior_scale_tau = 5.0 @@ -734,20 +823,20 @@ def __init__(self): self.grad_logp = jax.value_and_grad(self.logdensity_fn) - def corrchol_to_reals(self,x): - '''Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals''' + def corrchol_to_reals(self, x): + """Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals""" dim = x.shape[0] z = jnp.zeros((dim, dim)) for i in range(dim): for j in range(i): - z = z.at[i, j].set(x[i,j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + z = z.at[i, j].set(x[i, j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) z_lower_triang = z[jnp.tril_indices(dim, -1)] y = 0.5 * (jnp.log(1.0 + z_lower_triang) - jnp.log(1.0 - z_lower_triang)) return y - def reals_to_corrchol(self,y): - '''Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix''' + def reals_to_corrchol(self, y): + """Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix""" len_vec = len(y) dim = int(0.5 * (1 + 8 * len_vec) ** 0.5 + 0.5) assert dim * (dim - 1) // 2 == len_vec @@ -757,20 +846,21 @@ def reals_to_corrchol(self,y): x = jnp.zeros((dim, dim)) for i in range(dim): - for j in range(i+1): + for j in range(i + 1): if i == j: x = x.at[i, j].set(jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) else: - x = x.at[i, j].set(z[i,j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + x = x.at[i, j].set( + z[i, j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0)) + ) return x - def logdensity_fn(self, pars): """log p of the target distribution, i.e., log posterior distribution up to a constant""" - mu = pars[:self.nbeta] + mu = pars[: self.nbeta] dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta:dim1] + log_tau = pars[self.nbeta : dim1] dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 omega_chol_realvec = pars[dim1:dim2] beta = pars[dim2:].reshape(self.nind, self.nbeta) @@ -783,26 +873,55 @@ def logdensity_fn(self, pars): beta_repeat = jnp.repeat(beta, self.nsessions, axis=0) - log_lik = jnp.sum(self.y * jax.nn.log_sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat)) + (1 - self.y) * jax.nn.log_sigmoid(-jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat))) - - log_density_beta_popdist = -0.5 * self.nind * jnp.log(jnp.linalg.det(sigma)) - 0.5 * jnp.sum(jax.vmap(lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), in_axes=(0, None))(beta - mu, sigma)) + log_lik = jnp.sum( + self.y + * jax.nn.log_sigmoid( + jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) + ) + + (1 - self.y) + * jax.nn.log_sigmoid( + -jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) + ) + ) + + log_density_beta_popdist = -0.5 * self.nind * jnp.log( + jnp.linalg.det(sigma) + ) - 0.5 * jnp.sum( + jax.vmap( + lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), + in_axes=(0, None), + )(beta - mu, sigma) + ) muMinusPriorMean = mu - self.prior_mean_mu - log_prior_mu = -0.5 * jnp.log(jnp.linalg.det(self.prior_var_mu)) - 0.5 * jnp.dot(muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean)) - - log_prior_tau = jnp.sum(dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau)) - #log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) - log_prior_omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).log_prob(omega_chol) - #log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) - - return log_lik + log_density_beta_popdist + log_prior_mu + log_prior_tau + log_prior_omega_chol - + log_prior_mu = -0.5 * jnp.log( + jnp.linalg.det(self.prior_var_mu) + ) - 0.5 * jnp.dot( + muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean) + ) + + log_prior_tau = jnp.sum( + dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau) + ) + # log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) + log_prior_omega_chol = dist.LKJCholesky( + self.nbeta, concentration=self.prior_concentration_omega + ).log_prob(omega_chol) + # log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) + + return ( + log_lik + + log_density_beta_popdist + + log_prior_mu + + log_prior_tau + + log_prior_omega_chol + ) def transform(self, pars): """transform pars to the original (possibly constrained) pars""" - mu = pars[:self.nbeta] + mu = pars[: self.nbeta] dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta:dim1] + log_tau = pars[self.nbeta : dim1] dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 omega_chol_realvec = pars[dim1:dim2] beta_flattened = pars[dim2:] @@ -819,8 +938,12 @@ def sample_init(self, key): """draws pars from the prior""" key_mu, key_omega_chol, key_tau, key_beta = jax.random.split(key, 4) - mu = jax.random.multivariate_normal(key_mu, self.prior_mean_mu, self.prior_var_mu) - omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).sample(key_omega_chol) + mu = jax.random.multivariate_normal( + key_mu, self.prior_mean_mu, self.prior_var_mu + ) + omega_chol = dist.LKJCholesky( + self.nbeta, concentration=self.prior_concentration_omega + ).sample(key_omega_chol) tau = dist.HalfCauchy(scale=self.prior_scale_tau).sample(key_tau, (self.nbeta,)) omega_chol_realvec = self.corrchol_to_reals(omega_chol) @@ -836,7 +959,6 @@ def sample_init(self, key): return pars - def nlogp_StudentT(x, df, scale): y = x / scale z = ( @@ -849,17 +971,16 @@ def nlogp_StudentT(x, df, scale): return 0.5 * (df + 1.0) * jnp.log1p(y**2.0 / df) + z - def random_walk(key, num): - """ Genereting process for the standard normal walk: - x[0] ~ N(0, 1) - x[n+1] ~ N(x[n], 1) - - Args: - key: jax random key - num: number of points in the walk - Returns: - 1 realization of the random walk (array of length num) + """Genereting process for the standard normal walk: + x[0] ~ N(0, 1) + x[n+1] ~ N(x[n], 1) + + Args: + key: jax random key + num: number of points in the walk + Returns: + 1 realization of the random walk (array of length num) """ def step(track, useless): @@ -871,22 +992,18 @@ def step(track, useless): return jax.lax.scan(step, init=(0.0, key), xs=None, length=num)[1] - models = { - - # Cauchy(100) : {'mclmc': 2000, 'adjusted_mclmc' : 2000, 'nuts': 2000}, - # StandardNormal(100) : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, - # Banana() : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, - Brownian() : {'mclmc': 20000, 'adjusted_mclmc' : 80000, 'nuts': 40000}, - - - # 'banana': Banana(), + # Cauchy(100) : {'mclmc': 2000, 'adjusted_mclmc' : 2000, 'nuts': 2000}, + # StandardNormal(100) : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, + # Banana() : {'mclmc': 10000, 'adjusted_mclmc' : 10000, 'nuts': 10000}, + Brownian(): {"mclmc": 20000, "adjusted_mclmc": 80000, "nuts": 40000}, + # 'banana': Banana(), # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'adjusted_mclmc' : 2000, 'nuts': 2000}), # GermanCredit(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000}, # ItemResponseTheory(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000}, # StochasticVolatility(): {'mclmc': 20000, 'adjusted_mclmc' : 40000, 'nuts': 20000} - } +} # models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'adjusted_mclmc' : 40000, 'nuts': 1000}), # # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'adjusted_mclmc' : 50000, 'nuts': 1000}) -# } \ No newline at end of file +# } diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py index bbd7e3e59..b559d3c11 100644 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -4,38 +4,50 @@ import jax import jax.numpy as jnp + import blackjax + # from blackjax.adaptation.window_adaptation import da_adaptation from blackjax.mcmc.adjusted_mclmc import rescale -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order +from blackjax.mcmc.integrators import ( + calls_per_integrator_step, + generate_euclidean_integrator, + generate_isokinetic_integrator, + integrator_order, +) + # from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.util import run_inference_algorithm -import blackjax __all__ = ["samplers"] -target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} +target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} + -def run_nuts( - coefficients, logdensity_fn, num_steps, initial_position, transform, key): - +def run_nuts(coefficients, logdensity_fn, num_steps, initial_position, transform, key): integrator = generate_euclidean_integrator(coefficients) # integrator = blackjax.mcmc.integrators.velocity_verlet # note: defaulted to in nuts rng_key, warmup_key = jax.random.split(key, 2) state, params = da_adaptation( - rng_key=warmup_key, - initial_position=initial_position, + rng_key=warmup_key, + initial_position=initial_position, algorithm=blackjax.nuts, - logdensity_fn=logdensity_fn) - + logdensity_fn=logdensity_fn, + ) + # print(params["inverse_mass_matrix"], "inv\n\n") # warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, integrator=integrator) # (state, params), _ = warmup.run(warmup_key, initial_position, 2000) - nuts = blackjax.nuts(logdensity_fn=logdensity_fn, step_size=params['step_size'], inverse_mass_matrix= params['inverse_mass_matrix'], integrator=integrator) + nuts = blackjax.nuts( + logdensity_fn=logdensity_fn, + step_size=params["step_size"], + inverse_mass_matrix=params["inverse_mass_matrix"], + integrator=integrator, + ) final_state, state_history, info_history = run_inference_algorithm( rng_key=rng_key, @@ -43,26 +55,32 @@ def run_nuts( inference_algorithm=nuts, num_steps=num_steps, transform=lambda x: transform(x.position), - progress_bar=True + progress_bar=True, ) # print("INFO\n\n",info_history.num_integration_steps) - return state_history, params, info_history.num_integration_steps.mean() * calls_per_integrator_step(coefficients), info_history.acceptance_rate.mean(), None, None + return ( + state_history, + params, + info_history.num_integration_steps.mean() + * calls_per_integrator_step(coefficients), + info_history.acceptance_rate.mean(), + None, + None, + ) -def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): +def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): integrator = generate_isokinetic_integrator(coefficients) init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - - kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, std_mat=std_mat, @@ -87,8 +105,7 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=blackjax_mclmc_sampler_params.std_mat, - integrator = integrator, - + integrator=integrator, # std_mat=jnp.ones((initial_position.shape[0],)), ) @@ -101,38 +118,60 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor progress_bar=True, ) - acceptance_rate = 1. - return samples, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients), acceptance_rate, None, None + acceptance_rate = 1.0 + return ( + samples, + blackjax_mclmc_sampler_params, + calls_per_integrator_step(coefficients), + acceptance_rate, + None, + None, + ) -def run_adjusted_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, target_acc_rate=None): +def run_adjusted_mclmc( + coefficients, + logdensity_fn, + num_steps, + initial_position, + transform, + key, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + target_acc_rate=None, +): integrator = generate_isokinetic_integrator(coefficients) init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.adjusted_mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( - integrator=integrator, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=logdensity_fn) - + integrator=integrator, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, state=state, step_size=step_size, logdensity_fn=logdensity_fn + ) + if target_acc_rate is None: - target_acc_rate = target_acceptance_rate_of_order[integrator_order(coefficients)] + target_acc_rate = target_acceptance_rate_of_order[ + integrator_order(coefficients) + ] print("target acc rate") ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, params_history, - final_da + final_da, ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -145,43 +184,46 @@ def run_adjusted_mclmc(coefficients, logdensity_fn, num_steps, initial_position, diagonal_preconditioning=False, ) - - step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.step_size, blackjax_mclmc_sampler_params.L)) - alg = blackjax.adjusted_mclmc( logdensity_fn=logdensity_fn, step_size=step_size, - integration_steps_fn = lambda key: jnp.ceil(jax.random.uniform(key) * rescale(L/step_size)) , + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), integrator=integrator, std_mat=blackjax_mclmc_sampler_params.std_mat, - - ) - _, out, info = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) - + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) + return ( + out, + blackjax_mclmc_sampler_params, + calls_per_integrator_step(coefficients) * (L / step_size), + info.acceptance_rate, + params_history, + final_da, + ) - return out, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients) * (L/step_size), info.acceptance_rate, params_history, final_da # we should do at least: mclmc, nuts, unadjusted hmc, adjusted_mclmc, langevin samplers = { - 'nuts' : run_nuts, - 'mclmc' : run_mclmc, - 'adjusted_mclmc': run_adjusted_mclmc, - } + "nuts": run_nuts, + "mclmc": run_mclmc, + "adjusted_mclmc": run_adjusted_mclmc, +} # foo = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(20.56)) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index aef0b3f57..188b4ef71 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -17,35 +17,29 @@ import jax import jax.numpy as jnp -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.types import ArrayLike import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence from blackjax.mcmc.hmc import HMCInfo from blackjax.mcmc.proposal import static_binomial_sampling - from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector -__all__ = [ - "init", - "build_kernel", - "adjusted_mclmc", -] +__all__ = ["init", "build_kernel", "as_top_level_api"] -def init( - position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array -): + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + # TODO: no default for std_mat def build_kernel( integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - std_mat=1., + std_mat=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -73,29 +67,29 @@ def kernel( state: DynamicHMCState, logdensity_fn: Callable, step_size: float, - L_proposal : float = 1.0, + L_proposal: float = 1.0, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - - num_integration_steps = integration_steps_fn( - state.random_generator_arg - ) + + num_integration_steps = integration_steps_fn(state.random_generator_arg) key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( # integrators.with_isokinetic_maruyama(integrator(logdensity_fn)), - lambda state, step_size, L_prop, key : (integrator(logdensity_fn, std_mat))(state, step_size), + lambda state, step_size, L_prop, key: (integrator(logdensity_fn, std_mat))( + state, step_size + ), step_size, L_proposal, num_integration_steps, divergence_threshold, )( - key_integrator, + key_integrator, integrators.IntegratorState( - state.position, momentum, state.logdensity, state.logdensity_grad - ) + state.position, momentum, state.logdensity, state.logdensity_grad + ), ) return ( @@ -110,16 +104,17 @@ def kernel( return kernel + def as_top_level_api( - logdensity_fn: Callable, - step_size: float, - L_proposal : float = 0.6, - std_mat=1.0, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + logdensity_fn: Callable, + step_size: float, + L_proposal: float = 0.6, + std_mat=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), ) -> SamplingAlgorithm: """Implements the (basic) user interface for the dynamic MHMCHMC kernel. @@ -147,9 +142,13 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, std_mat=std_mat, divergence_threshold=divergence_threshold) - - + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + std_mat=std_mat, + divergence_threshold=divergence_threshold, + ) def init_fn(position: ArrayLikeTree, rng_key: Array): return init(position, logdensity_fn, rng_key) @@ -162,11 +161,7 @@ def update_fn(rng_key: PRNGKey, state): step_size, L_proposal, ) - - def init_fn(position: ArrayLike, rng_key: PRNGKey): - return init(position, logdensity_fn, rng_key) - return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -209,12 +204,16 @@ def adjusted_mclmc_proposal( def step(i, vars): state, kinetic_energy, rng_key = vars rng_key, next_rng_key = jax.random.split(rng_key) - next_state, next_kinetic_energy = integrator(state, step_size, L_proposal, rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal, rng_key + ) return next_state, kinetic_energy + next_kinetic_energy, next_rng_key def build_trajectory(state, num_integration_steps, rng_key): - return jax.lax.fori_loop(0*num_integration_steps, num_integration_steps, step, (state, 0, rng_key)) + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) def generate( rng_key, state: integrators.IntegratorState @@ -225,7 +224,7 @@ def generate( ) # note that this is the POTENTIAL energy only - new_energy = -end_state.logdensity + new_energy = -end_state.logdensity delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) is_diverging = -delta_energy > divergence_threshold @@ -246,6 +245,7 @@ def generate( return generate + def rescale(mu): """returns s, such that round(U(0, 1) * s + 0.5) @@ -255,6 +255,7 @@ def rescale(mu): x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) return k + x + def trajectory_length(t, mu): s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) \ No newline at end of file + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 6f402dd67..aefc3d4b9 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -39,7 +39,7 @@ "implicit_midpoint", "calls_per_integrator_step", "name_integrator", - "integrator_order" + "integrator_order", ] @@ -481,14 +481,20 @@ def name_integrator(c): else: raise Exception("No such integrator exists in blackjax") + def integrator_order(c): - if c==velocity_verlet_coefficients: return 2 - if c==mclachlan_coefficients: return 2 - if c==yoshida_coefficients: return 4 - if c==omelyan_coefficients: return 4 - + if c == velocity_verlet_coefficients: + return 2 + if c == mclachlan_coefficients: + return 2 + if c == yoshida_coefficients: + return 4 + if c == omelyan_coefficients: + return 4 + + else: + raise Exception("No such integrator exists in blackjax") - else: raise Exception("No such integrator exists in blackjax") FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], From 1c17ecf3f02870673881de26f48a84059bc3ba9c Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 May 2024 00:55:59 +0200 Subject: [PATCH 29/71] UNIFY ADJUSTED MCLMC AND MCHMC --- blackjax/mcmc/adjusted_mclmc.py | 18 +++++++----------- blackjax/mcmc/integrators.py | 10 +++++++++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 188b4ef71..282f67658 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -67,7 +67,7 @@ def kernel( state: DynamicHMCState, logdensity_fn: Callable, step_size: float, - L_proposal: float = 1.0, + L_proposal: float = jnp.inf, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -75,16 +75,12 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) - proposal, info, _ = adjusted_mclmc_proposal( - # integrators.with_isokinetic_maruyama(integrator(logdensity_fn)), - lambda state, step_size, L_prop, key: (integrator(logdensity_fn, std_mat))( - state, step_size - ), - step_size, - L_proposal, - num_integration_steps, - divergence_threshold, + integrator=integrators.with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)), + step_size=step_size, + L_proposal=L_proposal*num_integration_steps, + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, )( key_integrator, integrators.IntegratorState( @@ -108,7 +104,7 @@ def kernel( def as_top_level_api( logdensity_fn: Callable, step_size: float, - L_proposal: float = 0.6, + L_proposal: float = jnp.inf, std_mat=1.0, *, divergence_threshold: int = 1000, diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index aefc3d4b9..dcd57e2e2 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -419,11 +419,19 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): ------- momentum with random change in angle """ + m, unravel_fn = ravel_pytree(momentum) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) + # return new_momentum + return jax.lax.cond( + jnp.isinf(L), + lambda _: momentum, + lambda _: new_momentum, + operand=None, + ) def with_isokinetic_maruyama(integrator): From 6bd5ab1f134d5864cd66502fd80e02bb76110325 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 May 2024 19:03:58 +0200 Subject: [PATCH 30/71] ADD INITIAL_POSITION --- blackjax/util.py | 16 +++++++++++- tests/adaptation/test_adaptation.py | 7 ++++- tests/mcmc/test_sampling.py | 40 ++++++++++++++++++++++------- tests/test_benchmarks.py | 5 +++- tests/test_util.py | 10 ++++---- 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index e579c126d..746613808 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -143,9 +143,10 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, + initial_state: ArrayLikeTree = None, + initial_position: ArrayLikeTree = None, progress_bar: bool = False, transform: Callable = lambda x: x, return_state_history=True, @@ -163,6 +164,8 @@ def run_inference_algorithm( The random state used by JAX's random numbers generator. initial_state The initial state of the inference algorithm. + initial_position + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -189,6 +192,17 @@ def run_inference_algorithm( 2. The final state of the inference algorithm. """ + if initial_state is None and initial_position is None: + raise ValueError("Either initial_state or initial_position must be provided.") + if initial_state is not None and initial_position is not None: + raise ValueError( + "Only one of initial_state or initial_position must be provided." + ) + + rng_key, init_key = split(rng_key, 2) + if initial_position is not None: + initial_state = inference_algorithm.init(initial_state, init_key) + keys = split(rng_key, num_steps) def one_step(average_and_state, xs, return_state): diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index f54d18c21..4450e61f9 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -62,7 +62,12 @@ def test_chees_adaptation(): chain_keys = jax.random.split(inference_key, num_chains) _, _, infos = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, algorithm, num_results) + lambda key, state: run_inference_algorithm( + rng_key=key, + initial_state=state, + inference_algorithm=algorithm, + num_steps=num_results, + ) )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 19f72a7c2..f334206e6 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -141,7 +141,10 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): inference_algorithm = case["algorithm"](logposterior_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, case["num_sampling_steps"] + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=case["num_sampling_steps"], ) coefs_samples = states.position["coefs"] @@ -163,7 +166,12 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000) + _, states, _ = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=mala, + num_steps=10_000, + ) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -229,7 +237,10 @@ def test_pathfinder_adaptation( inference_algorithm = algorithm(logposterior_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, num_sampling_steps + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=num_sampling_steps, ) coefs_samples = states.position["coefs"] @@ -270,7 +281,10 @@ def test_meads(self): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=100, ) )(chain_keys, last_states) @@ -314,7 +328,10 @@ def test_chees(self, jitter_generator): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( lambda key, state: run_inference_algorithm( - key, state, inference_algorithm, 100 + rng_key=key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=100, ) )(chain_keys, last_states) @@ -338,7 +355,12 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) + _, states, _ = run_inference_algorithm( + rng_key=inference_key, + initial_state=state, + inference_algorithm=barker, + num_steps=10_000, + ) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -524,7 +546,7 @@ def test_latent_gaussian(self): inference_algorithm=inference_algorithm, num_steps=self.sampling_steps, ), - )(self.key, initial_state) + )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 @@ -568,7 +590,7 @@ def univariate_normal_test_case( inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, ) - )(inference_key, initial_state) + )(rng_key=inference_key, initial_state=initial_state) # else: if postprocess_samples: @@ -839,7 +861,7 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) ) _, states, _ = inference_loop_multiple_chains( - multi_chain_sample_key, initial_states + rng_key=multi_chain_sample_key, initial_state=initial_states ) posterior_samples = states.position[:, -1000:] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index d8f09cea0..c2295e7e2 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -49,7 +49,10 @@ def run_regression(algorithm, **parameters): inference_algorithm = algorithm(logdensity_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, inference_algorithm, 10_000 + rng_key=inference_key, + initial_state=state, + inference_algorithm=inference_algorithm, + num_steps=10_000, ) return states diff --git a/tests/test_util.py b/tests/test_util.py index 3bafca894..df649efa4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -25,11 +25,11 @@ def check_compatible(self, initial_state, progress_bar): `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( - self.key, - initial_state, - self.algorithm, - self.num_steps, - progress_bar, + rng_key=self.key, + initial_state=initial_state, + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, transform=lambda x: x.position, ) From 561526135f78a66272b57a88b50157b9367d3c22 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 17 May 2024 19:32:31 +0200 Subject: [PATCH 31/71] FIX TEST --- blackjax/util.py | 2 +- tests/test_util.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 746613808..070ca8687 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -201,7 +201,7 @@ def run_inference_algorithm( rng_key, init_key = split(rng_key, 2) if initial_position is not None: - initial_state = inference_algorithm.init(initial_state, init_key) + initial_state = inference_algorithm.init(initial_position, init_key) keys = split(rng_key, num_steps) diff --git a/tests/test_util.py b/tests/test_util.py index df649efa4..83955acd7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -75,7 +75,14 @@ def logdensity_fn(x): @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): - self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) + _ = run_inference_algorithm( + rng_key=self.key, + initial_position=jnp.array([1.0, 1.0]), + inference_algorithm=self.algorithm, + num_steps=self.num_steps, + progress_bar=progress_bar, + transform=lambda x: x.position, + ) @parameterized.parameters([True, False]) def test_compatible_with_initial_state(self, progress_bar): From 356cd3be9327e6ac0b5c531200a23c06b2269923 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 19:37:02 +0200 Subject: [PATCH 32/71] CLEAN UP --- .gitignore | 1 + blackjax/benchmarks/mcmc/benchmark.py | 815 +++++++++++++----- blackjax/benchmarks/mcmc/inference_models.py | 753 +++++++++------- .../benchmarks/mcmc/sampling_algorithms.py | 140 +-- blackjax/mcmc/integrators.py | 20 +- 5 files changed, 1131 insertions(+), 598 deletions(-) diff --git a/.gitignore b/.gitignore index d9186a6e9..d313a6b81 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # Edit at https://www.gitignore.io/?templates=python explore.py +blackjax/benchmarks/ ### Python ### # Byte-compiled / optimized / DLL files diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py index 174cd30f7..0ab3b4c5e 100644 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ b/blackjax/benchmarks/mcmc/benchmark.py @@ -1,13 +1,14 @@ # mypy: ignore-errors # flake8: noqa -from collections import defaultdict -from functools import partial import math import operator import os import pprint +from collections import defaultdict +from functools import partial from statistics import mean, median + import jax import jax.numpy as jnp import pandas as pd @@ -15,7 +16,7 @@ from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(128) +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) num_cores = jax.local_device_count() # print(num_cores, jax.lib.xla_bridge.get_backend().platform) @@ -24,48 +25,81 @@ import numpy as np import blackjax -from blackjax.benchmarks.mcmc.sampling_algorithms import run_mclmc, run_mhmclmc, run_nuts, samplers -from blackjax.benchmarks.mcmc.inference_models import Brownian, GermanCredit, ItemResponseTheory, MixedLogit, StandardNormal, StochasticVolatility, models -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator, integrator_order, isokinetic_mclachlan, mclachlan_coefficients, name_integrator, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients +from blackjax.benchmarks.mcmc.inference_models import ( + Brownian, + GermanCredit, + ItemResponseTheory, + MixedLogit, + StandardNormal, + StochasticVolatility, + models, +) +from blackjax.benchmarks.mcmc.sampling_algorithms import ( + run_mclmc, + run_mhmclmc, + run_nuts, + samplers, +) +from blackjax.mcmc.integrators import ( + calls_per_integrator_step, + generate_euclidean_integrator, + generate_isokinetic_integrator, + integrator_order, + isokinetic_mclachlan, + mclachlan_coefficients, + name_integrator, + omelyan_coefficients, + velocity_verlet, + velocity_verlet_coefficients, + yoshida_coefficients, +) + # from blackjax.mcmc.mhmclmc import rescale from blackjax.util import run_inference_algorithm -target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8} +target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} + def get_num_latents(target): - return target.ndims + return target.ndims + + # return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) def err(f_true, var_f, contract): """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) + Args: + f: E_sampler[f(x)], can be a vector + f_true: E_true[f(x)] + var_f: Var_true[f(x)] + contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max + Returns: + contract(b^2) + """ + + return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) -def grads_to_low_error(err_t, grad_evals_per_step= 1, low_error= 0.01): +def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - + b^2 = 1/neff""" + cutoff_reached = err_t[-1] < low_error return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff= 100): - - grads_to_low, cutoff_reached = grads_to_low_error(err_t, grad_evals_per_step, 1./neff) - - return (neff / grads_to_low) * cutoff_reached, grads_to_low*(1/cutoff_reached), cutoff_reached + + +def calculate_ess(err_t, grad_evals_per_step, neff=100): + grads_to_low, cutoff_reached = grads_to_low_error( + err_t, grad_evals_per_step, 1.0 / neff + ) + + return ( + (neff / grads_to_low) * cutoff_reached, + grads_to_low * (1 / cutoff_reached), + cutoff_reached, + ) def find_crossing(array, cutoff): @@ -77,34 +111,61 @@ def find_crossing(array, cutoff): print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) return 1 - return jnp.max(indices)+1 + return jnp.max(indices) + 1 def cumulative_avg(samples): - return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps, center_L, center_step_size, contract): + return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] + + +def gridsearch_tune( + key, + iterations, + grid_size, + model, + sampler, + batch, + num_steps, + center_L, + center_step_size, + contract, +): results = defaultdict(float) converged = False - keys = jax.random.split(key, iterations+1) + keys = jax.random.split(key, iterations + 1) for i in range(iterations): print(f"EPOCH {i}") width = 2 - step_sizes = np.logspace(np.log10(center_step_size/width), np.log10(center_step_size*width), grid_size) - Ls = np.logspace(np.log10(center_L/2), np.log10(center_L*2),grid_size) + step_sizes = np.logspace( + np.log10(center_step_size / width), + np.log10(center_step_size * width), + grid_size, + ) + Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) # print(list(itertools.product(step_sizes , Ls))) - grid_keys = jax.random.split(keys[i], grid_size^2) + grid_keys = jax.random.split(keys[i], grid_size ^ 2) print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes , Ls)): - ess, grad_calls_until_convergence, _ , _, _ = benchmark_chains(model, sampler(step_size=step_size, L=L), grid_keys[j], n=num_steps, batch = batch, contract=contract) + for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): + ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( + model, + sampler(step_size=step_size, L=L), + grid_keys[j], + n=num_steps, + batch=batch, + contract=contract, + ) results[(step_size, L)] = (ess, grad_calls_until_convergence) - best_ess, best_grads, (step_size, L) = max([(results[r][0], results[r][1], r) for r in results], key=operator.itemgetter(0)) + best_ess, best_grads, (step_size, L) = max( + ((results[r][0], results[r][1], r) for r in results), + key=operator.itemgetter(0), + ) # raise Exception - print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - if L==center_L and step_size==center_step_size: + print( + f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" + ) + if L == center_L and step_size == center_step_size: print("converged") converged = True break @@ -112,42 +173,57 @@ def gridsearch_tune(key, iterations, grid_size, model, sampler, batch, num_steps center_L, center_step_size = L, step_size pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") + # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") + # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") return center_L, center_step_size, converged def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - def s(logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - num_steps_per_traj = L/step_size + num_steps_per_traj = L / step_size alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(num_steps_per_traj)) , - integrator=integrator, - std_mat=std_mat, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(num_steps_per_traj) + ), + integrator=integrator, + std_mat=std_mat, ) _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) + rng_key=key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) - return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), num_steps_per_traj * calls_per_integrator_step(coefficients), info.acceptance_rate.mean(), None, jnp.array([0]) + return ( + out, + MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), + num_steps_per_traj * calls_per_integrator_step(coefficients), + info.acceptance_rate.mean(), + None, + jnp.array([0]), + ) return s -def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.average,): +def benchmark_chains( + model, + sampler, + key, + n=10000, + batch=None, + contract=jnp.average, +): pvmap = jax.pmap - + d = get_num_latents(model) if batch is None: batch = np.ceil(1000 / d).astype(int) @@ -158,13 +234,31 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av init_pos = pvmap(model.sample_init)(init_keys) # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - samples, params, grad_calls_per_traj, acceptance_rate, step_size_over_da, final_da = pvmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys) + ( + samples, + params, + grad_calls_per_traj, + acceptance_rate, + step_size_over_da, + final_da, + ) = pvmap( + lambda pos, key: sampler( + logdensity_fn=model.logdensity_fn, + num_steps=n, + initial_position=pos, + transform=model.transform, + key=key, + ) + )( + init_pos, keys + ) avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) try: - print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0)) - except: pass - - full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) + print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) + except: + pass + + full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) err_t = pvmap(full)(samples**2) # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] @@ -174,7 +268,6 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # return(mean(esses), mean(grad_calls)) # print(final_da.mean(), "final da") - err_t_median = jnp.median(err_t, axis=0) # import matplotlib.pyplot as plt # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) @@ -184,62 +277,106 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av # plt.yscale('log') # plt.savefig('brownian.png') # plt.close() - esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj) - return esses, grad_calls, params, jnp.mean(acceptance_rate, axis=0), step_size_over_da - - + esses, grad_calls, _ = calculate_ess( + err_t_median, grad_evals_per_step=avg_grad_calls_per_traj + ) + return ( + esses, + grad_calls, + params, + jnp.mean(acceptance_rate, axis=0), + step_size_over_da, + ) def run_benchmarks(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): sampler, model, coefficients = variables - num_chains = batch_size#1 + batch_size//model.ndims - + num_chains = batch_size # 1 + batch_size//model.ndims num_steps = 100000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.max key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients, frac_tune1=0.1, frac_tune2=0.0, frac_tune3=0.0),key1, n=num_steps, batch=num_chains, contract=contract) + ( + ess, + grad_calls, + params, + acceptance_rate, + step_size_over_da, + ) = benchmark_chains( + model, + partial( + samplers[sampler], + coefficients=coefficients, + frac_tune1=0.1, + frac_tune2=0.0, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) jax.numpy.save(f"acceptance.npy", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() print(ess.item()) # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_simple.csv", index=False) @@ -248,19 +385,17 @@ def run_benchmarks(batch_size): def run_simple(): - results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [Brownian()], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): sampler, model, coefficients = variables num_chains = 128 @@ -271,68 +406,127 @@ def run_simple(): key = jax.random.PRNGKey(11) for i in range(1): key1, key = jax.random.split(key) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + ( + ess, + grad_calls, + params, + acceptance_rate, + step_size_over_da, + ) = benchmark_chains( + model, + partial(samplers[sampler], coefficients=coefficients), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) - - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() print(ess.item()) - return results + # vary step_size def run_benchmarks_step_size(batch_size): - results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmclmc"], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], [StandardNormal(10)], # [Brownian()], # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients], + ): num_steps = 10000 sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) + # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) contract = jnp.average center = 6.534974 key = jax.random.PRNGKey(11) - for step_size in np.linspace(center-1,center+1, 41): - # for L in np.linspace(1, 10, 41): + for step_size in np.linspace(center - 1, center + 1, 41): + # for L in np.linspace(1, 10, 41): key1, key2, key3, key = jax.random.split(key, 4) initial_position = model.sample_init(key2) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=key3) - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(initial_state=initial_state, coefficients=mclachlan_coefficients, step_size=step_size, L= 5*step_size, std_mat=1.),key1, n=num_steps, batch=num_chains, contract=contract) + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=key3, + ) + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + initial_state=initial_state, + coefficients=mclachlan_coefficients, + step_size=step_size, + L=5 * step_size, + std_mat=1.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - # print(f"grads to low bias: {grad_calls}") # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - results[((model.name, model.ndims), sampler, name_integrator(coefficients), "standard", acceptance_rate.mean().item(), params.L.mean().item(), params.step_size.mean().item(), num_chains, num_steps, contract)] = ess.item() + results[ + ( + (model.name, model.ndims), + sampler, + name_integrator(coefficients), + "standard", + acceptance_rate.mean().item(), + params.L.mean().item(), + params.step_size.mean().item(), + num_chains, + num_steps, + contract, + ) + ] = ess.item() # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - # print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "sampler", "integrator", "tuning", "acc rate", "L", "stepsize", "num_chains", "num steps", "contraction", "ESS"] + df.columns = [ + "model", + "sampler", + "integrator", + "tuning", + "acc rate", + "L", + "stepsize", + "num_chains", + "num steps", + "contraction", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results_step_size.csv", index=False) @@ -340,17 +534,14 @@ def run_benchmarks_step_size(batch_size): return results - def benchmark_mhmchmc(batch_size): - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) results = defaultdict(tuple) # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] for model, coeffs in itertools.product(models, coefficients): - - num_chains = batch_size # 1 + batch_size//model.ndims + num_chains = batch_size # 1 + batch_size//model.ndims print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") num_steps = models[model]["mhmclmc"] print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") @@ -358,67 +549,123 @@ def benchmark_mhmchmc(batch_size): ####### run mclmc with standard tuning contract = jnp.max - - ess, grad_calls, params , _, step_size_over_da = benchmark_chains( + ess, grad_calls, params, _, step_size_over_da = benchmark_chains( model, - partial(run_mclmc,coefficients=coeffs), + partial(run_mclmc, coefficients=coeffs), key0, n=num_steps, batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mclmc", params.L.mean().item(), params.step_size.mean().item(), name_integrator(coeffs), "standard", 1.)] = ess.item() - print(f'mclmc with tuning ESS {ess}') - + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mclmc", + params.L.mean().item(), + params.step_size.mean().item(), + name_integrator(coeffs), + "standard", + 1.0, + ) + ] = ess.item() + print(f"mclmc with tuning ESS {ess}") - ####### run mhmclmc with standard tuning + ####### run mhmclmc with standard tuning for target_acc_rate in [0.65, 0.9]: # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate, coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') - + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_mhmclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"mhmclmc with tuning ESS {ess}") + # coeffs = mclachlan_coefficients - ess, grad_calls, params , acceptance_rate, _ = benchmark_chains( - model, - partial(run_mhmclmc, target_acc_rate=target_acc_rate,coefficients=coeffs, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.1), - key1, - n=num_steps, - batch=num_chains, - contract=contract) - results[(model.name, model.ndims, "mhmchmc:st3"+str(target_acc_rate), jnp.nanmean(params.L).item(), jnp.nanmean(params.step_size).item(), name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - print(f'mhmclmc with tuning ESS {ess}') + ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( + model, + partial( + run_mhmclmc, + target_acc_rate=target_acc_rate, + coefficients=coeffs, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + ), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "mhmchmc:st3" + str(target_acc_rate), + jnp.nanmean(params.L).item(), + jnp.nanmean(params.step_size).item(), + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() + print(f"mhmclmc with tuning ESS {ess}") if True: ####### run mhmclmc with standard tuning + grid search - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split(key2, 5) + init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( + key2, 5 + ) initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coeffs), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_mhmclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -428,96 +675,165 @@ def benchmark_mhmchmc(batch_size): frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}") - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") - + print( + f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" + ) + print( + f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" + ) - L, step_size, convergence = gridsearch_tune(grid_key, iterations=10, contract=contract, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coeffs, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + L, step_size, convergence = gridsearch_tune( + grid_key, + iterations=10, + contract=contract, + grid_size=5, + model=model, + sampler=partial( + run_mhmclmc_no_tuning, + coefficients=coeffs, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_mhmclmc_sampler_params.L, + center_step_size=blackjax_mhmclmc_sampler_params.step_size, + ) # print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coeffs, L=L, step_size=step_size, initial_state=state, std_mat=1.),bench_key, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + coefficients=coeffs, + L=L, + step_size=step_size, + initial_state=state, + std_mat=1.0, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, "mhmchmc:grid", L.item(), step_size.item(), name_integrator(coeffs), f"gridsearch:{convergence}", acceptance_rate.mean().item())] = ess.item() + results[ + ( + model.name, + model.ndims, + "mhmchmc:grid", + L.item(), + step_size.item(), + name_integrator(coeffs), + f"gridsearch:{convergence}", + acceptance_rate.mean().item(), + ) + ] = ess.item() ####### run nuts # coeffs = velocity_verlet_coefficients - ess, grad_calls, _ , acceptance_rate, _ = benchmark_chains(model, partial(run_nuts,coefficients=coeffs),key3, n=models[model]["nuts"], batch=num_chains, contract=contract) - results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate.mean().item())] = ess.item() - - - - - + ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( + model, + partial(run_nuts, coefficients=coeffs), + key3, + n=models[model]["nuts"], + batch=num_chains, + contract=contract, + ) + results[ + ( + model.name, + model.ndims, + "nuts", + 0.0, + 0.0, + name_integrator(coeffs), + "standard", + acceptance_rate.mean().item(), + ) + ] = ess.item() - print(results) - df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "L", "step_size", "integrator", "tuning", "acc_rate", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "L", + "step_size", + "integrator", + "tuning", + "acc_rate", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("results.csv", index=False) return results -def benchmark_omelyan(batch_size): - +def benchmark_omelyan(batch_size): key = jax.random.PRNGKey(2) results = defaultdict(tuple) for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int)], + # ["mhmclmc", "nuts", "mclmc", ], + ["mhmchmc"], + [ + StandardNormal(d) + for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) + ], # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - - + # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], + [mclachlan_coefficients, omelyan_coefficients], + ): sampler, model, coefficients = variables # num_chains = 1 + batch_size//model.ndims num_chains = batch_size - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split(current_key, 5) + current_key, key = jax.random.split(key) + init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( + current_key, 5 + ) # num_steps = models[model][sampler] num_steps = 1000 - initial_position = model.sample_init(init_pos_key) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=model.logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=model.logdensity_fn, + random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn) + integrator=generate_isokinetic_integrator(coefficients), + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=model.logdensity_fn, + ) ( state, blackjax_mhmclmc_sampler_params, - _, _ + _, + _, ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -527,57 +843,112 @@ def benchmark_omelyan(batch_size): frac_tune1=0.1, frac_tune2=0.1, # frac_tune3=0.1, - diagonal_preconditioning=False + diagonal_preconditioning=False, ) - print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - print(f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}") + print( + f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", + ) + print( + f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" + ) # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune(grid_key, iterations=10, contract=jnp.average, grid_size=5, model=model, sampler=partial(run_mhmclmc_no_tuning, coefficients=coefficients, initial_state=state, std_mat=1.), batch=num_chains, num_steps=num_steps, center_L=blackjax_mhmclmc_sampler_params.L, center_step_size=blackjax_mhmclmc_sampler_params.step_size) + # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) + + L, step_size, converged = gridsearch_tune( + grid_key, + iterations=10, + contract=jnp.average, + grid_size=5, + model=model, + sampler=partial( + run_mhmclmc_no_tuning, + coefficients=coefficients, + initial_state=state, + std_mat=1.0, + ), + batch=num_chains, + num_steps=num_steps, + center_L=blackjax_mhmclmc_sampler_params.L, + center_step_size=blackjax_mhmclmc_sampler_params.step_size, + ) print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _ , _, _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=L, step_size=step_size, std_mat=1., initial_state=state),bench_key, n=num_steps, batch=num_chains, contract=jnp.average) + ess, grad_calls, _, _, _ = benchmark_chains( + model, + run_mhmclmc_no_tuning( + coefficients=coefficients, + L=L, + step_size=step_size, + std_mat=1.0, + initial_state=state, + ), + bench_key, + n=num_steps, + batch=num_chains, + contract=jnp.average, + ) print(f"grads to low bias: {grad_calls}") - results[(model.name, model.ndims, sampler, name_integrator(coefficients), converged, L.item(), step_size.item())] = ess.item() + results[ + ( + model.name, + model.ndims, + sampler, + name_integrator(coefficients), + converged, + L.item(), + step_size.item(), + ) + ] = ess.item() df = pd.Series(results).reset_index() - df.columns = ["model", "dims", "sampler", "integrator", "convergence", "L", "step_size", "ESS"] + df.columns = [ + "model", + "dims", + "sampler", + "integrator", + "convergence", + "L", + "step_size", + "ESS", + ] # df.result = df.result.apply(lambda x: x[0].item()) # df.model = df.model.apply(lambda x: x[1]) df.to_csv("omelyan.csv", index=False) def run_benchmarks_divij(): - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal + model = StandardNormal(10) # 10 dimensional standard normal coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions + contract = jnp.average # how we average across dimensions num_steps = 2000 num_chains = 100 key1 = jax.random.PRNGKey(2) - ess, grad_calls, params , acceptance_rate, step_size_over_da = benchmark_chains(model, partial(sampler, coefficients=coefficients),key1, n=num_steps, batch=num_chains, contract=contract) + ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( + model, + partial(sampler, coefficients=coefficients), + key1, + n=num_steps, + batch=num_chains, + contract=contract, + ) print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") -if __name__ == "__main__": +if __name__ == "__main__": # run_benchmarks_divij() - - # benchmark_mhmchmc(batch_size=128) run_simple() # run_benchmarks_step_size(128) # benchmark_omelyan(128) # run_benchmarks(128) - #benchmark_omelyan(10) + # benchmark_omelyan(10) # print("4") diff --git a/blackjax/benchmarks/mcmc/inference_models.py b/blackjax/benchmarks/mcmc/inference_models.py index b918ce3bf..715bd4c14 100644 --- a/blackjax/benchmarks/mcmc/inference_models.py +++ b/blackjax/benchmarks/mcmc/inference_models.py @@ -1,51 +1,50 @@ # mypy: ignore-errors # flake8: noqa -#from inference_gym import using_jax as gym +import os + +# from inference_gym import using_jax as gym import jax import jax.numpy as jnp import numpy as np -import os -#import numpyro.distributions as dist -dirr = os.path.dirname(os.path.realpath(__file__)) +# import numpyro.distributions as dist +dirr = os.path.dirname(os.path.realpath(__file__)) -class StandardNormal(): +class StandardNormal: """Standard Normal distribution in d dimensions""" def __init__(self, d): self.ndims = d self.E_x2 = jnp.ones(d) self.Var_x2 = 2 * self.E_x2 - self.name = 'StandardNormal' - + self.name = "StandardNormal" def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x), axis= -1) - + return -0.5 * jnp.sum(jnp.square(x), axis=-1) def transform(self, x): return x def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) + return jax.random.normal(key, shape=(self.ndims,)) - -class IllConditionedGaussian(): +class IllConditionedGaussian: """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" - - def __init__(self, d, condition_number, numpy_seed=None, prior= 'prior'): + def __init__(self, d, condition_number, numpy_seed=None, prior="prior"): """numpy_seed is used to generate a random rotation for the covariance matrix. - If None, the covariance matrix is diagonal.""" + If None, the covariance matrix is diagonal.""" self.ndims = d - self.name = 'IllConditionedGaussian' + self.name = "IllConditionedGaussian" self.condition_number = condition_number - eigs = jnp.logspace(-0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d) + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d + ) if numpy_seed == None: # diagonal self.E_x2 = eigs @@ -57,268 +56,296 @@ def __init__(self, d, condition_number, numpy_seed=None, prior= 'prior'): rng = np.random.RandomState(seed=numpy_seed) D = jnp.diag(eigs) inv_D = jnp.diag(1 / eigs) - R, _ = jnp.array(np.linalg.qr(rng.randn(self.ndims, self.ndims))) # random rotation + R, _ = jnp.array( + np.linalg.qr(rng.randn(self.ndims, self.ndims)) + ) # random rotation self.R = R self.Hessian = R @ inv_D @ R.T self.Cov = R @ D @ R.T self.E_x2 = jnp.diagonal(R @ D @ R.T) - #Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) + # Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) - #print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) + # print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) self.Var_x2 = 2 * jnp.square(self.E_x2) - self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x self.transform = lambda x: x - - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.zeros(self.ndims) - elif prior == 'posterior': - self.sample_init = lambda key: self.R @ (jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs)) + elif prior == "posterior": + self.sample_init = lambda key: self.R @ ( + jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs) + ) - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(jnp.sqrt(eigs)) + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(jnp.sqrt(eigs)) - -class IllConditionedESH(): +class IllConditionedESH: """ICG from the ESH paper.""" def __init__(self): self.ndims = 50 - self.name = 'IllConditionedESH' + self.name = "IllConditionedESH" self.variance = jnp.linspace(0.01, 1, self.ndims) - - - def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis= -1) - + return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis=-1) def transform(self, x): return x def draw(self, key): - return jax.random.normal(key, shape = (self.ndims, )) * jnp.sqrt(self.variance) + return jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(self.variance) def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) - - + return jax.random.normal(key, shape=(self.ndims,)) -class IllConditionedGaussianGamma(): +class IllConditionedGaussianGamma: """Inference gym's Ill conditioned Gaussian""" - def __init__(self, prior = 'prior'): + def __init__(self, prior="prior"): self.ndims = 100 - self.name = 'IllConditionedGaussianGamma' + self.name = "IllConditionedGaussianGamma" # define the Hessian - rng = np.random.RandomState(seed=10 & (2 ** 32 - 1)) - eigs = np.sort(rng.gamma(shape=0.5, scale=1., size=self.ndims)) #eigenvalues of the Hessian - eigs *= jnp.average(1.0/eigs) + rng = np.random.RandomState(seed=10 & (2**32 - 1)) + eigs = np.sort( + rng.gamma(shape=0.5, scale=1.0, size=self.ndims) + ) # eigenvalues of the Hessian + eigs *= jnp.average(1.0 / eigs) self.entropy = 0.5 * self.ndims - self.maxmin = (1./jnp.sqrt(eigs[0]), 1./jnp.sqrt(eigs[-1])) - R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) #random rotation + self.maxmin = (1.0 / jnp.sqrt(eigs[0]), 1.0 / jnp.sqrt(eigs[-1])) + R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) # random rotation self.map_to_worst = (R.T)[[0, -1], :] self.Hessian = R @ np.diag(eigs) @ R.T # analytic ground truth moments - self.E_x2 = jnp.diagonal(R @ np.diag(1.0/eigs) @ R.T) + self.E_x2 = jnp.diagonal(R @ np.diag(1.0 / eigs) @ R.T) self.Var_x2 = 2 * jnp.square(self.E_x2) # norm = jnp.diag(1/jnp.sqrt(self.E_x2)) # Sigma = R @ np.diag(1/eigs) @ R.T # reduced = norm @ Sigma @ norm # print(np.linalg.cond(reduced), np.linalg.cond(Sigma)) - + # gradient - - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.zeros(self.ndims) - elif prior == 'posterior': - self.sample_init = lambda key: R @ (jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs)) + elif prior == "posterior": + self.sample_init = lambda key: R @ ( + jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs) + ) + + else: # N(0, sigma_true_max) + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(1.0 / jnp.sqrt(eigs)) - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.max(1.0/jnp.sqrt(eigs)) - def logdensity_fn(self, x): """- log p of the target distribution""" return -0.5 * x.T @ self.Hessian @ x def transform(self, x): return x - - -class Banana(): +class Banana: """Banana target fromm the Inference Gym""" - def __init__(self, prior = 'map'): + def __init__(self, prior="map"): self.curvature = 0.03 self.ndims = 2 - self.name = 'Banana' - + self.name = "Banana" + self.transform = lambda x: x - self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.E_x2 = jnp.array( + [100.0, 19.0] + ) # the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. self.Var_x2 = jnp.array([20000.0, 4600.898]) - if prior == 'map': + if prior == "map": self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) - elif prior == 'posterior': + elif prior == "posterior": self.sample_init = lambda key: self.posterior_draw(key) - elif prior == 'prior': - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + elif prior == "prior": + self.sample_init = ( + lambda key: jax.random.normal(key, shape=(self.ndims,)) + * jnp.array([10.0, 5.0]) + * 2 + ) else: - raise ValueError('prior = '+prior +' is not defined.') + raise ValueError("prior = " + prior + " is not defined.") def logdensity_fn(self, x): mu2 = self.curvature * (x[0] ** 2 - 100) return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) def posterior_draw(self, key): - z = jax.random.normal(key, shape = (2, )) + z = jax.random.normal(key, shape=(2,)) x0 = 10.0 * z[0] - x1 = self.curvature * (x0 ** 2 - 100) + z[1] + x1 = self.curvature * (x0**2 - 100) + z[1] return jnp.array([x0, x1]) def ground_truth(self): - x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + x = jax.vmap(self.posterior_draw)( + jax.random.split(jax.random.PRNGKey(0), 100000000) + ) print(jnp.average(x, axis=0)) print(jnp.average(jnp.square(x), axis=0)) print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) - - -class Cauchy(): +class Cauchy: """d indpendent copies of the standard Cauchy distribution""" def __init__(self, d): self.ndims = d - self.name = 'Cauchy' - - self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1. + jnp.square(x))) - - self.transform = lambda x: x - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) - + self.name = "Cauchy" + self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1.0 + jnp.square(x))) + self.transform = lambda x: x + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) -class HardConvex(): - def __init__(self, d, kappa, theta = 0.1): +class HardConvex: + def __init__(self, d, kappa, theta=0.1): """d is the dimension, kappa = condition number, 0 < theta < 1/4""" self.ndims = d - self.name = 'HardConvex' + self.name = "HardConvex" self.theta, self.kappa = theta, kappa - C = jnp.power(d-1, 0.25 - theta) - self.logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) - (0.75 / kappa)* x[-1]**2 + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 - + C = jnp.power(d - 1, 0.25 - theta) + self.logdensity_fn = ( + lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) + - (0.75 / kappa) * x[-1] ** 2 + + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 + ) + self.transform = lambda x: x # numerically precomputed variances num_integration = [0.93295, 0.968802, 0.990595, 0.998002, 0.999819] if d == 100: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[0], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[0], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 300: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[1], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[1], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 1000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[2], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[2], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 3000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[3], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[3], jnp.ones(1) * 2.0 * kappa / 3.0) + ) elif d == 10000: - self.variance = jnp.concatenate((jnp.ones(d-1) * num_integration[4], jnp.ones(1) * 2.0*kappa/3.0)) + self.variance = jnp.concatenate( + (jnp.ones(d - 1) * num_integration[4], jnp.ones(1) * 2.0 * kappa / 3.0) + ) else: None - def sample_init(self, key): """Gaussian prior with approximately estimating the variance along each dimension""" - scale = jnp.concatenate((jnp.ones(self.ndims-1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0))) + scale = jnp.concatenate( + (jnp.ones(self.ndims - 1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0)) + ) return jax.random.normal(key, shape=(self.ndims,)) * scale - - -class BiModal(): +class BiModal: """A Gaussian mixture p(x) = f N(x | mu1, sigma1) + (1-f) N(x | mu2, sigma2).""" - def __init__(self, d = 50, mu1 = 0.0, mu2 = 8.0, sigma1 = 1.0, sigma2 = 1.0, f = 0.2): - + def __init__(self, d=50, mu1=0.0, mu2=8.0, sigma1=1.0, sigma2=1.0, f=0.2): self.ndims = d - self.name = 'BiModal' + self.name = "BiModal" - self.mu1 = jnp.insert(jnp.zeros(d-1), 0, mu1) + self.mu1 = jnp.insert(jnp.zeros(d - 1), 0, mu1) self.mu2 = jnp.insert(jnp.zeros(d - 1), 0, mu2) self.sigma1, self.sigma2 = sigma1, sigma2 self.f = f - self.variance = jnp.insert(jnp.ones(d-1) * ((1 - f) * sigma1**2 + f * sigma2**2), 0, (1-f)*(sigma1**2 + mu1**2) + f*(sigma2**2 + mu2**2)) - - + self.variance = jnp.insert( + jnp.ones(d - 1) * ((1 - f) * sigma1**2 + f * sigma2**2), + 0, + (1 - f) * (sigma1**2 + mu1**2) + f * (sigma2**2 + mu2**2), + ) def logdensity_fn(self, x): """- log p of the target distribution""" - N1 = (1.0 - self.f) * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu1), axis= -1) / self.sigma1 ** 2) / jnp.power(2 * jnp.pi * self.sigma1 ** 2, self.ndims * 0.5) - N2 = self.f * jnp.exp(-0.5 * jnp.sum(jnp.square(x - self.mu2), axis= -1) / self.sigma2 ** 2) / jnp.power(2 * jnp.pi * self.sigma2 ** 2, self.ndims * 0.5) + N1 = ( + (1.0 - self.f) + * jnp.exp( + -0.5 * jnp.sum(jnp.square(x - self.mu1), axis=-1) / self.sigma1**2 + ) + / jnp.power(2 * jnp.pi * self.sigma1**2, self.ndims * 0.5) + ) + N2 = ( + self.f + * jnp.exp( + -0.5 * jnp.sum(jnp.square(x - self.mu2), axis=-1) / self.sigma2**2 + ) + / jnp.power(2 * jnp.pi * self.sigma2**2, self.ndims * 0.5) + ) return jnp.log(N1 + N2) - def draw(self, num_samples): """direct sampler from a target""" - X = np.random.normal(size = (num_samples, self.ndims)) + X = np.random.normal(size=(num_samples, self.ndims)) mask = np.random.uniform(0, 1, num_samples) < self.f X[mask, :] = (X[mask, :] * self.sigma2) + self.mu2 X[~mask] = (X[~mask] * self.sigma1) + self.mu1 return X - def transform(self, x): return x def sample_init(self, key): - z = jax.random.normal(key, shape = (self.ndims, )) *self.sigma1 - #z= z.at[0].set(self.mu1 + z[0]) + z = jax.random.normal(key, shape=(self.ndims,)) * self.sigma1 + # z= z.at[0].set(self.mu1 + z[0]) return z -class BiModalEqual(): +class BiModalEqual: """Mixture of two Gaussians, one centered at x = [mu/2, 0, 0, ...], the other at x = [-mu/2, 0, 0, ...]. - Both have equal probability mass.""" + Both have equal probability mass.""" def __init__(self, d, mu): - self.ndims = d - self.name = 'BiModalEqual' + self.name = "BiModalEqual" self.mu = mu - - def logdensity_fn(self, x): """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x), axis= -1) + jnp.log(jnp.cosh(0.5*self.mu*x[0])) - 0.5* self.ndims * jnp.log(2 * jnp.pi) - self.mu**2 / 8.0 - + return ( + -0.5 * jnp.sum(jnp.square(x), axis=-1) + + jnp.log(jnp.cosh(0.5 * self.mu * x[0])) + - 0.5 * self.ndims * jnp.log(2 * jnp.pi) + - self.mu**2 / 8.0 + ) def draw(self, num_samples): """direct sampler from a target""" - X = np.random.normal(size = (num_samples, self.ndims)) + X = np.random.normal(size=(num_samples, self.ndims)) mask = np.random.uniform(0, 1, num_samples) < 0.5 - X[mask, 0] += 0.5*self.mu + X[mask, 0] += 0.5 * self.mu X[~mask, 0] -= 0.5 * self.mu return X @@ -327,82 +354,79 @@ def transform(self, x): return x -class Funnel(): +class Funnel: """Noise-less funnel""" - def __init__(self, d = 20): - + def __init__(self, d=20): self.ndims = d - self.name = 'Funnel' - self.sigma_theta= 3.0 - - self.E_x2 = jnp.ones(d) # the transformed variables are standard Gaussian distributed - self.Var_x2 = 2 * self.E_x2 - + self.name = "Funnel" + self.sigma_theta = 3.0 + self.E_x2 = jnp.ones( + d + ) # the transformed variables are standard Gaussian distributed + self.Var_x2 = 2 * self.E_x2 def logdensity_fn(self, x): - """ - log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta] """ + """- log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta]""" theta = x[-1] - X = x[..., :- 1] + X = x[..., :-1] - return -0.5* jnp.square(theta / self.sigma_theta) - 0.5 * (self.ndims - 1) * theta - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis = -1) + return ( + -0.5 * jnp.square(theta / self.sigma_theta) + - 0.5 * (self.ndims - 1) * theta + - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis=-1) + ) def inverse_transform(self, xtilde): theta = 3 * xtilde[-1] - return jnp.concatenate((xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1)*theta)) - + return jnp.concatenate( + (xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1) * theta) + ) def transform(self, x): """gaussianization""" xtilde = jnp.empty(x.shape) xtilde = xtilde.at[-1].set(x.T[-1] / 3.0) - xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5*x.T[-1])) + xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5 * x.T[-1])) return xtilde.T - def sample_init(self, key): - return self.inverse_transform(jax.random.normal(key, shape = (self.ndims, ))) - - - + return self.inverse_transform(jax.random.normal(key, shape=(self.ndims,))) -class Funnel_with_Data(): +class Funnel_with_Data: def __init__(self, d, sigma, minibatch_size, key): - self.ndims = d - self.name = 'Funnel_with_Data' - self.sigma_theta= 3.0 + self.name = "Funnel_with_Data" + self.sigma_theta = 3.0 self.theta_true = 0.0 self.sigma_data = sigma - self.data = self.simulate_data() self.batch = minibatch_size def simulate_data(self): - - norm = jax.random.normal(jax.random.PRNGKey(123), shape = (2*(self.ndims-1), )) - z_true = norm[:self.ndims-1] * jnp.exp(self.theta_true * 0.5) - self.data = z_true + norm[self.ndims-1:] * self.sigma_data - + norm = jax.random.normal(jax.random.PRNGKey(123), shape=(2 * (self.ndims - 1),)) + z_true = norm[: self.ndims - 1] * jnp.exp(self.theta_true * 0.5) + self.data = z_true + norm[self.ndims - 1 :] * self.sigma_data def logdensity_fn(self, x, subset): - """ - log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta] """ + """- log p of the target distribution + x = [z_0, z_1, ... z_{d-1}, theta]""" theta = x[-1] - z = x[:- 1][subset] + z = x[:-1][subset] prior_theta = jnp.square(theta / self.sigma_theta) - prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum(jnp.square(z*subset)) - likelihood = jnp.sum(jnp.square((z - self.data)*subset / self.sigma_data)) + prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum( + jnp.square(z * subset) + ) + likelihood = jnp.sum(jnp.square((z - self.data) * subset / self.sigma_data)) return -0.5 * (prior_theta + prior_z + likelihood) - def transform(self, x): """gaussianization""" return x @@ -410,58 +434,56 @@ def transform(self, x): def sample_init(self, key): key1, key2 = jax.random.split(key) theta = jax.random.normal(key1) * self.sigma_theta - z = jax.random.normal(key2, shape = (self.ndims-1, )) * jnp.exp(theta * 0.5) + z = jax.random.normal(key2, shape=(self.ndims - 1,)) * jnp.exp(theta * 0.5) return jnp.concatenate((z, theta)) - - -class Rosenbrock(): - - def __init__(self, d = 36, Q = 0.1): - +class Rosenbrock: + def __init__(self, d=36, Q=0.1): self.ndims = d - self.name = 'Rosenbrock' + self.name = "Rosenbrock" self.Q = Q - #ground truth moments + # ground truth moments var_x = 2.0 - #these two options were precomputed: + # these two options were precomputed: if Q == 0.1: - var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) + var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) elif Q == 0.5: var_y = 10.498957879911487 else: - raise ValueError('Ground truth moments for Q = ' + str(Q) + ' were not precomputed. Use Q = 0.1 or 0.5.') - - self.variance = jnp.concatenate((var_x * jnp.ones(d//2), var_y * jnp.ones(d//2))) - - + raise ValueError( + "Ground truth moments for Q = " + + str(Q) + + " were not precomputed. Use Q = 0.1 or 0.5." + ) + self.variance = jnp.concatenate( + (var_x * jnp.ones(d // 2), var_y * jnp.ones(d // 2)) + ) def logdensity_fn(self, x): """- log p of the target distribution""" - X, Y = x[..., :self.ndims//2], x[..., self.ndims//2:] - return -0.5 * jnp.sum(jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis= -1) - - + X, Y = x[..., : self.ndims // 2], x[..., self.ndims // 2 :] + return -0.5 * jnp.sum( + jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis=-1 + ) def draw(self, num): n = self.ndims // 2 - X= np.empty((num, self.ndims)) - X[:, :n] = np.random.normal(loc= 1.0, scale= 1.0, size= (num, n)) - X[:, n:] = np.random.normal(loc= jnp.square(X[:, :n]), scale= jnp.sqrt(self.Q), size= (num, n)) + X = np.empty((num, self.ndims)) + X[:, :n] = np.random.normal(loc=1.0, scale=1.0, size=(num, n)) + X[:, n:] = np.random.normal( + loc=jnp.square(X[:, :n]), scale=jnp.sqrt(self.Q), size=(num, n) + ) return X - def transform(self, x): return x - def sample_init(self, key): - return jax.random.normal(key, shape = (self.ndims, )) - + return jax.random.normal(key, shape=(self.ndims,)) def ground_truth(self): num = 100000000 @@ -474,13 +496,12 @@ def ground_truth(self): x1 = np.average(x) y1 = np.average(y) - print(np.sqrt(0.5*(np.square(np.std(x)) + np.square(np.std(y))))) + print(np.sqrt(0.5 * (np.square(np.std(x)) + np.square(np.std(y))))) print(x2, y2) - -class Brownian(): +class Brownian: """ log sigma_i ~ N(0, 2) log sigma_obs ~N(0, 2) @@ -493,36 +514,75 @@ class Brownian(): def __init__(self): self.num_data = 30 - self.name = 'Brownian' + self.name = "Brownian" self.ndims = self.num_data + 2 - ground_truth_moments = jnp.load(dirr + '/ground_truth/brownian/ground_truth.npy') + ground_truth_moments = jnp.load( + dirr + "/ground_truth/brownian/ground_truth.npy" + ) self.E_x2, self.Var_x2 = ground_truth_moments[0], ground_truth_moments[1] - self.data = jnp.array([0.21592641, 0.118771404, -0.07945447, 0.037677474, -0.27885845, -0.1484156, -0.3250906, -0.22957903, - -0.44110894, -0.09830782, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8786016, -0.83736074, - -0.7384849, -0.8939254, -0.7774566, -0.70238715, -0.87771565, -0.51853573, -0.6948214, -0.6202789]) + self.data = jnp.array( + [ + 0.21592641, + 0.118771404, + -0.07945447, + 0.037677474, + -0.27885845, + -0.1484156, + -0.3250906, + -0.22957903, + -0.44110894, + -0.09830782, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -0.8786016, + -0.83736074, + -0.7384849, + -0.8939254, + -0.7774566, + -0.70238715, + -0.87771565, + -0.51853573, + -0.6948214, + -0.6202789, + ] + ) # sigma_obs = 0.15, sigma_i = 0.1 self.observable = jnp.concatenate((jnp.ones(10), jnp.zeros(10), jnp.ones(10))) self.num_observable = jnp.sum(self.observable) # = 20 - def logdensity_fn(self, x): # y = softplus_to_log(x[:2]) - lik = 0.5 * jnp.exp(-2 * x[1]) * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) + x[ - 1] * self.num_observable - prior_x = 0.5 * jnp.exp(-2 * x[0]) * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) + x[0] * self.num_data + lik = ( + 0.5 + * jnp.exp(-2 * x[1]) + * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) + + x[1] * self.num_observable + ) + prior_x = ( + 0.5 + * jnp.exp(-2 * x[0]) + * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) + + x[0] * self.num_data + ) prior_logsigma = 0.5 * jnp.sum(jnp.square(x / 2.0)) return -lik - prior_x - prior_logsigma - def transform(self, x): return jnp.concatenate((jnp.exp(x[:2]), x[2:])) - def sample_init(self, key): key_walk, key_sigma = jax.random.split(key) @@ -530,8 +590,10 @@ def sample_init(self, key): # log_sigma = jax.random.normal(key_sigma, shape= (2, )) * 2 # narrower prior - log_sigma = jnp.log(np.array([0.1, 0.15])) + jax.random.normal(key_sigma, shape=( - 2,)) * 0.1 # *0.05# log sigma_i, log sigma_obs + log_sigma = ( + jnp.log(np.array([0.1, 0.15])) + + jax.random.normal(key_sigma, shape=(2,)) * 0.1 + ) # *0.05# log sigma_i, log sigma_obs walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) @@ -540,50 +602,50 @@ def sample_init(self, key): def generate_data(self, key): key_walk, key_sigma, key_noise = jax.random.split(key, 3) - log_sigma = jax.random.normal(key_sigma, shape=(2,)) * 2 # log sigma_i, log sigma_obs + log_sigma = ( + jax.random.normal(key_sigma, shape=(2,)) * 2 + ) # log sigma_i, log sigma_obs walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) - noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp(log_sigma[1]) + noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp( + log_sigma[1] + ) return walk + noise class GermanCredit: - """ Taken from inference gym. + """Taken from inference gym. - x = (global scale, local scales, weights) + x = (global scale, local scales, weights) - global_scale ~ Gamma(0.5, 0.5) + global_scale ~ Gamma(0.5, 0.5) - for i in range(num_features): - unscaled_weights[i] ~ Normal(loc=0, scale=1) - local_scales[i] ~ Gamma(0.5, 0.5) - weights[i] = unscaled_weights[i] * local_scales[i] * global_scale + for i in range(num_features): + unscaled_weights[i] ~ Normal(loc=0, scale=1) + local_scales[i] ~ Gamma(0.5, 0.5) + weights[i] = unscaled_weights[i] * local_scales[i] * global_scale - for j in range(num_datapoints): - label[j] ~ Bernoulli(features @ weights) + for j in range(num_datapoints): + label[j] ~ Bernoulli(features @ weights) - We use a log transform for the scale parameters. + We use a log transform for the scale parameters. """ def __init__(self): - self.ndims = 51 #global scale + 25 local scales + 25 weights - self.name = 'GermanCredit' + self.ndims = 51 # global scale + 25 local scales + 25 weights + self.name = "GermanCredit" - self.labels = jnp.load(dirr + '/data/gc_labels.npy') - self.features = jnp.load(dirr + '/data/gc_features.npy') + self.labels = jnp.load(dirr + "/data/gc_labels.npy") + self.features = jnp.load(dirr + "/data/gc_features.npy") - truth = jnp.load(dirr+'/ground_truth/german_credit/ground_truth.npy') + truth = jnp.load(dirr + "/ground_truth/german_credit/ground_truth.npy") self.E_x2, self.Var_x2 = truth[0], truth[1] - - - def transform(self, x): return jnp.concatenate((jnp.exp(x[:26]), x[26:])) def logdensity_fn(self, x): - scales = jnp.exp(x[:26]) # prior @@ -594,139 +656,166 @@ def logdensity_fn(self, x): # likelihood weights = scales[0] * scales[1:26] * x[26:] - logits = self.features @ weights # = jnp.einsum('nd,...d->...n', self.features, weights) - lik = jnp.sum(self.labels * jnp.logaddexp(0., -logits) + (1-self.labels)* jnp.logaddexp(0., logits)) + logits = ( + self.features @ weights + ) # = jnp.einsum('nd,...d->...n', self.features, weights) + lik = jnp.sum( + self.labels * jnp.logaddexp(0.0, -logits) + + (1 - self.labels) * jnp.logaddexp(0.0, logits) + ) return -(lik + pr + transform) def sample_init(self, key): - weights = jax.random.normal(key, shape = (25, )) + weights = jax.random.normal(key, shape=(25,)) return jnp.concatenate((jnp.zeros(26), weights)) - - class ItemResponseTheory: - """ Taken from inference gym.""" + """Taken from inference gym.""" def __init__(self): self.ndims = 501 - self.name = 'ItemResponseTheory' + self.name = "ItemResponseTheory" self.students = 400 self.questions = 100 - self.mask = jnp.load(dirr + '/data/irt_mask.npy') - self.labels = jnp.load(dirr + '/data/irt_labels.npy') + self.mask = jnp.load(dirr + "/data/irt_mask.npy") + self.labels = jnp.load(dirr + "/data/irt_labels.npy") - truth = jnp.load(dirr+'/ground_truth/item_response_theory/ground_truth.npy') + truth = jnp.load(dirr + "/ground_truth/item_response_theory/ground_truth.npy") self.E_x2, self.Var_x2 = truth[0], truth[1] - self.transform = lambda x: x def logdensity_fn(self, x): - - students = x[:self.students] + students = x[: self.students] mean = x[self.students] - questions = x[self.students + 1:] + questions = x[self.students + 1 :] # prior - pr = 0.5 * (jnp.square(mean - 0.75) + jnp.sum(jnp.square(students)) + jnp.sum(jnp.square(questions))) + pr = 0.5 * ( + jnp.square(mean - 0.75) + + jnp.sum(jnp.square(students)) + + jnp.sum(jnp.square(questions)) + ) # likelihood logits = mean + students[:, jnp.newaxis] - questions[jnp.newaxis, :] - bern = self.labels * jnp.logaddexp(0., -logits) + (1 - self.labels) * jnp.logaddexp(0., logits) + bern = self.labels * jnp.logaddexp(0.0, -logits) + ( + 1 - self.labels + ) * jnp.logaddexp(0.0, logits) bern = jnp.where(self.mask, bern, jnp.zeros_like(bern)) lik = jnp.sum(bern) return -lik - pr - def sample_init(self, key): - x = jax.random.normal(key, shape = (self.ndims,)) + x = jax.random.normal(key, shape=(self.ndims,)) x = x.at[self.students].add(0.75) return x - - -class StochasticVolatility(): +class StochasticVolatility: """Example from https://num.pyro.ai/en/latest/examples/stochastic_volatility.html""" def __init__(self): - self.SP500_returns = jnp.load(dirr + '/data/SP500.npy') + self.SP500_returns = jnp.load(dirr + "/data/SP500.npy") self.ndims = 2429 - self.name = 'StochasticVolatility' + self.name = "StochasticVolatility" - self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda + self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda - data = jnp.load(dirr + '/ground_truth/stochastic_volatility/ground_truth_0.npy') + data = jnp.load(dirr + "/ground_truth/stochastic_volatility/ground_truth_0.npy") self.E_x2 = data[0] self.Var_x2 = data[1] - - def logdensity_fn(self, x): """- log p of the target distribution - x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" + x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" - sigma = jnp.exp(x[-2]) * self.typical_sigma #we used this transformation to make x unconstrained + sigma = ( + jnp.exp(x[-2]) * self.typical_sigma + ) # we used this transformation to make x unconstrained nu = jnp.exp(x[-1]) * self.typical_nu - l1= (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) - l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * (jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3]))) / jnp.square(sigma) + l1 = (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) + l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * ( + jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3])) + ) / jnp.square(sigma) l3 = jnp.sum(nlogp_StudentT(self.SP500_returns, nu, jnp.exp(x[:-2]))) return -(l1 + l2 + l3) - def transform(self, x): """transforms to the variables which are used by numpyro (and in which we have the ground truth moments)""" z = jnp.empty(x.shape) - z = z.at[:-2].set(x[:-2]) # = s = log R - z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma - z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu + z = z.at[:-2].set(x[:-2]) # = s = log R + z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma + z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu return z - def sample_init(self, key): """draws x from the prior""" key_walk, key_exp = jax.random.split(key) scales = jnp.array([self.typical_sigma, self.typical_nu]) - #params = jax.random.exponential(key_exp, shape = (2, )) * scales - params= scales + # params = jax.random.exponential(key_exp, shape = (2, )) * scales + params = scales walk = random_walk(key_walk, self.ndims - 2) * params[0] - return jnp.concatenate((walk, jnp.log(params/scales))) - + return jnp.concatenate((walk, jnp.log(params / scales))) -class MixedLogit(): +class MixedLogit: def __init__(self): - key = jax.random.PRNGKey(0) key_poisson, key_x, key_beta, key_logit = jax.random.split(key, 4) self.ndims = 2014 self.name = "Mixed Logit" self.nind = 500 - self.nsessions = jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 + self.nsessions = ( + jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 + ) self.nbeta = 4 nobs = jnp.sum(self.nsessions) mu_true = jnp.array([-1.5, -0.3, 0.8, 1.2]) - sigma_true = jnp.array([[0.5, 0.1, 0.1, 0.1], [0.1, 0.5, 0.1, 0.1], [0.1, 0.1, 0.5, 0.1], [0.1, 0.1, 0.1, 0.5]]) - beta_true = jax.random.multivariate_normal(key_beta, mu_true, sigma_true, shape=(self.nind,)) + sigma_true = jnp.array( + [ + [0.5, 0.1, 0.1, 0.1], + [0.1, 0.5, 0.1, 0.1], + [0.1, 0.1, 0.5, 0.1], + [0.1, 0.1, 0.1, 0.5], + ] + ) + beta_true = jax.random.multivariate_normal( + key_beta, mu_true, sigma_true, shape=(self.nind,) + ) beta_true_repeat = jnp.repeat(beta_true, self.nsessions, axis=0) self.x = jax.random.normal(key_x, (nobs, self.nbeta)) - self.y = 1 * jax.random.bernoulli(key_logit, (jax.nn.sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_true_repeat)))) - - self.d = self.nbeta + self.nbeta + (self.nbeta * (self.nbeta-1) // 2) + self.nbeta * self.nind # mu, tau, omega_chol, and (beta for each i) + self.y = 1 * jax.random.bernoulli( + key_logit, + ( + jax.nn.sigmoid( + jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))( + self.x, beta_true_repeat + ) + ) + ), + ) + + self.d = ( + self.nbeta + + self.nbeta + + (self.nbeta * (self.nbeta - 1) // 2) + + self.nbeta * self.nind + ) # mu, tau, omega_chol, and (beta for each i) self.prior_mean_mu = jnp.zeros(self.nbeta) self.prior_var_mu = 10.0 * jnp.eye(self.nbeta) self.prior_scale_tau = 5.0 @@ -734,20 +823,20 @@ def __init__(self): self.grad_logp = jax.value_and_grad(self.logdensity_fn) - def corrchol_to_reals(self,x): - '''Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals''' + def corrchol_to_reals(self, x): + """Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals""" dim = x.shape[0] z = jnp.zeros((dim, dim)) for i in range(dim): for j in range(i): - z = z.at[i, j].set(x[i,j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + z = z.at[i, j].set(x[i, j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) z_lower_triang = z[jnp.tril_indices(dim, -1)] y = 0.5 * (jnp.log(1.0 + z_lower_triang) - jnp.log(1.0 - z_lower_triang)) return y - def reals_to_corrchol(self,y): - '''Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix''' + def reals_to_corrchol(self, y): + """Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix""" len_vec = len(y) dim = int(0.5 * (1 + 8 * len_vec) ** 0.5 + 0.5) assert dim * (dim - 1) // 2 == len_vec @@ -757,20 +846,21 @@ def reals_to_corrchol(self,y): x = jnp.zeros((dim, dim)) for i in range(dim): - for j in range(i+1): + for j in range(i + 1): if i == j: x = x.at[i, j].set(jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) else: - x = x.at[i, j].set(z[i,j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) + x = x.at[i, j].set( + z[i, j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0)) + ) return x - def logdensity_fn(self, pars): """log p of the target distribution, i.e., log posterior distribution up to a constant""" - mu = pars[:self.nbeta] + mu = pars[: self.nbeta] dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta:dim1] + log_tau = pars[self.nbeta : dim1] dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 omega_chol_realvec = pars[dim1:dim2] beta = pars[dim2:].reshape(self.nind, self.nbeta) @@ -783,26 +873,55 @@ def logdensity_fn(self, pars): beta_repeat = jnp.repeat(beta, self.nsessions, axis=0) - log_lik = jnp.sum(self.y * jax.nn.log_sigmoid(jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat)) + (1 - self.y) * jax.nn.log_sigmoid(-jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat))) - - log_density_beta_popdist = -0.5 * self.nind * jnp.log(jnp.linalg.det(sigma)) - 0.5 * jnp.sum(jax.vmap(lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), in_axes=(0, None))(beta - mu, sigma)) + log_lik = jnp.sum( + self.y + * jax.nn.log_sigmoid( + jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) + ) + + (1 - self.y) + * jax.nn.log_sigmoid( + -jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) + ) + ) + + log_density_beta_popdist = -0.5 * self.nind * jnp.log( + jnp.linalg.det(sigma) + ) - 0.5 * jnp.sum( + jax.vmap( + lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), + in_axes=(0, None), + )(beta - mu, sigma) + ) muMinusPriorMean = mu - self.prior_mean_mu - log_prior_mu = -0.5 * jnp.log(jnp.linalg.det(self.prior_var_mu)) - 0.5 * jnp.dot(muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean)) - - log_prior_tau = jnp.sum(dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau)) - #log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) - log_prior_omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).log_prob(omega_chol) - #log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) - - return log_lik + log_density_beta_popdist + log_prior_mu + log_prior_tau + log_prior_omega_chol - + log_prior_mu = -0.5 * jnp.log( + jnp.linalg.det(self.prior_var_mu) + ) - 0.5 * jnp.dot( + muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean) + ) + + log_prior_tau = jnp.sum( + dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau) + ) + # log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) + log_prior_omega_chol = dist.LKJCholesky( + self.nbeta, concentration=self.prior_concentration_omega + ).log_prob(omega_chol) + # log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) + + return ( + log_lik + + log_density_beta_popdist + + log_prior_mu + + log_prior_tau + + log_prior_omega_chol + ) def transform(self, pars): """transform pars to the original (possibly constrained) pars""" - mu = pars[:self.nbeta] + mu = pars[: self.nbeta] dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta:dim1] + log_tau = pars[self.nbeta : dim1] dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 omega_chol_realvec = pars[dim1:dim2] beta_flattened = pars[dim2:] @@ -819,8 +938,12 @@ def sample_init(self, key): """draws pars from the prior""" key_mu, key_omega_chol, key_tau, key_beta = jax.random.split(key, 4) - mu = jax.random.multivariate_normal(key_mu, self.prior_mean_mu, self.prior_var_mu) - omega_chol = dist.LKJCholesky(self.nbeta, concentration=self.prior_concentration_omega).sample(key_omega_chol) + mu = jax.random.multivariate_normal( + key_mu, self.prior_mean_mu, self.prior_var_mu + ) + omega_chol = dist.LKJCholesky( + self.nbeta, concentration=self.prior_concentration_omega + ).sample(key_omega_chol) tau = dist.HalfCauchy(scale=self.prior_scale_tau).sample(key_tau, (self.nbeta,)) omega_chol_realvec = self.corrchol_to_reals(omega_chol) @@ -836,7 +959,6 @@ def sample_init(self, key): return pars - def nlogp_StudentT(x, df, scale): y = x / scale z = ( @@ -849,17 +971,16 @@ def nlogp_StudentT(x, df, scale): return 0.5 * (df + 1.0) * jnp.log1p(y**2.0 / df) + z - def random_walk(key, num): - """ Genereting process for the standard normal walk: - x[0] ~ N(0, 1) - x[n+1] ~ N(x[n], 1) - - Args: - key: jax random key - num: number of points in the walk - Returns: - 1 realization of the random walk (array of length num) + """Genereting process for the standard normal walk: + x[0] ~ N(0, 1) + x[n+1] ~ N(x[n], 1) + + Args: + key: jax random key + num: number of points in the walk + Returns: + 1 realization of the random walk (array of length num) """ def step(track, useless): @@ -871,22 +992,18 @@ def step(track, useless): return jax.lax.scan(step, init=(0.0, key), xs=None, length=num)[1] - models = { - - # Cauchy(100) : {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}, - # StandardNormal(100) : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - # Banana() : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - Brownian() : {'mclmc': 20000, 'mhmclmc' : 80000, 'nuts': 40000}, - - - # 'banana': Banana(), + # Cauchy(100) : {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}, + # StandardNormal(100) : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, + # Banana() : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, + Brownian(): {"mclmc": 20000, "mhmclmc": 80000, "nuts": 40000}, + # 'banana': Banana(), # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}), # GermanCredit(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, # ItemResponseTheory(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, # StochasticVolatility(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000} - } +} # models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'mhmclmc' : 40000, 'nuts': 1000}), # # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'mhmclmc' : 50000, 'nuts': 1000}) -# } \ No newline at end of file +# } diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py index 617797b79..b0a43edc4 100644 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -4,37 +4,45 @@ import jax import jax.numpy as jnp + import blackjax + # from blackjax.adaptation.window_adaptation import da_adaptation -from blackjax.mcmc.integrators import calls_per_integrator_step, generate_euclidean_integrator, generate_isokinetic_integrator +from blackjax.mcmc.integrators import ( + calls_per_integrator_step, + generate_euclidean_integrator, + generate_isokinetic_integrator, +) + # from blackjax.mcmc.adjusted_mclmc import rescale from blackjax.util import run_inference_algorithm -import blackjax __all__ = ["samplers"] - - -def run_nuts( - coefficients, logdensity_fn, num_steps, initial_position, transform, key): - +def run_nuts(coefficients, logdensity_fn, num_steps, initial_position, transform, key): integrator = generate_euclidean_integrator(coefficients) # integrator = blackjax.mcmc.integrators.velocity_verlet # note: defaulted to in nuts rng_key, warmup_key = jax.random.split(key, 2) state, params = da_adaptation( - rng_key=warmup_key, - initial_position=initial_position, + rng_key=warmup_key, + initial_position=initial_position, algorithm=blackjax.nuts, - logdensity_fn=logdensity_fn) - + logdensity_fn=logdensity_fn, + ) + # print(params["inverse_mass_matrix"], "inv\n\n") # warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, integrator=integrator) # (state, params), _ = warmup.run(warmup_key, initial_position, 2000) - nuts = blackjax.nuts(logdensity_fn=logdensity_fn, step_size=params['step_size'], inverse_mass_matrix= params['inverse_mass_matrix'], integrator=integrator) + nuts = blackjax.nuts( + logdensity_fn=logdensity_fn, + step_size=params["step_size"], + inverse_mass_matrix=params["inverse_mass_matrix"], + integrator=integrator, + ) final_state, state_history, info_history = run_inference_algorithm( rng_key=rng_key, @@ -42,26 +50,32 @@ def run_nuts( inference_algorithm=nuts, num_steps=num_steps, transform=lambda x: transform(x.position), - progress_bar=True + progress_bar=True, ) # print("INFO\n\n",info_history.num_integration_steps) - return state_history, params, info_history.num_integration_steps.mean() * calls_per_integrator_step(coefficients), info_history.acceptance_rate.mean(), None, None + return ( + state_history, + params, + info_history.num_integration_steps.mean() + * calls_per_integrator_step(coefficients), + info_history.acceptance_rate.mean(), + None, + None, + ) -def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): +def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): integrator = generate_isokinetic_integrator(coefficients) init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.mclmc.init( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - - kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, std_mat=std_mat, @@ -86,8 +100,7 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=blackjax_mclmc_sampler_params.std_mat, - integrator = integrator, - + integrator=integrator, # std_mat=jnp.ones((initial_position.shape[0],)), ) @@ -100,38 +113,60 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor progress_bar=True, ) - acceptance_rate = 1. - return samples, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients), acceptance_rate, None, None + acceptance_rate = 1.0 + return ( + samples, + blackjax_mclmc_sampler_params, + calls_per_integrator_step(coefficients), + acceptance_rate, + None, + None, + ) -def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, target_acc_rate=None): +def run_mhmclmc( + coefficients, + logdensity_fn, + num_steps, + initial_position, + transform, + key, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + target_acc_rate=None, +): integrator = generate_isokinetic_integrator(coefficients) init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, ) kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=integrator, - integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=logdensity_fn) - + integrator=integrator, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, state=state, step_size=step_size, logdensity_fn=logdensity_fn + ) + if target_acc_rate is None: - target_acc_rate = target_acceptance_rate_of_order[integrator_order(coefficients)] + target_acc_rate = target_acceptance_rate_of_order[ + integrator_order(coefficients) + ] print("target acc rate") ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, params_history, - final_da + final_da, ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -144,43 +179,46 @@ def run_mhmclmc(coefficients, logdensity_fn, num_steps, initial_position, transf diagonal_preconditioning=False, ) - - step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.step_size, blackjax_mclmc_sampler_params.L)) - alg = blackjax.mcmc.mhmclmc.mhmclmc( logdensity_fn=logdensity_fn, step_size=step_size, - integration_steps_fn = lambda key: jnp.ceil(jax.random.uniform(key) * rescale(L/step_size)) , + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), integrator=integrator, std_mat=blackjax_mclmc_sampler_params.std_mat, - - ) - _, out, info = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True) - + num_steps=num_steps, + transform=lambda x: transform(x.position), + progress_bar=True, + ) + return ( + out, + blackjax_mclmc_sampler_params, + calls_per_integrator_step(coefficients) * (L / step_size), + info.acceptance_rate, + params_history, + final_da, + ) - return out, blackjax_mclmc_sampler_params, calls_per_integrator_step(coefficients) * (L/step_size), info.acceptance_rate, params_history, final_da # we should do at least: mclmc, nuts, unadjusted hmc, mhmclmc, langevin samplers = { - 'nuts' : run_nuts, - 'mclmc' : run_mclmc, - 'mhmclmc': run_mhmclmc, - } + "nuts": run_nuts, + "mclmc": run_mclmc, + "mhmclmc": run_mhmclmc, +} # foo = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(20.56)) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 6f402dd67..aefc3d4b9 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -39,7 +39,7 @@ "implicit_midpoint", "calls_per_integrator_step", "name_integrator", - "integrator_order" + "integrator_order", ] @@ -481,14 +481,20 @@ def name_integrator(c): else: raise Exception("No such integrator exists in blackjax") + def integrator_order(c): - if c==velocity_verlet_coefficients: return 2 - if c==mclachlan_coefficients: return 2 - if c==yoshida_coefficients: return 4 - if c==omelyan_coefficients: return 4 - + if c == velocity_verlet_coefficients: + return 2 + if c == mclachlan_coefficients: + return 2 + if c == yoshida_coefficients: + return 4 + if c == omelyan_coefficients: + return 4 + + else: + raise Exception("No such integrator exists in blackjax") - else: raise Exception("No such integrator exists in blackjax") FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], From 63a8042070f1d8b06e1559960c0786e0f795346c Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 19:50:41 +0200 Subject: [PATCH 33/71] REMOVE BENCHMARKS --- blackjax/benchmarks/mcmc/benchmark.py | 954 ---------------- .../ground_truth/brownian/ground_truth.npy | Bin 384 -> 0 bytes blackjax/benchmarks/mcmc/inference_models.py | 1009 ----------------- .../benchmarks/mcmc/sampling_algorithms.py | 226 ---- 4 files changed, 2189 deletions(-) delete mode 100644 blackjax/benchmarks/mcmc/benchmark.py delete mode 100644 blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy delete mode 100644 blackjax/benchmarks/mcmc/inference_models.py delete mode 100644 blackjax/benchmarks/mcmc/sampling_algorithms.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py deleted file mode 100644 index 0ab3b4c5e..000000000 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ /dev/null @@ -1,954 +0,0 @@ -# mypy: ignore-errors -# flake8: noqa - -import math -import operator -import os -import pprint -from collections import defaultdict -from functools import partial -from statistics import mean, median - -import jax -import jax.numpy as jnp -import pandas as pd -import scipy - -from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState - -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) -num_cores = jax.local_device_count() -# print(num_cores, jax.lib.xla_bridge.get_backend().platform) - -import itertools - -import numpy as np - -import blackjax -from blackjax.benchmarks.mcmc.inference_models import ( - Brownian, - GermanCredit, - ItemResponseTheory, - MixedLogit, - StandardNormal, - StochasticVolatility, - models, -) -from blackjax.benchmarks.mcmc.sampling_algorithms import ( - run_mclmc, - run_mhmclmc, - run_nuts, - samplers, -) -from blackjax.mcmc.integrators import ( - calls_per_integrator_step, - generate_euclidean_integrator, - generate_isokinetic_integrator, - integrator_order, - isokinetic_mclachlan, - mclachlan_coefficients, - name_integrator, - omelyan_coefficients, - velocity_verlet, - velocity_verlet_coefficients, - yoshida_coefficients, -) - -# from blackjax.mcmc.mhmclmc import rescale -from blackjax.util import run_inference_algorithm - -target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} - - -def get_num_latents(target): - return target.ndims - - -# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) - - -def err(f_true, var_f, contract): - """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) - - -def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): - """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - - cutoff_reached = err_t[-1] < low_error - return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff=100): - grads_to_low, cutoff_reached = grads_to_low_error( - err_t, grad_evals_per_step, 1.0 / neff - ) - - return ( - (neff / grads_to_low) * cutoff_reached, - grads_to_low * (1 / cutoff_reached), - cutoff_reached, - ) - - -def find_crossing(array, cutoff): - """the smallest M such that array[m] < cutoff for all m > M""" - - b = array > cutoff - indices = jnp.argwhere(b) - if indices.shape[0] == 0: - print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) - return 1 - - return jnp.max(indices) + 1 - - -def cumulative_avg(samples): - return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune( - key, - iterations, - grid_size, - model, - sampler, - batch, - num_steps, - center_L, - center_step_size, - contract, -): - results = defaultdict(float) - converged = False - keys = jax.random.split(key, iterations + 1) - for i in range(iterations): - print(f"EPOCH {i}") - width = 2 - step_sizes = np.logspace( - np.log10(center_step_size / width), - np.log10(center_step_size * width), - grid_size, - ) - Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) - # print(list(itertools.product(step_sizes , Ls))) - - grid_keys = jax.random.split(keys[i], grid_size ^ 2) - print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): - ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( - model, - sampler(step_size=step_size, L=L), - grid_keys[j], - n=num_steps, - batch=batch, - contract=contract, - ) - results[(step_size, L)] = (ess, grad_calls_until_convergence) - - best_ess, best_grads, (step_size, L) = max( - ((results[r][0], results[r][1], r) for r in results), - key=operator.itemgetter(0), - ) - # raise Exception - print( - f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" - ) - if L == center_L and step_size == center_step_size: - print("converged") - converged = True - break - else: - center_L, center_step_size = L, step_size - - pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") - return center_L, center_step_size, converged - - -def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - def s(logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - - num_steps_per_traj = L / step_size - alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(num_steps_per_traj) - ), - integrator=integrator, - std_mat=std_mat, - ) - - _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) - - return ( - out, - MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), - num_steps_per_traj * calls_per_integrator_step(coefficients), - info.acceptance_rate.mean(), - None, - jnp.array([0]), - ) - - return s - - -def benchmark_chains( - model, - sampler, - key, - n=10000, - batch=None, - contract=jnp.average, -): - pvmap = jax.pmap - - d = get_num_latents(model) - if batch is None: - batch = np.ceil(1000 / d).astype(int) - key, init_key = jax.random.split(key, 2) - keys = jax.random.split(key, batch) - - init_keys = jax.random.split(init_key, batch) - init_pos = pvmap(model.sample_init)(init_keys) - - # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - ( - samples, - params, - grad_calls_per_traj, - acceptance_rate, - step_size_over_da, - final_da, - ) = pvmap( - lambda pos, key: sampler( - logdensity_fn=model.logdensity_fn, - num_steps=n, - initial_position=pos, - transform=model.transform, - key=key, - ) - )( - init_pos, keys - ) - avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) - try: - print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) - except: - pass - - full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) - err_t = pvmap(full)(samples**2) - - # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] - # # print(outs[:10]) - # esses = [i[0].item() for i in outs if not math.isnan(i[0].item())] - # grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())] - # return(mean(esses), mean(grad_calls)) - # print(final_da.mean(), "final da") - - err_t_median = jnp.median(err_t, axis=0) - # import matplotlib.pyplot as plt - # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) - # plt.xlabel('gradient evaluations') - # plt.ylabel('average second moment error') - # plt.xscale('log') - # plt.yscale('log') - # plt.savefig('brownian.png') - # plt.close() - esses, grad_calls, _ = calculate_ess( - err_t_median, grad_evals_per_step=avg_grad_calls_per_traj - ) - return ( - esses, - grad_calls, - params, - jnp.mean(acceptance_rate, axis=0), - step_size_over_da, - ) - - -def run_benchmarks(batch_size): - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [Brownian()], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - num_steps = 100000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.max - - key = jax.random.PRNGKey(11) - for i in range(1): - key1, key = jax.random.split(key) - ( - ess, - grad_calls, - params, - acceptance_rate, - step_size_over_da, - ) = benchmark_chains( - model, - partial( - samplers[sampler], - coefficients=coefficients, - frac_tune1=0.1, - frac_tune2=0.0, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - jax.numpy.save(f"acceptance.npy", acceptance_rate) - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() - print(ess.item()) - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - # print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_simple.csv", index=False) - - return results - - -def run_simple(): - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [Brownian()], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - sampler, model, coefficients = variables - num_chains = 128 - - num_steps = 10000 - - contract = jnp.max - - key = jax.random.PRNGKey(11) - for i in range(1): - key1, key = jax.random.split(key) - ( - ess, - grad_calls, - params, - acceptance_rate, - step_size_over_da, - ) = benchmark_chains( - model, - partial(samplers[sampler], coefficients=coefficients), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() - print(ess.item()) - - return results - - -# vary step_size -def run_benchmarks_step_size(batch_size): - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [StandardNormal(10)], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - num_steps = 10000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.average - - center = 6.534974 - key = jax.random.PRNGKey(11) - for step_size in np.linspace(center - 1, center + 1, 41): - # for L in np.linspace(1, 10, 41): - key1, key2, key3, key = jax.random.split(key, 4) - initial_position = model.sample_init(key2) - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=key3, - ) - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - initial_state=initial_state, - coefficients=mclachlan_coefficients, - step_size=step_size, - L=5 * step_size, - std_mat=1.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - # print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_step_size.csv", index=False) - - return results - - -def benchmark_mhmchmc(batch_size): - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) - results = defaultdict(tuple) - - # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] - coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] - for model, coeffs in itertools.product(models, coefficients): - num_chains = batch_size # 1 + batch_size//model.ndims - print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") - num_steps = models[model]["mhmclmc"] - print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") - - ####### run mclmc with standard tuning - - contract = jnp.max - - ess, grad_calls, params, _, step_size_over_da = benchmark_chains( - model, - partial(run_mclmc, coefficients=coeffs), - key0, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mclmc", - params.L.mean().item(), - params.step_size.mean().item(), - name_integrator(coeffs), - "standard", - 1.0, - ) - ] = ess.item() - print(f"mclmc with tuning ESS {ess}") - - ####### run mhmclmc with standard tuning - for target_acc_rate in [0.65, 0.9]: - # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") - - # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.1, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc:st3" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") - - if True: - ####### run mhmclmc with standard tuning + grid search - - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( - key2, 5 - ) - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, - _, - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coeffs)], - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - diagonal_preconditioning=False, - ) - - print( - f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) - - L, step_size, convergence = gridsearch_tune( - grid_key, - iterations=10, - contract=contract, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coeffs, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) - # print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coeffs, - L=L, - step_size=step_size, - initial_state=state, - std_mat=1.0, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - print(f"grads to low bias: {grad_calls}") - - results[ - ( - model.name, - model.ndims, - "mhmchmc:grid", - L.item(), - step_size.item(), - name_integrator(coeffs), - f"gridsearch:{convergence}", - acceptance_rate.mean().item(), - ) - ] = ess.item() - - ####### run nuts - - # coeffs = velocity_verlet_coefficients - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - partial(run_nuts, coefficients=coeffs), - key3, - n=models[model]["nuts"], - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "nuts", - 0.0, - 0.0, - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - - print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "L", - "step_size", - "integrator", - "tuning", - "acc_rate", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results.csv", index=False) - - return results - - -def benchmark_omelyan(batch_size): - key = jax.random.PRNGKey(2) - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [ - StandardNormal(d) - for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) - ], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], - # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - sampler, model, coefficients = variables - - # num_chains = 1 + batch_size//model.ndims - num_chains = batch_size - - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( - current_key, 5 - ) - - # num_steps = models[model][sampler] - - num_steps = 1000 - - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, - _, - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coefficients)], - frac_tune1=0.1, - frac_tune2=0.1, - # frac_tune3=0.1, - diagonal_preconditioning=False, - ) - - print( - f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) - - # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune( - grid_key, - iterations=10, - contract=jnp.average, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coefficients, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) - print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _, _, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coefficients, - L=L, - step_size=step_size, - std_mat=1.0, - initial_state=state, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=jnp.average, - ) - - print(f"grads to low bias: {grad_calls}") - - results[ - ( - model.name, - model.ndims, - sampler, - name_integrator(coefficients), - converged, - L.item(), - step_size.item(), - ) - ] = ess.item() - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "integrator", - "convergence", - "L", - "step_size", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("omelyan.csv", index=False) - - -def run_benchmarks_divij(): - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal - coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions - num_steps = 2000 - num_chains = 100 - key1 = jax.random.PRNGKey(2) - - ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( - model, - partial(sampler, coefficients=coefficients), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") - - -if __name__ == "__main__": - # run_benchmarks_divij() - - # benchmark_mhmchmc(batch_size=128) - run_simple() - # run_benchmarks_step_size(128) - # benchmark_omelyan(128) - # run_benchmarks(128) - # benchmark_omelyan(10) - # print("4") diff --git a/blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy b/blackjax/benchmarks/mcmc/ground_truth/brownian/ground_truth.npy deleted file mode 100644 index d381c47de0061e8dd351fc55551c143439c7381d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 384 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Its={nmP)#3giN=HsrpV z?PsMKwrgvr+xp7gw{>nVw_R9u!1h`Jv)z%e+IICTL+spj8toc0XWQwN@36aIe8cY9 z%-?oicLnVo88z%@vZ~vcx69bC->+j|`BcjOJPWTqQmVb6{~e;&sZ_vu(iHo&1bz(|CsfW zc2}FscMEJ@Gd{4{<1S=7;gOl`9=#CT(~POMrvnmfvsEo@|2c28k=oc~ldPO%W7pAP e<0&0(^RCq0CX~P0=KmsFn?;B9Y}9Xs*#H1}DT0^) diff --git a/blackjax/benchmarks/mcmc/inference_models.py b/blackjax/benchmarks/mcmc/inference_models.py deleted file mode 100644 index 715bd4c14..000000000 --- a/blackjax/benchmarks/mcmc/inference_models.py +++ /dev/null @@ -1,1009 +0,0 @@ -# mypy: ignore-errors -# flake8: noqa - -import os - -# from inference_gym import using_jax as gym -import jax -import jax.numpy as jnp -import numpy as np - -# import numpyro.distributions as dist -dirr = os.path.dirname(os.path.realpath(__file__)) - - -class StandardNormal: - """Standard Normal distribution in d dimensions""" - - def __init__(self, d): - self.ndims = d - self.E_x2 = jnp.ones(d) - self.Var_x2 = 2 * self.E_x2 - self.name = "StandardNormal" - - def logdensity_fn(self, x): - """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x), axis=-1) - - def transform(self, x): - return x - - def sample_init(self, key): - return jax.random.normal(key, shape=(self.ndims,)) - - -class IllConditionedGaussian: - """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" - - def __init__(self, d, condition_number, numpy_seed=None, prior="prior"): - """numpy_seed is used to generate a random rotation for the covariance matrix. - If None, the covariance matrix is diagonal.""" - - self.ndims = d - self.name = "IllConditionedGaussian" - self.condition_number = condition_number - eigs = jnp.logspace( - -0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), d - ) - - if numpy_seed == None: # diagonal - self.E_x2 = eigs - self.R = jnp.eye(d) - self.Hessian = jnp.diag(1 / eigs) - self.Cov = jnp.diag(eigs) - - else: # randomly rotate - rng = np.random.RandomState(seed=numpy_seed) - D = jnp.diag(eigs) - inv_D = jnp.diag(1 / eigs) - R, _ = jnp.array( - np.linalg.qr(rng.randn(self.ndims, self.ndims)) - ) # random rotation - self.R = R - self.Hessian = R @ inv_D @ R.T - self.Cov = R @ D @ R.T - self.E_x2 = jnp.diagonal(R @ D @ R.T) - - # Cov_precond = jnp.diag(1 / jnp.sqrt(self.E_x2)) @ self.Cov @ jnp.diag(1 / jnp.sqrt(self.E_x2)) - - # print(jnp.linalg.cond(Cov_precond) / jnp.linalg.cond(self.Cov)) - - self.Var_x2 = 2 * jnp.square(self.E_x2) - - self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x - self.transform = lambda x: x - - if prior == "map": - self.sample_init = lambda key: jnp.zeros(self.ndims) - - elif prior == "posterior": - self.sample_init = lambda key: self.R @ ( - jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(eigs) - ) - - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal( - key, shape=(self.ndims,) - ) * jnp.max(jnp.sqrt(eigs)) - - -class IllConditionedESH: - """ICG from the ESH paper.""" - - def __init__(self): - self.ndims = 50 - self.name = "IllConditionedESH" - self.variance = jnp.linspace(0.01, 1, self.ndims) - - def logdensity_fn(self, x): - """- log p of the target distribution""" - return -0.5 * jnp.sum(jnp.square(x) / self.variance, axis=-1) - - def transform(self, x): - return x - - def draw(self, key): - return jax.random.normal(key, shape=(self.ndims,)) * jnp.sqrt(self.variance) - - def sample_init(self, key): - return jax.random.normal(key, shape=(self.ndims,)) - - -class IllConditionedGaussianGamma: - """Inference gym's Ill conditioned Gaussian""" - - def __init__(self, prior="prior"): - self.ndims = 100 - self.name = "IllConditionedGaussianGamma" - - # define the Hessian - rng = np.random.RandomState(seed=10 & (2**32 - 1)) - eigs = np.sort( - rng.gamma(shape=0.5, scale=1.0, size=self.ndims) - ) # eigenvalues of the Hessian - eigs *= jnp.average(1.0 / eigs) - self.entropy = 0.5 * self.ndims - self.maxmin = (1.0 / jnp.sqrt(eigs[0]), 1.0 / jnp.sqrt(eigs[-1])) - R, _ = np.linalg.qr(rng.randn(self.ndims, self.ndims)) # random rotation - self.map_to_worst = (R.T)[[0, -1], :] - self.Hessian = R @ np.diag(eigs) @ R.T - - # analytic ground truth moments - self.E_x2 = jnp.diagonal(R @ np.diag(1.0 / eigs) @ R.T) - self.Var_x2 = 2 * jnp.square(self.E_x2) - - # norm = jnp.diag(1/jnp.sqrt(self.E_x2)) - # Sigma = R @ np.diag(1/eigs) @ R.T - # reduced = norm @ Sigma @ norm - # print(np.linalg.cond(reduced), np.linalg.cond(Sigma)) - - # gradient - - if prior == "map": - self.sample_init = lambda key: jnp.zeros(self.ndims) - - elif prior == "posterior": - self.sample_init = lambda key: R @ ( - jax.random.normal(key, shape=(self.ndims,)) / jnp.sqrt(eigs) - ) - - else: # N(0, sigma_true_max) - self.sample_init = lambda key: jax.random.normal( - key, shape=(self.ndims,) - ) * jnp.max(1.0 / jnp.sqrt(eigs)) - - def logdensity_fn(self, x): - """- log p of the target distribution""" - return -0.5 * x.T @ self.Hessian @ x - - def transform(self, x): - return x - - -class Banana: - """Banana target fromm the Inference Gym""" - - def __init__(self, prior="map"): - self.curvature = 0.03 - self.ndims = 2 - self.name = "Banana" - - self.transform = lambda x: x - self.E_x2 = jnp.array( - [100.0, 19.0] - ) # the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. - self.Var_x2 = jnp.array([20000.0, 4600.898]) - - if prior == "map": - self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) - elif prior == "posterior": - self.sample_init = lambda key: self.posterior_draw(key) - elif prior == "prior": - self.sample_init = ( - lambda key: jax.random.normal(key, shape=(self.ndims,)) - * jnp.array([10.0, 5.0]) - * 2 - ) - else: - raise ValueError("prior = " + prior + " is not defined.") - - def logdensity_fn(self, x): - mu2 = self.curvature * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - - def posterior_draw(self, key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = self.curvature * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - - def ground_truth(self): - x = jax.vmap(self.posterior_draw)( - jax.random.split(jax.random.PRNGKey(0), 100000000) - ) - print(jnp.average(x, axis=0)) - print(jnp.average(jnp.square(x), axis=0)) - print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) - - -class Cauchy: - """d indpendent copies of the standard Cauchy distribution""" - - def __init__(self, d): - self.ndims = d - self.name = "Cauchy" - - self.logdensity_fn = lambda x: -jnp.sum(jnp.log(1.0 + jnp.square(x))) - - self.transform = lambda x: x - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) - - -class HardConvex: - def __init__(self, d, kappa, theta=0.1): - """d is the dimension, kappa = condition number, 0 < theta < 1/4""" - self.ndims = d - self.name = "HardConvex" - self.theta, self.kappa = theta, kappa - C = jnp.power(d - 1, 0.25 - theta) - self.logdensity_fn = ( - lambda x: -0.5 * jnp.sum(jnp.square(x[:-1])) - - (0.75 / kappa) * x[-1] ** 2 - + 0.5 * jnp.sum(jnp.cos(C * x[:-1])) / C**2 - ) - - self.transform = lambda x: x - - # numerically precomputed variances - num_integration = [0.93295, 0.968802, 0.990595, 0.998002, 0.999819] - if d == 100: - self.variance = jnp.concatenate( - (jnp.ones(d - 1) * num_integration[0], jnp.ones(1) * 2.0 * kappa / 3.0) - ) - elif d == 300: - self.variance = jnp.concatenate( - (jnp.ones(d - 1) * num_integration[1], jnp.ones(1) * 2.0 * kappa / 3.0) - ) - elif d == 1000: - self.variance = jnp.concatenate( - (jnp.ones(d - 1) * num_integration[2], jnp.ones(1) * 2.0 * kappa / 3.0) - ) - elif d == 3000: - self.variance = jnp.concatenate( - (jnp.ones(d - 1) * num_integration[3], jnp.ones(1) * 2.0 * kappa / 3.0) - ) - elif d == 10000: - self.variance = jnp.concatenate( - (jnp.ones(d - 1) * num_integration[4], jnp.ones(1) * 2.0 * kappa / 3.0) - ) - else: - None - - def sample_init(self, key): - """Gaussian prior with approximately estimating the variance along each dimension""" - scale = jnp.concatenate( - (jnp.ones(self.ndims - 1), jnp.ones(1) * jnp.sqrt(2.0 * self.kappa / 3.0)) - ) - return jax.random.normal(key, shape=(self.ndims,)) * scale - - -class BiModal: - """A Gaussian mixture p(x) = f N(x | mu1, sigma1) + (1-f) N(x | mu2, sigma2).""" - - def __init__(self, d=50, mu1=0.0, mu2=8.0, sigma1=1.0, sigma2=1.0, f=0.2): - self.ndims = d - self.name = "BiModal" - - self.mu1 = jnp.insert(jnp.zeros(d - 1), 0, mu1) - self.mu2 = jnp.insert(jnp.zeros(d - 1), 0, mu2) - self.sigma1, self.sigma2 = sigma1, sigma2 - self.f = f - self.variance = jnp.insert( - jnp.ones(d - 1) * ((1 - f) * sigma1**2 + f * sigma2**2), - 0, - (1 - f) * (sigma1**2 + mu1**2) + f * (sigma2**2 + mu2**2), - ) - - def logdensity_fn(self, x): - """- log p of the target distribution""" - - N1 = ( - (1.0 - self.f) - * jnp.exp( - -0.5 * jnp.sum(jnp.square(x - self.mu1), axis=-1) / self.sigma1**2 - ) - / jnp.power(2 * jnp.pi * self.sigma1**2, self.ndims * 0.5) - ) - N2 = ( - self.f - * jnp.exp( - -0.5 * jnp.sum(jnp.square(x - self.mu2), axis=-1) / self.sigma2**2 - ) - / jnp.power(2 * jnp.pi * self.sigma2**2, self.ndims * 0.5) - ) - - return jnp.log(N1 + N2) - - def draw(self, num_samples): - """direct sampler from a target""" - X = np.random.normal(size=(num_samples, self.ndims)) - mask = np.random.uniform(0, 1, num_samples) < self.f - X[mask, :] = (X[mask, :] * self.sigma2) + self.mu2 - X[~mask] = (X[~mask] * self.sigma1) + self.mu1 - - return X - - def transform(self, x): - return x - - def sample_init(self, key): - z = jax.random.normal(key, shape=(self.ndims,)) * self.sigma1 - # z= z.at[0].set(self.mu1 + z[0]) - return z - - -class BiModalEqual: - """Mixture of two Gaussians, one centered at x = [mu/2, 0, 0, ...], the other at x = [-mu/2, 0, 0, ...]. - Both have equal probability mass.""" - - def __init__(self, d, mu): - self.ndims = d - self.name = "BiModalEqual" - self.mu = mu - - def logdensity_fn(self, x): - """- log p of the target distribution""" - - return ( - -0.5 * jnp.sum(jnp.square(x), axis=-1) - + jnp.log(jnp.cosh(0.5 * self.mu * x[0])) - - 0.5 * self.ndims * jnp.log(2 * jnp.pi) - - self.mu**2 / 8.0 - ) - - def draw(self, num_samples): - """direct sampler from a target""" - X = np.random.normal(size=(num_samples, self.ndims)) - mask = np.random.uniform(0, 1, num_samples) < 0.5 - X[mask, 0] += 0.5 * self.mu - X[~mask, 0] -= 0.5 * self.mu - - return X - - def transform(self, x): - return x - - -class Funnel: - """Noise-less funnel""" - - def __init__(self, d=20): - self.ndims = d - self.name = "Funnel" - self.sigma_theta = 3.0 - - self.E_x2 = jnp.ones( - d - ) # the transformed variables are standard Gaussian distributed - self.Var_x2 = 2 * self.E_x2 - - def logdensity_fn(self, x): - """- log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta]""" - theta = x[-1] - X = x[..., :-1] - - return ( - -0.5 * jnp.square(theta / self.sigma_theta) - - 0.5 * (self.ndims - 1) * theta - - 0.5 * jnp.exp(-theta) * jnp.sum(jnp.square(X), axis=-1) - ) - - def inverse_transform(self, xtilde): - theta = 3 * xtilde[-1] - return jnp.concatenate( - (xtilde[:-1] * jnp.exp(0.5 * theta), jnp.ones(1) * theta) - ) - - def transform(self, x): - """gaussianization""" - xtilde = jnp.empty(x.shape) - xtilde = xtilde.at[-1].set(x.T[-1] / 3.0) - xtilde = xtilde.at[:-1].set(x.T[:-1] * jnp.exp(-0.5 * x.T[-1])) - return xtilde.T - - def sample_init(self, key): - return self.inverse_transform(jax.random.normal(key, shape=(self.ndims,))) - - -class Funnel_with_Data: - def __init__(self, d, sigma, minibatch_size, key): - self.ndims = d - self.name = "Funnel_with_Data" - self.sigma_theta = 3.0 - self.theta_true = 0.0 - self.sigma_data = sigma - - self.data = self.simulate_data() - - self.batch = minibatch_size - - def simulate_data(self): - norm = jax.random.normal(jax.random.PRNGKey(123), shape=(2 * (self.ndims - 1),)) - z_true = norm[: self.ndims - 1] * jnp.exp(self.theta_true * 0.5) - self.data = z_true + norm[self.ndims - 1 :] * self.sigma_data - - def logdensity_fn(self, x, subset): - """- log p of the target distribution - x = [z_0, z_1, ... z_{d-1}, theta]""" - theta = x[-1] - z = x[:-1][subset] - - prior_theta = jnp.square(theta / self.sigma_theta) - prior_z = jnp.sum(subset) * theta + jnp.exp(-theta) * jnp.sum( - jnp.square(z * subset) - ) - likelihood = jnp.sum(jnp.square((z - self.data) * subset / self.sigma_data)) - - return -0.5 * (prior_theta + prior_z + likelihood) - - def transform(self, x): - """gaussianization""" - return x - - def sample_init(self, key): - key1, key2 = jax.random.split(key) - theta = jax.random.normal(key1) * self.sigma_theta - z = jax.random.normal(key2, shape=(self.ndims - 1,)) * jnp.exp(theta * 0.5) - return jnp.concatenate((z, theta)) - - -class Rosenbrock: - def __init__(self, d=36, Q=0.1): - self.ndims = d - self.name = "Rosenbrock" - self.Q = Q - # ground truth moments - var_x = 2.0 - - # these two options were precomputed: - if Q == 0.1: - var_y = 10.098433122783046 # var_y is computed numerically (see class function compute_variance) - elif Q == 0.5: - var_y = 10.498957879911487 - else: - raise ValueError( - "Ground truth moments for Q = " - + str(Q) - + " were not precomputed. Use Q = 0.1 or 0.5." - ) - - self.variance = jnp.concatenate( - (var_x * jnp.ones(d // 2), var_y * jnp.ones(d // 2)) - ) - - def logdensity_fn(self, x): - """- log p of the target distribution""" - X, Y = x[..., : self.ndims // 2], x[..., self.ndims // 2 :] - return -0.5 * jnp.sum( - jnp.square(X - 1.0) + jnp.square(jnp.square(X) - Y) / self.Q, axis=-1 - ) - - def draw(self, num): - n = self.ndims // 2 - X = np.empty((num, self.ndims)) - X[:, :n] = np.random.normal(loc=1.0, scale=1.0, size=(num, n)) - X[:, n:] = np.random.normal( - loc=jnp.square(X[:, :n]), scale=jnp.sqrt(self.Q), size=(num, n) - ) - - return X - - def transform(self, x): - return x - - def sample_init(self, key): - return jax.random.normal(key, shape=(self.ndims,)) - - def ground_truth(self): - num = 100000000 - x = np.random.normal(loc=1.0, scale=1.0, size=num) - y = np.random.normal(loc=np.square(x), scale=jnp.sqrt(self.Q), size=num) - - x2 = jnp.sum(jnp.square(x)) / (num - 1) - y2 = jnp.sum(jnp.square(y)) / (num - 1) - - x1 = np.average(x) - y1 = np.average(y) - - print(np.sqrt(0.5 * (np.square(np.std(x)) + np.square(np.std(y))))) - - print(x2, y2) - - -class Brownian: - """ - log sigma_i ~ N(0, 2) - log sigma_obs ~N(0, 2) - - x ~ RandomWalk(0, sigma_i) - x_observed = (x + noise) * mask - noise ~ N(0, sigma_obs) - mask = 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 - """ - - def __init__(self): - self.num_data = 30 - self.name = "Brownian" - self.ndims = self.num_data + 2 - - ground_truth_moments = jnp.load( - dirr + "/ground_truth/brownian/ground_truth.npy" - ) - self.E_x2, self.Var_x2 = ground_truth_moments[0], ground_truth_moments[1] - - self.data = jnp.array( - [ - 0.21592641, - 0.118771404, - -0.07945447, - 0.037677474, - -0.27885845, - -0.1484156, - -0.3250906, - -0.22957903, - -0.44110894, - -0.09830782, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - -0.8786016, - -0.83736074, - -0.7384849, - -0.8939254, - -0.7774566, - -0.70238715, - -0.87771565, - -0.51853573, - -0.6948214, - -0.6202789, - ] - ) - # sigma_obs = 0.15, sigma_i = 0.1 - - self.observable = jnp.concatenate((jnp.ones(10), jnp.zeros(10), jnp.ones(10))) - self.num_observable = jnp.sum(self.observable) # = 20 - - def logdensity_fn(self, x): - # y = softplus_to_log(x[:2]) - - lik = ( - 0.5 - * jnp.exp(-2 * x[1]) - * jnp.sum(self.observable * jnp.square(x[2:] - self.data)) - + x[1] * self.num_observable - ) - prior_x = ( - 0.5 - * jnp.exp(-2 * x[0]) - * (x[2] ** 2 + jnp.sum(jnp.square(x[3:] - x[2:-1]))) - + x[0] * self.num_data - ) - prior_logsigma = 0.5 * jnp.sum(jnp.square(x / 2.0)) - - return -lik - prior_x - prior_logsigma - - def transform(self, x): - return jnp.concatenate((jnp.exp(x[:2]), x[2:])) - - def sample_init(self, key): - key_walk, key_sigma = jax.random.split(key) - - # original prior - # log_sigma = jax.random.normal(key_sigma, shape= (2, )) * 2 - - # narrower prior - log_sigma = ( - jnp.log(np.array([0.1, 0.15])) - + jax.random.normal(key_sigma, shape=(2,)) * 0.1 - ) # *0.05# log sigma_i, log sigma_obs - - walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) - - return jnp.concatenate((log_sigma, walk)) - - def generate_data(self, key): - key_walk, key_sigma, key_noise = jax.random.split(key, 3) - - log_sigma = ( - jax.random.normal(key_sigma, shape=(2,)) * 2 - ) # log sigma_i, log sigma_obs - - walk = random_walk(key_walk, self.ndims - 2) * jnp.exp(log_sigma[0]) - noise = jax.random.normal(key_noise, shape=(self.ndims - 2,)) * jnp.exp( - log_sigma[1] - ) - - return walk + noise - - -class GermanCredit: - """Taken from inference gym. - - x = (global scale, local scales, weights) - - global_scale ~ Gamma(0.5, 0.5) - - for i in range(num_features): - unscaled_weights[i] ~ Normal(loc=0, scale=1) - local_scales[i] ~ Gamma(0.5, 0.5) - weights[i] = unscaled_weights[i] * local_scales[i] * global_scale - - for j in range(num_datapoints): - label[j] ~ Bernoulli(features @ weights) - - We use a log transform for the scale parameters. - """ - - def __init__(self): - self.ndims = 51 # global scale + 25 local scales + 25 weights - self.name = "GermanCredit" - - self.labels = jnp.load(dirr + "/data/gc_labels.npy") - self.features = jnp.load(dirr + "/data/gc_features.npy") - - truth = jnp.load(dirr + "/ground_truth/german_credit/ground_truth.npy") - self.E_x2, self.Var_x2 = truth[0], truth[1] - - def transform(self, x): - return jnp.concatenate((jnp.exp(x[:26]), x[26:])) - - def logdensity_fn(self, x): - scales = jnp.exp(x[:26]) - - # prior - pr = jnp.sum(0.5 * scales + 0.5 * x[:26]) + 0.5 * jnp.sum(jnp.square(x[26:])) - - # transform - transform = -jnp.sum(x[:26]) - - # likelihood - weights = scales[0] * scales[1:26] * x[26:] - logits = ( - self.features @ weights - ) # = jnp.einsum('nd,...d->...n', self.features, weights) - lik = jnp.sum( - self.labels * jnp.logaddexp(0.0, -logits) - + (1 - self.labels) * jnp.logaddexp(0.0, logits) - ) - - return -(lik + pr + transform) - - def sample_init(self, key): - weights = jax.random.normal(key, shape=(25,)) - return jnp.concatenate((jnp.zeros(26), weights)) - - -class ItemResponseTheory: - """Taken from inference gym.""" - - def __init__(self): - self.ndims = 501 - self.name = "ItemResponseTheory" - self.students = 400 - self.questions = 100 - - self.mask = jnp.load(dirr + "/data/irt_mask.npy") - self.labels = jnp.load(dirr + "/data/irt_labels.npy") - - truth = jnp.load(dirr + "/ground_truth/item_response_theory/ground_truth.npy") - self.E_x2, self.Var_x2 = truth[0], truth[1] - - self.transform = lambda x: x - - def logdensity_fn(self, x): - students = x[: self.students] - mean = x[self.students] - questions = x[self.students + 1 :] - - # prior - pr = 0.5 * ( - jnp.square(mean - 0.75) - + jnp.sum(jnp.square(students)) - + jnp.sum(jnp.square(questions)) - ) - - # likelihood - logits = mean + students[:, jnp.newaxis] - questions[jnp.newaxis, :] - bern = self.labels * jnp.logaddexp(0.0, -logits) + ( - 1 - self.labels - ) * jnp.logaddexp(0.0, logits) - bern = jnp.where(self.mask, bern, jnp.zeros_like(bern)) - lik = jnp.sum(bern) - - return -lik - pr - - def sample_init(self, key): - x = jax.random.normal(key, shape=(self.ndims,)) - x = x.at[self.students].add(0.75) - return x - - -class StochasticVolatility: - """Example from https://num.pyro.ai/en/latest/examples/stochastic_volatility.html""" - - def __init__(self): - self.SP500_returns = jnp.load(dirr + "/data/SP500.npy") - - self.ndims = 2429 - self.name = "StochasticVolatility" - - self.typical_sigma, self.typical_nu = 0.02, 10.0 # := 1 / lambda - - data = jnp.load(dirr + "/ground_truth/stochastic_volatility/ground_truth_0.npy") - self.E_x2 = data[0] - self.Var_x2 = data[1] - - def logdensity_fn(self, x): - """- log p of the target distribution - x= [s1, s2, ... s2427, log sigma / typical_sigma, log nu / typical_nu]""" - - sigma = ( - jnp.exp(x[-2]) * self.typical_sigma - ) # we used this transformation to make x unconstrained - nu = jnp.exp(x[-1]) * self.typical_nu - - l1 = (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1]) - l2 = (self.ndims - 2) * jnp.log(sigma) + 0.5 * ( - jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3])) - ) / jnp.square(sigma) - l3 = jnp.sum(nlogp_StudentT(self.SP500_returns, nu, jnp.exp(x[:-2]))) - - return -(l1 + l2 + l3) - - def transform(self, x): - """transforms to the variables which are used by numpyro (and in which we have the ground truth moments)""" - - z = jnp.empty(x.shape) - z = z.at[:-2].set(x[:-2]) # = s = log R - z = z.at[-2].set(jnp.exp(x[-2]) * self.typical_sigma) # = sigma - z = z.at[-1].set(jnp.exp(x[-1]) * self.typical_nu) # = nu - - return z - - def sample_init(self, key): - """draws x from the prior""" - - key_walk, key_exp = jax.random.split(key) - - scales = jnp.array([self.typical_sigma, self.typical_nu]) - # params = jax.random.exponential(key_exp, shape = (2, )) * scales - params = scales - walk = random_walk(key_walk, self.ndims - 2) * params[0] - return jnp.concatenate((walk, jnp.log(params / scales))) - - -class MixedLogit: - def __init__(self): - key = jax.random.PRNGKey(0) - key_poisson, key_x, key_beta, key_logit = jax.random.split(key, 4) - - self.ndims = 2014 - self.name = "Mixed Logit" - self.nind = 500 - self.nsessions = ( - jax.random.poisson(key_poisson, lam=1.0, shape=(self.nind,)) + 10 - ) - self.nbeta = 4 - nobs = jnp.sum(self.nsessions) - - mu_true = jnp.array([-1.5, -0.3, 0.8, 1.2]) - sigma_true = jnp.array( - [ - [0.5, 0.1, 0.1, 0.1], - [0.1, 0.5, 0.1, 0.1], - [0.1, 0.1, 0.5, 0.1], - [0.1, 0.1, 0.1, 0.5], - ] - ) - beta_true = jax.random.multivariate_normal( - key_beta, mu_true, sigma_true, shape=(self.nind,) - ) - beta_true_repeat = jnp.repeat(beta_true, self.nsessions, axis=0) - - self.x = jax.random.normal(key_x, (nobs, self.nbeta)) - self.y = 1 * jax.random.bernoulli( - key_logit, - ( - jax.nn.sigmoid( - jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))( - self.x, beta_true_repeat - ) - ) - ), - ) - - self.d = ( - self.nbeta - + self.nbeta - + (self.nbeta * (self.nbeta - 1) // 2) - + self.nbeta * self.nind - ) # mu, tau, omega_chol, and (beta for each i) - self.prior_mean_mu = jnp.zeros(self.nbeta) - self.prior_var_mu = 10.0 * jnp.eye(self.nbeta) - self.prior_scale_tau = 5.0 - self.prior_concentration_omega = 1.0 - - self.grad_logp = jax.value_and_grad(self.logdensity_fn) - - def corrchol_to_reals(self, x): - """Converts a Cholesky-correlation (lower-triangular) matrix to a vector of unconstrained reals""" - dim = x.shape[0] - z = jnp.zeros((dim, dim)) - for i in range(dim): - for j in range(i): - z = z.at[i, j].set(x[i, j] / jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) - z_lower_triang = z[jnp.tril_indices(dim, -1)] - y = 0.5 * (jnp.log(1.0 + z_lower_triang) - jnp.log(1.0 - z_lower_triang)) - - return y - - def reals_to_corrchol(self, y): - """Converts a vector of unconstrained reals to a Cholesky-correlation (lower-triangular) matrix""" - len_vec = len(y) - dim = int(0.5 * (1 + 8 * len_vec) ** 0.5 + 0.5) - assert dim * (dim - 1) // 2 == len_vec - - z = jnp.zeros((dim, dim)) - z = z.at[jnp.tril_indices(dim, -1)].set(jnp.tanh(y)) - - x = jnp.zeros((dim, dim)) - for i in range(dim): - for j in range(i + 1): - if i == j: - x = x.at[i, j].set(jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0))) - else: - x = x.at[i, j].set( - z[i, j] * jnp.sqrt(1.0 - jnp.sum(x[i, :j] ** 2.0)) - ) - return x - - def logdensity_fn(self, pars): - """log p of the target distribution, i.e., log posterior distribution up to a constant""" - - mu = pars[: self.nbeta] - dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta : dim1] - dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 - omega_chol_realvec = pars[dim1:dim2] - beta = pars[dim2:].reshape(self.nind, self.nbeta) - - omega_chol = self.reals_to_corrchol(omega_chol_realvec) - omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) - tau = jnp.exp(log_tau) - tau_diagmat = jnp.diag(tau) - sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) - - beta_repeat = jnp.repeat(beta, self.nsessions, axis=0) - - log_lik = jnp.sum( - self.y - * jax.nn.log_sigmoid( - jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) - ) - + (1 - self.y) - * jax.nn.log_sigmoid( - -jax.vmap(lambda vec1, vec2: jnp.dot(vec1, vec2))(self.x, beta_repeat) - ) - ) - - log_density_beta_popdist = -0.5 * self.nind * jnp.log( - jnp.linalg.det(sigma) - ) - 0.5 * jnp.sum( - jax.vmap( - lambda vec, mat: jnp.dot(vec, jnp.linalg.solve(mat, vec)), - in_axes=(0, None), - )(beta - mu, sigma) - ) - - muMinusPriorMean = mu - self.prior_mean_mu - log_prior_mu = -0.5 * jnp.log( - jnp.linalg.det(self.prior_var_mu) - ) - 0.5 * jnp.dot( - muMinusPriorMean, jnp.linalg.solve(self.prior_var_mu, muMinusPriorMean) - ) - - log_prior_tau = jnp.sum( - dist.HalfCauchy(scale=self.prior_scale_tau).log_prob(tau) - ) - # log_prior_tau = jnp.sum(jax.vmap(lambda arg: -jnp.log(1.0 + (arg / self.prior_scale_tau) ** 2.0))(tau)) - log_prior_omega_chol = dist.LKJCholesky( - self.nbeta, concentration=self.prior_concentration_omega - ).log_prob(omega_chol) - # log_prior_omega_chol = jnp.dot(nbeta - jnp.arange(2, nbeta+1) + 2.0 * self.prior_concentration_omega - 2.0, jnp.log(jnp.diag(omega_chol)[1:])) - - return ( - log_lik - + log_density_beta_popdist - + log_prior_mu - + log_prior_tau - + log_prior_omega_chol - ) - - def transform(self, pars): - """transform pars to the original (possibly constrained) pars""" - mu = pars[: self.nbeta] - dim1 = self.nbeta + self.nbeta - log_tau = pars[self.nbeta : dim1] - dim2 = self.nbeta + self.nbeta + self.nbeta * (self.nbeta - 1) // 2 - omega_chol_realvec = pars[dim1:dim2] - beta_flattened = pars[dim2:] - - omega_chol = self.reals_to_corrchol(omega_chol_realvec) - omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) - tau = jnp.exp(log_tau) - tau_diagmat = jnp.diag(tau) - sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) - - return jnp.concatenate((mu, sigma.flatten(), beta_flattened)) - - def sample_init(self, key): - """draws pars from the prior""" - - key_mu, key_omega_chol, key_tau, key_beta = jax.random.split(key, 4) - mu = jax.random.multivariate_normal( - key_mu, self.prior_mean_mu, self.prior_var_mu - ) - omega_chol = dist.LKJCholesky( - self.nbeta, concentration=self.prior_concentration_omega - ).sample(key_omega_chol) - tau = dist.HalfCauchy(scale=self.prior_scale_tau).sample(key_tau, (self.nbeta,)) - - omega_chol_realvec = self.corrchol_to_reals(omega_chol) - log_tau = jnp.log(tau) - - omega = jnp.dot(omega_chol, jnp.transpose(omega_chol)) - tau_diagmat = jnp.diag(tau) - sigma = jnp.dot(tau_diagmat, jnp.dot(omega, tau_diagmat)) - - beta = jax.random.multivariate_normal(key_beta, mu, sigma, shape=(self.nind,)) - - pars = jnp.concatenate((mu, log_tau, omega_chol_realvec, beta.flatten())) - return pars - - -def nlogp_StudentT(x, df, scale): - y = x / scale - z = ( - jnp.log(scale) - + 0.5 * jnp.log(df) - + 0.5 * jnp.log(jnp.pi) - + jax.scipy.special.gammaln(0.5 * df) - - jax.scipy.special.gammaln(0.5 * (df + 1.0)) - ) - return 0.5 * (df + 1.0) * jnp.log1p(y**2.0 / df) + z - - -def random_walk(key, num): - """Genereting process for the standard normal walk: - x[0] ~ N(0, 1) - x[n+1] ~ N(x[n], 1) - - Args: - key: jax random key - num: number of points in the walk - Returns: - 1 realization of the random walk (array of length num) - """ - - def step(track, useless): - x, key = track - randkey, subkey = jax.random.split(key) - x += jax.random.normal(subkey) - return (x, randkey), x - - return jax.lax.scan(step, init=(0.0, key), xs=None, length=num)[1] - - -models = { - # Cauchy(100) : {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}, - # StandardNormal(100) : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - # Banana() : {'mclmc': 10000, 'mhmclmc' : 10000, 'nuts': 10000}, - Brownian(): {"mclmc": 20000, "mhmclmc": 80000, "nuts": 40000}, - # 'banana': Banana(), - # 'icg' : (IllConditionedGaussian(10, 2), {'mclmc': 2000, 'mhmclmc' : 2000, 'nuts': 2000}), - # GermanCredit(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, - # ItemResponseTheory(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000}, - # StochasticVolatility(): {'mclmc': 20000, 'mhmclmc' : 40000, 'nuts': 20000} -} - -# models = {'Brownian Motion': (Brownian(), {'mclmc': 50000, 'mhmclmc' : 40000, 'nuts': 1000}), -# # 'Item Response Theory': (ItemResponseTheory(), {'mclmc': 50000, 'mhmclmc' : 50000, 'nuts': 1000}) -# } diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py deleted file mode 100644 index b0a43edc4..000000000 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ /dev/null @@ -1,226 +0,0 @@ -# mypy: ignore-errors -# flake8: noqa - - -import jax -import jax.numpy as jnp - -import blackjax - -# from blackjax.adaptation.window_adaptation import da_adaptation -from blackjax.mcmc.integrators import ( - calls_per_integrator_step, - generate_euclidean_integrator, - generate_isokinetic_integrator, -) - -# from blackjax.mcmc.adjusted_mclmc import rescale -from blackjax.util import run_inference_algorithm - -__all__ = ["samplers"] - - -def run_nuts(coefficients, logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_euclidean_integrator(coefficients) - # integrator = blackjax.mcmc.integrators.velocity_verlet # note: defaulted to in nuts - - rng_key, warmup_key = jax.random.split(key, 2) - - state, params = da_adaptation( - rng_key=warmup_key, - initial_position=initial_position, - algorithm=blackjax.nuts, - logdensity_fn=logdensity_fn, - ) - - # print(params["inverse_mass_matrix"], "inv\n\n") - # warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn, integrator=integrator) - # (state, params), _ = warmup.run(warmup_key, initial_position, 2000) - - nuts = blackjax.nuts( - logdensity_fn=logdensity_fn, - step_size=params["step_size"], - inverse_mass_matrix=params["inverse_mass_matrix"], - integrator=integrator, - ) - - final_state, state_history, info_history = run_inference_algorithm( - rng_key=rng_key, - initial_state=state, - inference_algorithm=nuts, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) - - # print("INFO\n\n",info_history.num_integration_steps) - - return ( - state_history, - params, - info_history.num_integration_steps.mean() - * calls_per_integrator_step(coefficients), - info_history.acceptance_rate.mean(), - None, - None, - ) - - -def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - - init_key, tune_key, run_key = jax.random.split(key, 3) - - initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key - ) - - kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=integrator, - std_mat=std_mat, - ) - - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - diagonal_preconditioning=False, - # desired_energy_var= 1e-5 - ) - - # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.L, blackjax_mclmc_sampler_params.step_size)) - - sampling_alg = blackjax.mclmc( - logdensity_fn, - L=blackjax_mclmc_sampler_params.L, - step_size=blackjax_mclmc_sampler_params.step_size, - std_mat=blackjax_mclmc_sampler_params.std_mat, - integrator=integrator, - # std_mat=jnp.ones((initial_position.shape[0],)), - ) - - _, samples, _ = run_inference_algorithm( - rng_key=run_key, - initial_state=blackjax_state_after_tuning, - inference_algorithm=sampling_alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) - - acceptance_rate = 1.0 - return ( - samples, - blackjax_mclmc_sampler_params, - calls_per_integrator_step(coefficients), - acceptance_rate, - None, - None, - ) - - -def run_mhmclmc( - coefficients, - logdensity_fn, - num_steps, - initial_position, - transform, - key, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - target_acc_rate=None, -): - integrator = generate_isokinetic_integrator(coefficients) - - init_key, tune_key, run_key = jax.random.split(key, 3) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=logdensity_fn, - random_generator_arg=init_key, - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=integrator, - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, state=state, step_size=step_size, logdensity_fn=logdensity_fn - ) - - if target_acc_rate is None: - target_acc_rate = target_acceptance_rate_of_order[ - integrator_order(coefficients) - ] - print("target acc rate") - - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - params_history, - final_da, - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acc_rate, - frac_tune1=frac_tune1, - frac_tune2=frac_tune2, - frac_tune3=frac_tune3, - diagonal_preconditioning=False, - ) - - step_size = blackjax_mclmc_sampler_params.step_size - L = blackjax_mclmc_sampler_params.L - # jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.step_size, blackjax_mclmc_sampler_params.L)) - - alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn=lambda key: jnp.ceil( - jax.random.uniform(key) * rescale(L / step_size) - ), - integrator=integrator, - std_mat=blackjax_mclmc_sampler_params.std_mat, - ) - - _, out, info = run_inference_algorithm( - rng_key=run_key, - initial_state=blackjax_state_after_tuning, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) - - return ( - out, - blackjax_mclmc_sampler_params, - calls_per_integrator_step(coefficients) * (L / step_size), - info.acceptance_rate, - params_history, - final_da, - ) - - -# we should do at least: mclmc, nuts, unadjusted hmc, mhmclmc, langevin - -samplers = { - "nuts": run_nuts, - "mclmc": run_mclmc, - "mhmclmc": run_mhmclmc, -} - - -# foo = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(20.56)) - -# print(jnp.mean(jax.vmap(foo)(jax.random.split(jax.random.PRNGKey(1), 10000000)))) From 51fee690c854fd6492de5686140c4a54537ccb23 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 20:52:36 +0200 Subject: [PATCH 34/71] ADD TEST --- tests/mcmc/test_sampling.py | 83 +++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index c484fd6c1..604316e48 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,6 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -259,6 +260,88 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_mclmc_preconditioning(self): + class IllConditionedGaussian: + """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" + + def __init__(self, d, condition_number): + """numpy_seed is used to generate a random rotation for the covariance matrix. + If None, the covariance matrix is diagonal.""" + + self.ndims = d + self.name = "IllConditionedGaussian" + self.condition_number = condition_number + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), + 0.5 * jnp.log10(condition_number), + d, + ) + self.E_x2 = eigs + self.R = jnp.eye(d) + self.Hessian = jnp.diag(1 / eigs) + self.Cov = jnp.diag(eigs) + self.Var_x2 = 2 * jnp.square(self.E_x2) + + self.logdensity_fn = lambda x: -0.5 * x.T @ self.Hessian @ x + self.transform = lambda x: x + + self.sample_init = lambda key: jax.random.normal( + key, shape=(self.ndims,) + ) * jnp.max(jnp.sqrt(eigs)) + + dim = 100 + condition_number = 10 + eigs = jnp.logspace( + -0.5 * jnp.log10(condition_number), 0.5 * jnp.log10(condition_number), dim + ) + model = IllConditionedGaussian(dim, condition_number) + num_steps = 20000 + key = jax.random.PRNGKey(2) + + integrator = isokinetic_mclachlan + + def get_std_mat(): + init_key, tune_key = jax.random.split(key) + + initial_position = model.sample_init(init_key) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, + logdensity_fn=model.logdensity_fn, + rng_key=init_key, + ) + + kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=model.logdensity_fn, + integrator=integrator, + std_mat=std_mat, + ) + + ( + _, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + diagonal_preconditioning=True, + ) + + return blackjax_mclmc_sampler_params.std_mat + + std_mat = get_std_mat() + assert ( + jnp.abs( + jnp.dot( + (std_mat**2) / jnp.linalg.norm(std_mat**2), + eigs / jnp.linalg.norm(eigs), + ) + - 1 + ) + < 0.1 + ) + @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( self, From 29994d765e44529dbb69e11df77816254931ec5f Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 20:53:02 +0200 Subject: [PATCH 35/71] REMOVE BENCHMARKS --- blackjax/benchmarks/mcmc/benchmark.py | 908 -------------------------- 1 file changed, 908 deletions(-) delete mode 100644 blackjax/benchmarks/mcmc/benchmark.py diff --git a/blackjax/benchmarks/mcmc/benchmark.py b/blackjax/benchmarks/mcmc/benchmark.py deleted file mode 100644 index 9eadc7e2f..000000000 --- a/blackjax/benchmarks/mcmc/benchmark.py +++ /dev/null @@ -1,908 +0,0 @@ -# mypy: ignore-errors -# flake8: noqa - -import math -import operator -import os -import pprint -from collections import defaultdict -from functools import partial -from statistics import mean, median - -import jax -import jax.numpy as jnp -import pandas as pd -import scipy - -from blackjax.adaptation.mclmc_adaptation import ( - MCLMCAdaptationState, - integrator_order, - target_acceptance_rate_of_order, -) - -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=" + str(128) -num_cores = jax.local_device_count() -# print(num_cores, jax.lib.xla_bridge.get_backend().platform) - -import itertools - -import numpy as np - -import blackjax -from blackjax.benchmarks.mcmc.inference_models import ( - Brownian, - GermanCredit, - ItemResponseTheory, - MixedLogit, - StandardNormal, - StochasticVolatility, - models, -) -from blackjax.benchmarks.mcmc.sampling_algorithms import ( - run_mclmc, - run_mhmclmc, - run_nuts, - samplers, -) -from blackjax.mcmc.integrators import ( - calls_per_integrator_step, - generate_euclidean_integrator, - generate_isokinetic_integrator, - isokinetic_mclachlan, - mclachlan_coefficients, - name_integrator, - omelyan_coefficients, - velocity_verlet, - velocity_verlet_coefficients, - yoshida_coefficients, -) -from blackjax.mcmc.mhmclmc import rescale -from blackjax.util import run_inference_algorithm - - -def get_num_latents(target): - return target.ndims - - -# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0])))) - - -def err(f_true, var_f, contract): - """Computes the error b^2 = (f - f_true)^2 / var_f - Args: - f: E_sampler[f(x)], can be a vector - f_true: E_true[f(x)] - var_f: Var_true[f(x)] - contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max - - Returns: - contract(b^2) - """ - - return jax.vmap(lambda f: contract(jnp.square(f - f_true) / var_f)) - - -def grads_to_low_error(err_t, grad_evals_per_step=1, low_error=0.01): - """Uses the error of the expectation values to compute the effective sample size neff - b^2 = 1/neff""" - - cutoff_reached = err_t[-1] < low_error - return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached - - -def calculate_ess(err_t, grad_evals_per_step, neff=100): - grads_to_low, cutoff_reached = grads_to_low_error( - err_t, grad_evals_per_step, 1.0 / neff - ) - - return ( - (neff / grads_to_low) * cutoff_reached, - grads_to_low * (1 / cutoff_reached), - cutoff_reached, - ) - - -def find_crossing(array, cutoff): - """the smallest M such that array[m] < cutoff for all m > M""" - - b = array > cutoff - indices = jnp.argwhere(b) - if indices.shape[0] == 0: - print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff) - return 1 - - return jnp.max(indices) + 1 - - -def cumulative_avg(samples): - return jnp.cumsum(samples, axis=0) / jnp.arange(1, samples.shape[0] + 1)[:, None] - - -def gridsearch_tune( - key, - iterations, - grid_size, - model, - sampler, - batch, - num_steps, - center_L, - center_step_size, - contract, -): - results = defaultdict(float) - converged = False - keys = jax.random.split(key, iterations + 1) - for i in range(iterations): - print(f"EPOCH {i}") - width = 2 - step_sizes = np.logspace( - np.log10(center_step_size / width), - np.log10(center_step_size * width), - grid_size, - ) - Ls = np.logspace(np.log10(center_L / 2), np.log10(center_L * 2), grid_size) - # print(list(itertools.product(step_sizes , Ls))) - - grid_keys = jax.random.split(keys[i], grid_size ^ 2) - print(f"center step size {center_step_size}, center L {center_L}") - for j, (step_size, L) in enumerate(itertools.product(step_sizes, Ls)): - ess, grad_calls_until_convergence, _, _, _ = benchmark_chains( - model, - sampler(step_size=step_size, L=L), - grid_keys[j], - n=num_steps, - batch=batch, - contract=contract, - ) - results[(step_size, L)] = (ess, grad_calls_until_convergence) - - best_ess, best_grads, (step_size, L) = max( - ((results[r][0], results[r][1], r) for r in results), - key=operator.itemgetter(0), - ) - # raise Exception - print( - f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}" - ) - if L == center_L and step_size == center_step_size: - print("converged") - converged = True - break - else: - center_L, center_step_size = L, step_size - - pprint.pp(results) - # print(f"best params on iteration {i} are stepsize {step_size} and L {L} with Grad Calls until Convergence {best_grads}") - # print(f"L from ESS (0.4 * step_size/ESS): {0.4 * step_size/best_ess}") - return center_L, center_step_size, converged - - -def run_mhmclmc_no_tuning(initial_state, coefficients, step_size, L, std_mat): - def s(logdensity_fn, num_steps, initial_position, transform, key): - integrator = generate_isokinetic_integrator(coefficients) - - num_steps_per_traj = L / step_size - alg = blackjax.mcmc.mhmclmc.mhmclmc( - logdensity_fn=logdensity_fn, - step_size=step_size, - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(num_steps_per_traj) - ), - integrator=integrator, - std_mat=std_mat, - ) - - _, out, info = run_inference_algorithm( - rng_key=key, - initial_state=initial_state, - inference_algorithm=alg, - num_steps=num_steps, - transform=lambda x: transform(x.position), - progress_bar=True, - ) - - return ( - out, - MCLMCAdaptationState(L=L, step_size=step_size, std_mat=std_mat), - num_steps_per_traj * calls_per_integrator_step(coefficients), - info.acceptance_rate.mean(), - None, - jnp.array([0]), - ) - - return s - - -def benchmark_chains( - model, - sampler, - key, - n=10000, - batch=None, - contract=jnp.average, -): - pvmap = jax.pmap - - # def pvmap(f): - # def f(arr): - # return arr - # print(arr.shape,"shape") - # print(arr) - # arr = arr.reshape(128, -1) - # out = jax.vmap(jax.vmap(f), in_axes=0)(arr) - # return out.flatten() - # return f - - d = get_num_latents(model) - if batch is None: - batch = np.ceil(1000 / d).astype(int) - key, init_key = jax.random.split(key, 2) - keys = jax.random.split(key, batch) - - init_keys = jax.random.split(init_key, batch) - init_pos = pvmap(model.sample_init)(init_keys) - - # samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys) - ( - samples, - params, - grad_calls_per_traj, - acceptance_rate, - step_size_over_da, - final_da, - ) = pvmap( - lambda pos, key: sampler( - logdensity_fn=model.logdensity_fn, - num_steps=n, - initial_position=pos, - transform=model.transform, - key=key, - ) - )( - init_pos, keys - ) - avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0) - try: - print(jnp.nanmean(params.step_size, axis=0), jnp.nanmean(params.L, axis=0)) - except: - pass - - full = lambda arr: err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr)) - err_t = pvmap(full)(samples**2) - - # outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t] - # # print(outs[:10]) - # esses = [i[0].item() for i in outs if not math.isnan(i[0].item())] - # grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())] - # return(mean(esses), mean(grad_calls)) - # print(final_da.mean(), "final da") - - err_t_median = jnp.median(err_t, axis=0) - # import matplotlib.pyplot as plt - # plt.plot(np.arange(1, 1+ len(err_t_median))* 2, err_t_median, color= 'teal', lw = 3) - # plt.xlabel('gradient evaluations') - # plt.ylabel('average second moment error') - # plt.xscale('log') - # plt.yscale('log') - # plt.savefig('brownian.png') - # plt.close() - esses, grad_calls, _ = calculate_ess( - err_t_median, grad_evals_per_step=avg_grad_calls_per_traj - ) - return ( - esses, - grad_calls, - params, - jnp.mean(acceptance_rate, axis=0), - step_size_over_da, - ) - - -def run_benchmarks(batch_size): - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [Brownian()], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - num_steps = 100000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.max - - key = jax.random.PRNGKey(11) - for i in range(1): - key1, key = jax.random.split(key) - ( - ess, - grad_calls, - params, - acceptance_rate, - step_size_over_da, - ) = benchmark_chains( - model, - partial( - samplers[sampler], - coefficients=coefficients, - frac_tune1=0.1, - frac_tune2=0.0, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - jax.numpy.save(f"acceptance.npy", acceptance_rate) - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() - print(ess.item()) - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - # print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_simple.csv", index=False) - - return results - - -# vary step_size -def run_benchmarks_step_size(batch_size): - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmclmc"], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)], - [StandardNormal(10)], - # [Brownian()], - # [Brownian()], - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients], - ): - num_steps = 10000 - - sampler, model, coefficients = variables - num_chains = batch_size # 1 + batch_size//model.ndims - - # print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",) - - contract = jnp.average - - center = 6.534974 - key = jax.random.PRNGKey(11) - for step_size in np.linspace(center - 1, center + 1, 41): - # for L in np.linspace(1, 10, 41): - key1, key2, key3, key = jax.random.split(key, 4) - initial_position = model.sample_init(key2) - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=key3, - ) - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - initial_state=initial_state, - coefficients=mclachlan_coefficients, - step_size=step_size, - L=5 * step_size, - std_mat=1.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - # print(f"step size over da {step_size_over_da.shape} \n\n\n\n") - # jax.numpy.save(f"step_size_over_da.npy", step_size_over_da.mean(axis=0)) - # jax.numpy.save(f"acceptance.npy_{step_size}", acceptance_rate) - - # print(f"grads to low bias: {grad_calls}") - # print(f"acceptance rate is {acceptance_rate, acceptance_rate.mean()}") - - results[ - ( - (model.name, model.ndims), - sampler, - name_integrator(coefficients), - "standard", - acceptance_rate.mean().item(), - params.L.mean().item(), - params.step_size.mean().item(), - num_chains, - num_steps, - contract, - ) - ] = ess.item() - # results[(model.name, model.ndims, "nuts", 0., 0., name_integrator(coeffs), "standard", acceptance_rate)] - - # print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "sampler", - "integrator", - "tuning", - "acc rate", - "L", - "stepsize", - "num_chains", - "num steps", - "contraction", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results_step_size.csv", index=False) - - return results - - -def benchmark_mhmchmc(batch_size): - key0, key1, key2, key3 = jax.random.split(jax.random.PRNGKey(5), 4) - results = defaultdict(tuple) - - # coefficients = [yoshida_coefficients, mclachlan_coefficients, velocity_verlet_coefficients, omelyan_coefficients] - coefficients = [mclachlan_coefficients, velocity_verlet_coefficients] - for model, coeffs in itertools.product(models, coefficients): - num_chains = batch_size # 1 + batch_size//model.ndims - print(f"NUMBER OF CHAINS for {model.name} and MHMCLMC is {num_chains}") - num_steps = models[model]["mhmclmc"] - print(f"NUMBER OF STEPS for {model.name} and MHCMLMC is {num_steps}") - - ####### run mclmc with standard tuning - - contract = jnp.max - - ess, grad_calls, params, _, step_size_over_da = benchmark_chains( - model, - partial(run_mclmc, coefficients=coeffs), - key0, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mclmc", - params.L.mean().item(), - params.step_size.mean().item(), - name_integrator(coeffs), - "standard", - 1.0, - ) - ] = ess.item() - print(f"mclmc with tuning ESS {ess}") - - ####### run mhmclmc with standard tuning - for target_acc_rate in [0.65, 0.9]: - # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") - - # coeffs = mclachlan_coefficients - ess, grad_calls, params, acceptance_rate, _ = benchmark_chains( - model, - partial( - run_mhmclmc, - target_acc_rate=target_acc_rate, - coefficients=coeffs, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.1, - ), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "mhmchmc:st3" + str(target_acc_rate), - jnp.nanmean(params.L).item(), - jnp.nanmean(params.step_size).item(), - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - print(f"mhmclmc with tuning ESS {ess}") - - if True: - ####### run mhmclmc with standard tuning + grid search - - init_pos_key, init_key, tune_key, grid_key, bench_key = jax.random.split( - key2, 5 - ) - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coeffs), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, - _, - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coeffs)], - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.0, - diagonal_preconditioning=False, - ) - - print( - f"target acceptance rate {target_acceptance_rate_of_order[integrator_order(coeffs)]}" - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) - - L, step_size, convergence = gridsearch_tune( - grid_key, - iterations=10, - contract=contract, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coeffs, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) - # print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coeffs, - L=L, - step_size=step_size, - initial_state=state, - std_mat=1.0, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - print(f"grads to low bias: {grad_calls}") - - results[ - ( - model.name, - model.ndims, - "mhmchmc:grid", - L.item(), - step_size.item(), - name_integrator(coeffs), - f"gridsearch:{convergence}", - acceptance_rate.mean().item(), - ) - ] = ess.item() - - ####### run nuts - - # coeffs = velocity_verlet_coefficients - ess, grad_calls, _, acceptance_rate, _ = benchmark_chains( - model, - partial(run_nuts, coefficients=coeffs), - key3, - n=models[model]["nuts"], - batch=num_chains, - contract=contract, - ) - results[ - ( - model.name, - model.ndims, - "nuts", - 0.0, - 0.0, - name_integrator(coeffs), - "standard", - acceptance_rate.mean().item(), - ) - ] = ess.item() - - print(results) - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "L", - "step_size", - "integrator", - "tuning", - "acc_rate", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("results.csv", index=False) - - return results - - -def benchmark_omelyan(batch_size): - key = jax.random.PRNGKey(2) - results = defaultdict(tuple) - for variables in itertools.product( - # ["mhmclmc", "nuts", "mclmc", ], - ["mhmchmc"], - [ - StandardNormal(d) - for d in np.ceil(np.logspace(np.log10(10), np.log10(1000), 4)).astype(int) - ], - # [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 5)).astype(int)], - # models, - # [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients], - [mclachlan_coefficients, omelyan_coefficients], - ): - sampler, model, coefficients = variables - - # num_chains = 1 + batch_size//model.ndims - num_chains = batch_size - - current_key, key = jax.random.split(key) - init_pos_key, init_key, tune_key, bench_key, grid_key = jax.random.split( - current_key, 5 - ) - - # num_steps = models[model][sampler] - - num_steps = 1000 - - initial_position = model.sample_init(init_pos_key) - - initial_state = blackjax.mcmc.mhmclmc.init( - position=initial_position, - logdensity_fn=model.logdensity_fn, - random_generator_arg=init_key, - ) - - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.mhmclmc.build_kernel( - integrator=generate_isokinetic_integrator(coefficients), - integration_steps_fn=lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ), - std_mat=std_mat, - )( - rng_key=rng_key, - state=state, - step_size=step_size, - logdensity_fn=model.logdensity_fn, - ) - - ( - state, - blackjax_mhmclmc_sampler_params, - _, - _, - ) = blackjax.adaptation.mclmc_adaptation.mhmclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - target=target_acceptance_rate_of_order[integrator_order(coefficients)], - frac_tune1=0.1, - frac_tune2=0.1, - # frac_tune3=0.1, - diagonal_preconditioning=False, - ) - - print( - f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}", - ) - print( - f"params after initial tuning are L={blackjax_mhmclmc_sampler_params.L}, step_size={blackjax_mhmclmc_sampler_params.step_size}" - ) - - # ess, grad_calls, _ , _ = benchmark_chains(model, run_mhmclmc_no_tuning(coefficients=coefficients, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, std_mat=1.),bench_key_pre_grid, n=num_steps, batch=num_chains, contract=jnp.average) - - # results[((model.name, model.ndims), sampler, name_integrator(coefficients), "without grid search")] = (ess, grad_calls) - - L, step_size, converged = gridsearch_tune( - grid_key, - iterations=10, - contract=jnp.average, - grid_size=5, - model=model, - sampler=partial( - run_mhmclmc_no_tuning, - coefficients=coefficients, - initial_state=state, - std_mat=1.0, - ), - batch=num_chains, - num_steps=num_steps, - center_L=blackjax_mhmclmc_sampler_params.L, - center_step_size=blackjax_mhmclmc_sampler_params.step_size, - ) - print(f"params after grid tuning are L={L}, step_size={step_size}") - - ess, grad_calls, _, _, _ = benchmark_chains( - model, - run_mhmclmc_no_tuning( - coefficients=coefficients, - L=L, - step_size=step_size, - std_mat=1.0, - initial_state=state, - ), - bench_key, - n=num_steps, - batch=num_chains, - contract=jnp.average, - ) - - print(f"grads to low bias: {grad_calls}") - - results[ - ( - model.name, - model.ndims, - sampler, - name_integrator(coefficients), - converged, - L.item(), - step_size.item(), - ) - ] = ess.item() - - df = pd.Series(results).reset_index() - df.columns = [ - "model", - "dims", - "sampler", - "integrator", - "convergence", - "L", - "step_size", - "ESS", - ] - # df.result = df.result.apply(lambda x: x[0].item()) - # df.model = df.model.apply(lambda x: x[1]) - df.to_csv("omelyan.csv", index=False) - - -def run_benchmarks_divij(): - sampler = run_mclmc - model = StandardNormal(10) # 10 dimensional standard normal - coefficients = mclachlan_coefficients - contract = jnp.average # how we average across dimensions - num_steps = 2000 - num_chains = 100 - key1 = jax.random.PRNGKey(2) - - ess, grad_calls, params, acceptance_rate, step_size_over_da = benchmark_chains( - model, - partial(sampler, coefficients=coefficients), - key1, - n=num_steps, - batch=num_chains, - contract=contract, - ) - - print(f"Effective Sample Size (ESS) of 10D Normal is {ess}") - - -if __name__ == "__main__": - # run_benchmarks_divij() - - benchmark_mhmchmc(batch_size=128) - # run_benchmarks(128) - # run_benchmarks_step_size(128) - benchmark_omelyan(128) - # run_benchmarks(128) - # benchmark_omelyan(10) - # print("4") From 888fb09f970cced5a003de46189f622493203929 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 21:17:51 +0200 Subject: [PATCH 36/71] MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR --- blackjax/adaptation/window_adaptation.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index dd3e7b282..1026e7ac1 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp -from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info +import blackjax.mcmc as mcmc +from blackjax.adaptation.base import AdaptationInfo, AdaptationResults from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -240,19 +241,18 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: return init, update, final - def window_adaptation( - algorithm, + algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, - adaptation_info_fn: Callable = return_all_adapt_info, + integrator = mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC family. See Blackjax.hmc_family + algorithms in the HMC fmaily. Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -279,11 +279,6 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. - adaptation_info_fn - Function to select the adaptation info returned. See return_all_adapt_info - and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all - information is saved - this can result in excessive memory usage if the - information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -294,7 +289,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel() + mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, @@ -322,7 +317,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - adaptation_info_fn(new_state, info, new_adaptation_state), + AdaptationInfo(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): From 4123e4f7740f0178a40de944aea369865f825b96 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 23:12:19 +0200 Subject: [PATCH 37/71] MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR --- blackjax/adaptation/window_adaptation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 1026e7ac1..2efe4b1d4 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -242,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: return init, update, final def window_adaptation( - algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts], + algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, From 64948e5cd0e5a4df2356c62a7e7370a536ea8596 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 23:25:39 +0200 Subject: [PATCH 38/71] BUG FIX --- blackjax/mcmc/integrators.py | 2 +- tests/mcmc/test_integrators.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0f4deeca4..f995bd564 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -357,7 +357,7 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, std_mat: ArrayTree, *args, **kwargs + logdensity_fn: Callable, std_mat: ArrayTree=1., *args, **kwargs ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index ddb13ad57..04a13b2b7 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -234,7 +234,7 @@ def test_esh_momentum_update(self, dims): ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) # Efficient implementation - update_stable = self.variant(esh_dynamics_momentum_update_one_step) + update_stable = self.variant(esh_dynamics_momentum_update_one_step(std_mat=1.0)) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -258,7 +258,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step + op1 = esh_dynamics_momentum_update_one_step(std_mat=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( From c3d44f3e2ccfb490be1624f275badc6678a9a843 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 23:31:10 +0200 Subject: [PATCH 39/71] CHANGE PRECISION --- tests/mcmc/test_integrators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 04a13b2b7..68c12c499 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -134,8 +134,8 @@ def kinetic_energy(p, position=None): algorithms = { "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, - "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, - "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4}, + "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4}, "implicit_midpoint": { "algorithm": integrators.implicit_midpoint, "precision": 1e-4, From 94d43bd0155767f4e32987ae48fe5ffc3c3ce463 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 18 May 2024 23:41:25 +0200 Subject: [PATCH 40/71] CHANGE PRECISION --- blackjax/mcmc/integrators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index f995bd564..ed11fb1a0 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -357,7 +357,7 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, std_mat: ArrayTree=1., *args, **kwargs + logdensity_fn: Callable, std_mat: ArrayTree = 1.0, *args, **kwargs ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( From 636ef43b08ea98fccf37141027574530676b3b90 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 00:23:01 +0200 Subject: [PATCH 41/71] ADD OMELYAN TEST --- tests/mcmc/test_integrators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 2aad97bb7..0bb79813e 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -136,6 +136,7 @@ def kinetic_energy(p, position=None): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4}, + "omelyan": {"algorithm": integrators.omelyan, "precision": 1e-4}, "implicit_midpoint": { "algorithm": integrators.implicit_midpoint, "precision": 1e-4, @@ -143,6 +144,7 @@ def kinetic_energy(p, position=None): "isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet}, "isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan}, "isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida}, + "isokinetic_omelyan": {"algorithm": integrators.isokinetic_omelyan}, } @@ -168,6 +170,7 @@ class IntegratorTest(chex.TestCase): "velocity_verlet", "mclachlan", "yoshida", + "omelyan", "implicit_midpoint", ], ) @@ -297,6 +300,7 @@ def test_isokinetic_velocity_verlet(self): "isokinetic_velocity_verlet", "isokinetic_mclachlan", "isokinetic_yoshida", + "isokinetic_omelyan", ], ) def test_isokinetic_integrator(self, integrator_name): From 7b16464a8f89a30e353e191a9d0507f795b83c55 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:40:22 +0200 Subject: [PATCH 42/71] ADD ADJUSTED MCLMC TEST --- blackjax/adaptation/window_adaptation.py | 5 +- blackjax/mcmc/adjusted_mclmc.py | 6 +- blackjax/util.py | 1 + tests/mcmc/test_sampling.py | 113 ++++++++++++++++++++++- 4 files changed, 120 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 2efe4b1d4..cacd0b4a6 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -241,6 +241,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: return init, update, final + def window_adaptation( algorithm, logdensity_fn: Callable, @@ -248,7 +249,7 @@ def window_adaptation( initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, - integrator = mcmc.integrators.velocity_verlet, + integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 282f67658..6a6722f2c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -76,9 +76,11 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( - integrator=integrators.with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)), + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn, std_mat) + ), step_size=step_size, - L_proposal=L_proposal*num_integration_steps, + L_proposal=L_proposal * num_integration_steps, num_integration_steps=num_integration_steps, divergence_threshold=divergence_threshold, )( diff --git a/blackjax/util.py b/blackjax/util.py index 070ca8687..37133a4b5 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -251,6 +251,7 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): new streaming average """ + # x, _ = ravel_pytree(x) expectation = O(x) flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 604316e48..3585d4023 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,7 +14,13 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.integrators import isokinetic_mclachlan +from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + integrator_order, + isokinetic_mclachlan, + mclachlan_coefficients, +) from blackjax.util import run_inference_algorithm @@ -145,6 +151,86 @@ def run_mclmc( return samples + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + coefficients = mclachlan_coefficients + integrator = generate_isokinetic_integrator(coefficients) + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + std_mat=std_mat, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} + + target_acc_rate = target_acceptance_rate_of_order[ + integrator_order(coefficients) + ] + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + params_history, + final_da, + ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), + integrator=integrator, + std_mat=blackjax_mclmc_sampler_params.std_mat, + ) + + _, out, info = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda x: x.position, + expectation=lambda x: x.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -260,6 +346,31 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_adjusted_mclmc(self): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + def test_mclmc_preconditioning(self): class IllConditionedGaussian: """Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2.""" From 0a11a0f62f7412d7fd9ccc6071901e866cc8ed00 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:46:39 +0200 Subject: [PATCH 43/71] ADD ADJUSTED MCLMC TEST --- blackjax/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/blackjax/util.py b/blackjax/util.py index 37133a4b5..070ca8687 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -251,7 +251,6 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): new streaming average """ - # x, _ = ravel_pytree(x) expectation = O(x) flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg From 178b452b77548b349140339dfb88af0fd80b380b Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:49:15 +0200 Subject: [PATCH 44/71] RENAME O --- blackjax/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 070ca8687..600a7a961 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state): _, rng_key = xs average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average(expectation, state, average) + average = streaming_average(expectation(state), average) if return_state: return (average, state), (transform(state), info) else: @@ -232,7 +232,7 @@ def one_step(average_and_state, xs, return_state): return transform(final_state), state_history, info_history -def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): +def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- @@ -251,7 +251,6 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): new streaming average """ - expectation = O(x) flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg flat_average, _ = ravel_pytree(average) From a26d4a002b85c75a4df353e04b27282cc9fa6dbc Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:53:22 +0200 Subject: [PATCH 45/71] UPDATE STREAMING AVG --- blackjax/adaptation/mclmc_adaptation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index dc33eb21c..27321321a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -177,10 +177,10 @@ def step(iteration_state, weight_and_key): state, params, adaptive_state, rng_key ) + x = ravel_pytree(state.position)[0] # update the running average of x, x^2 streaming_avg = streaming_average( - O=lambda x: jnp.array([x, jnp.square(x)]), - x=ravel_pytree(state.position)[0], + expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask, From 9dd740f9e36c995be8f56a3229ae0ca0dddb5dbb Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:55:12 +0200 Subject: [PATCH 46/71] UPDATE STREAMING AVG --- blackjax/adaptation/mclmc_adaptation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 5f22b4026..59d17c300 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -460,10 +460,10 @@ def step(iteration_state, weight_and_key): adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) # step_size = 1e-3 + x = ravel_pytree(state.position)[0] # update the running average of x, x^2 streaming_avg = streaming_average( - O=lambda x: jnp.array([x, jnp.square(x)]), - x=ravel_pytree(state.position)[0], + expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * step_size, zero_prevention=mask, From 4d03b8936d488f4da41e4ebd5ffa6d0792ccebde Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 18:27:59 +0200 Subject: [PATCH 47/71] FIX MERGE --- blackjax/adaptation/window_adaptation.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index cacd0b4a6..63c54bad0 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -249,11 +249,12 @@ def window_adaptation( initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC fmaily. + algorithms in the HMC fmaily. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -280,6 +281,11 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -318,7 +324,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): From 6bacb6c03676d366498a7347b000cd6ad13992b1 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 May 2024 15:30:46 +0200 Subject: [PATCH 48/71] UPDATE PR --- blackjax/adaptation/mclmc_adaptation.py | 81 +++++++++++++++---------- blackjax/mcmc/integrators.py | 11 ++-- blackjax/mcmc/mclmc.py | 8 +-- blackjax/util.py | 6 +- tests/mcmc/test_integrators.py | 6 +- tests/mcmc/test_sampling.py | 18 +++--- 6 files changed, 76 insertions(+), 54 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 27321321a..73fa6a327 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average +from blackjax.util import pytree_size, streaming_average_update class MCLMCAdaptationState(NamedTuple): @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - std_mat + sqrt_diag_cov_mat A matrix used for preconditioning. """ L: float step_size: float - std_mat: float + sqrt_diag_cov_mat: float def mclmc_find_L_and_step_size( @@ -81,10 +81,30 @@ def mclmc_find_L_and_step_size( Returns ------- A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + + Example + ------- + .. code:: + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=integrator, + std_mat=std_mat, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + diagonal_preconditioning=preconditioning, + ) """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, std_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -101,7 +121,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -128,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel(params.std_mat)( + next_state, info = kernel(params.sqrt_diag_cov_mat)( rng_key=rng_key, state=previous_state, L=params.L, @@ -179,7 +199,7 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average( + streaming_avg = streaming_average_update( expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * params.step_size, @@ -188,6 +208,17 @@ def step(iteration_state, weight_and_key): return (state, params, adaptive_state, streaming_avg), None + run_steps = lambda xs, state, params: jax.lax.scan( + step, + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=xs, + )[0] + def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = ( int(num_steps * frac_tune1) + 1, @@ -205,45 +236,31 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps - state, params, _, (_, average) = jax.lax.scan( - step, - init=( - state, - params, - (0.0, 0.0, jnp.inf), - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), - ), - xs=(mask, L_step_size_adaptation_keys), - )[0] + state, params, _, (_, average) = run_steps( + xs=(mask, L_step_size_adaptation_keys), state=state, params=params + ) L = params.L # determine L - std_mat = params.std_mat + sqrt_diag_cov_mat = params.sqrt_diag_cov_mat if num_steps2 != 0.0: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - std_mat = jnp.sqrt(variances) - params = params._replace(std_mat=std_mat) + sqrt_diag_cov_mat = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat) L = jnp.sqrt(dim) # readjust the stepsize steps = num_steps2 // 3 # we do some small number of steps keys = jax.random.split(final_key, steps) - state, params, _, (_, average) = jax.lax.scan( - step, - init=( - state, - params, - (0.0, 0.0, jnp.inf), - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), - ), - xs=(jnp.ones(steps), keys), - )[0] - - return state, MCLMCAdaptationState(L, params.step_size, std_mat) + state, params, _, (_, average) = run_steps( + xs=(jnp.ones(steps), keys), state=state, params=params + ) + + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat) return L_step_size_adaptation diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index ed11fb1a0..2dce5671e 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(std_mat): +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0): def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -313,7 +313,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * std_mat + flatten_grads = flatten_grads * sqrt_diag_cov_mat flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -325,7 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * std_mat) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -357,11 +357,12 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, std_mat: ArrayTree = 1.0, *args, **kwargs + logdensity_fn: Callable, *args, **kwargs ) -> GeneralIntegrator: + sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(std_mat), + esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 62a6da735..d841f64e3 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, std_mat, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): """Build a HMC kernel. Parameters @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, std_mat, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)) + step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float @@ -105,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - std_mat=1.0, + sqrt_diag_cov_mat=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -153,7 +153,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, std_mat, integrator) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/blackjax/util.py b/blackjax/util.py index 02c27e51c..71d7345fb 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state): _, rng_key = xs average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average(expectation(transform(state)), average) + average = streaming_average_update(expectation(transform(state)), average) if return_state: return (average, state), (transform(state), info) else: @@ -232,7 +232,9 @@ def one_step(average_and_state, xs, return_state): return transform(final_state), state_history, info_history -def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0): +def streaming_average_update( + expectation, streaming_avg, weight=1.0, zero_prevention=0.0 +): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 68c12c499..3439f52e6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -234,7 +234,9 @@ def test_esh_momentum_update(self, dims): ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) # Efficient implementation - update_stable = self.variant(esh_dynamics_momentum_update_one_step(std_mat=1.0)) + update_stable = self.variant( + esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -258,7 +260,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(std_mat=1.0) + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 604316e48..fb272ae7a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -111,10 +111,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, ) ( @@ -132,7 +132,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - std_mat=blackjax_mclmc_sampler_params.std_mat, + sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, ) _, samples, _ = run_inference_algorithm( @@ -300,7 +300,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_std_mat(): + def get_sqrt_diag_cov_mat(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -311,10 +311,10 @@ def get_std_mat(): rng_key=init_key, ) - kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, ) ( @@ -328,13 +328,13 @@ def get_std_mat(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.std_mat + return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat - std_mat = get_std_mat() + sqrt_diag_cov_mat = get_sqrt_diag_cov_mat() assert ( jnp.abs( jnp.dot( - (std_mat**2) / jnp.linalg.norm(std_mat**2), + (sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2), eigs / jnp.linalg.norm(eigs), ) - 1 From cacb7924c61f607327e96618d3814ac403c952b3 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 May 2024 17:51:47 +0200 Subject: [PATCH 49/71] RENAME STD_MAT --- blackjax/adaptation/mclmc_adaptation.py | 12 ++++++------ blackjax/mcmc/adjusted_mclmc.py | 10 +++++----- tests/mcmc/test_sampling.py | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 9d2c08af6..5cce91305 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -89,10 +89,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov_mat : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, ) ( @@ -354,7 +354,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, std_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov_mat=jnp.ones((dim,)) ) else: params = params @@ -435,7 +435,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - std_mat=params.std_mat, + sqrt_diag_cov_mat=params.sqrt_diag_cov_mat, ) # jax.debug.print("step size during {x}",x=(params.step_size, params.L)) @@ -596,7 +596,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): if diagonal_preconditioning: # diagonal preconditioning - params = params._replace(std_mat=jnp.sqrt(variances)) + params = params._replace(sqrt_diag_cov_mat=jnp.sqrt(variances)) # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] # dyn, _, hyp, adap, kalman_state = state @@ -643,7 +643,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - std_mat=params.std_mat, + sqrt_diag_cov_mat=params.sqrt_diag_cov_mat, ) return next_state, next_state.position diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 6a6722f2c..9c48d82d5 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -33,13 +33,13 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) -# TODO: no default for std_mat +# TODO: no default for sqrt_diag_cov_mat def build_kernel( integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - std_mat=1.0, + sqrt_diag_cov_mat=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -77,7 +77,7 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn, std_mat) + integrator(logdensity_fn, sqrt_diag_cov_mat) ), step_size=step_size, L_proposal=L_proposal * num_integration_steps, @@ -107,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal: float = jnp.inf, - std_mat=1.0, + sqrt_diag_cov_mat=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -144,7 +144,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, divergence_threshold=divergence_threshold, ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ad69614f5..333e6947b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -170,12 +170,12 @@ def run_adjusted_mclmc( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, std_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, )( rng_key=rng_key, state=state, @@ -216,7 +216,7 @@ def run_adjusted_mclmc( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - std_mat=blackjax_mclmc_sampler_params.std_mat, + sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, ) _, out, info = run_inference_algorithm( From 3656bb9e6502751b0d27cd1ffe8ad482237ef99e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 May 2024 18:04:50 +0200 Subject: [PATCH 50/71] RENAME STD_MAT --- blackjax/adaptation/mclmc_adaptation.py | 32 ++++++++++++------------- blackjax/mcmc/adjusted_mclmc.py | 10 ++++---- blackjax/mcmc/integrators.py | 10 ++++---- blackjax/mcmc/mclmc.py | 8 +++---- tests/mcmc/test_integrators.py | 4 ++-- tests/mcmc/test_sampling.py | 24 +++++++++---------- 6 files changed, 44 insertions(+), 44 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 5cce91305..422ae1e67 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -34,13 +34,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov_mat + sqrt_diag_cov A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov_mat: float + sqrt_diag_cov: float def mclmc_find_L_and_step_size( @@ -89,10 +89,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -108,7 +108,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -125,7 +125,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel(params.sqrt_diag_cov_mat)( + next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, state=previous_state, L=params.L, @@ -246,15 +246,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov_mat = params.sqrt_diag_cov_mat + sqrt_diag_cov = params.sqrt_diag_cov if num_steps2 != 0.0: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov_mat = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat) + sqrt_diag_cov = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov=sqrt_diag_cov) L = jnp.sqrt(dim) # readjust the stepsize @@ -264,7 +264,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat) + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) return L_step_size_adaptation @@ -354,7 +354,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) ) else: params = params @@ -435,7 +435,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov_mat=params.sqrt_diag_cov_mat, + sqrt_diag_cov=params.sqrt_diag_cov, ) # jax.debug.print("step size during {x}",x=(params.step_size, params.L)) @@ -479,7 +479,7 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average( + streaming_avg = streaming_average_update( expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * step_size, @@ -596,7 +596,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): if diagonal_preconditioning: # diagonal preconditioning - params = params._replace(sqrt_diag_cov_mat=jnp.sqrt(variances)) + params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] # dyn, _, hyp, adap, kalman_state = state @@ -643,7 +643,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov_mat=params.sqrt_diag_cov_mat, + sqrt_diag_cov=params.sqrt_diag_cov, ) return next_state, next_state.position diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9c48d82d5..97717afc3 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -33,13 +33,13 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) -# TODO: no default for sqrt_diag_cov_mat +# TODO: no default for sqrt_diag_cov def build_kernel( integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov_mat=1.0, + sqrt_diag_cov=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -77,7 +77,7 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn, sqrt_diag_cov_mat) + integrator(logdensity_fn, sqrt_diag_cov) ), step_size=step_size, L_proposal=L_proposal * num_integration_steps, @@ -107,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal: float = jnp.inf, - sqrt_diag_cov_mat=1.0, + sqrt_diag_cov=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -144,7 +144,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index fca3991bb..4ae3e92f2 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -316,7 +316,7 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0): +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -335,7 +335,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov_mat + flatten_grads = flatten_grads * sqrt_diag_cov flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -347,7 +347,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -381,10 +381,10 @@ def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( logdensity_fn: Callable, *args, **kwargs ) -> GeneralIntegrator: - sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0) + sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat), + esh_dynamics_momentum_update_one_step(sqrt_diag_cov), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d841f64e3..27b5c2e9c 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """Build a HMC kernel. Parameters @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat)) + step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float @@ -105,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov_mat=1.0, + sqrt_diag_cov=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -153,7 +153,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c55eff6ff..c38009e5e 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 333e6947b..aa439d21b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -117,10 +117,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -138,7 +138,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) _, samples, _ = run_inference_algorithm( @@ -170,12 +170,12 @@ def run_adjusted_mclmc( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov_mat: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, )( rng_key=rng_key, state=state, @@ -216,7 +216,7 @@ def run_adjusted_mclmc( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) _, out, info = run_inference_algorithm( @@ -411,7 +411,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov_mat(): + def get_sqrt_diag_cov(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -422,10 +422,10 @@ def get_sqrt_diag_cov_mat(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -439,13 +439,13 @@ def get_sqrt_diag_cov_mat(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat + return blackjax_mclmc_sampler_params.sqrt_diag_cov - sqrt_diag_cov_mat = get_sqrt_diag_cov_mat() + sqrt_diag_cov = get_sqrt_diag_cov() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2), + (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), eigs / jnp.linalg.norm(eigs), ) - 1 From 9c2fea78f43d81d1b4cf594dbe348899245fffbc Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 May 2024 18:05:03 +0200 Subject: [PATCH 51/71] RENAME STD_MAT --- blackjax/adaptation/mclmc_adaptation.py | 18 +++++++++--------- blackjax/mcmc/integrators.py | 10 +++++----- blackjax/mcmc/mclmc.py | 8 ++++---- tests/mcmc/test_integrators.py | 4 ++-- tests/mcmc/test_sampling.py | 18 +++++++++--------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 73fa6a327..b1b012c70 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov_mat + sqrt_diag_cov A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov_mat: float + sqrt_diag_cov: float def mclmc_find_L_and_step_size( @@ -104,7 +104,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -121,7 +121,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -148,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel(params.sqrt_diag_cov_mat)( + next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, state=previous_state, L=params.L, @@ -242,15 +242,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov_mat = params.sqrt_diag_cov_mat + sqrt_diag_cov = params.sqrt_diag_cov if num_steps2 != 0.0: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov_mat = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat) + sqrt_diag_cov = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov=sqrt_diag_cov) L = jnp.sqrt(dim) # readjust the stepsize @@ -260,7 +260,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat) + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) return L_step_size_adaptation diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 2dce5671e..1cc698e8f 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0): +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -313,7 +313,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov_mat + flatten_grads = flatten_grads * sqrt_diag_cov flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -325,7 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -359,10 +359,10 @@ def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( logdensity_fn: Callable, *args, **kwargs ) -> GeneralIntegrator: - sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0) + sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat), + esh_dynamics_momentum_update_one_step(sqrt_diag_cov), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d841f64e3..27b5c2e9c 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """Build a HMC kernel. Parameters @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat)) + step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float @@ -105,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov_mat=1.0, + sqrt_diag_cov=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -153,7 +153,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 3439f52e6..937339aaf 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -235,7 +235,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -260,7 +260,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index fb272ae7a..63eada8ac 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -111,10 +111,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -132,7 +132,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) _, samples, _ = run_inference_algorithm( @@ -300,7 +300,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov_mat(): + def get_sqrt_diag_cov(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -311,10 +311,10 @@ def get_sqrt_diag_cov_mat(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -328,13 +328,13 @@ def get_sqrt_diag_cov_mat(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat + return blackjax_mclmc_sampler_params.sqrt_diag_cov - sqrt_diag_cov_mat = get_sqrt_diag_cov_mat() + sqrt_diag_cov = get_sqrt_diag_cov() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2), + (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), eigs / jnp.linalg.norm(eigs), ) - 1 From 88b06cb8bec5d4d4d9c0ff5ce173635023792c52 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 26 May 2024 14:16:56 +0200 Subject: [PATCH 52/71] MERGE MAIN --- blackjax/mcmc/integrators.py | 45 ------------------------------------ 1 file changed, 45 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 08ea04343..ec3ff0942 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -37,9 +37,6 @@ "isokinetic_omelyan", "isokinetic_yoshida", "implicit_midpoint", - "calls_per_integrator_step", - "name_integrator", - "integrator_order", ] @@ -402,48 +399,6 @@ def isokinetic_integrator( isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients) -def calls_per_integrator_step(c): - if c == velocity_verlet_coefficients: - return 1 - if c == mclachlan_coefficients: - return 2 - if c == yoshida_coefficients: - return 3 - if c == omelyan_coefficients: - return 5 - - else: - raise Exception("No such integrator exists in blackjax") - - -def name_integrator(c): - if c == velocity_verlet_coefficients: - return "velocity_verlet" - if c == mclachlan_coefficients: - return "mclachlan" - if c == yoshida_coefficients: - return "yoshida" - if c == omelyan_coefficients: - return "omelyan" - - else: - raise Exception("No such integrator exists in blackjax") - - -def integrator_order(c): - if c == velocity_verlet_coefficients: - return 2 - if c == mclachlan_coefficients: - return 2 - if c == yoshida_coefficients: - return 4 - if c == omelyan_coefficients: - return 4 - - else: - raise Exception("No such integrator exists in blackjax") - - def partially_refresh_momentum(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. From abe707c0511a33cadd83ba1c129ff7cad9679802 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 26 May 2024 23:08:24 +0200 Subject: [PATCH 53/71] REMOVE COEFFICIENT EXPORTS --- blackjax/mcmc/integrators.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index ec3ff0942..58e7ed810 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -23,10 +23,6 @@ from blackjax.types import ArrayTree __all__ = [ - "velocity_verlet_coefficients", - "mclachlan_coefficients", - "yoshida_coefficients", - "omelyan_coefficients", "mclachlan", "omelyan", "velocity_verlet", From 0629fa893e2991230f1e8481de983932eae9dcc5 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 May 2024 18:33:45 +0200 Subject: [PATCH 54/71] REMOVE COEFFICIENT EXPORTS --- tests/mcmc/test_sampling.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index aa439d21b..3bd80c92d 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -15,12 +15,7 @@ import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.mcmc.adjusted_mclmc import rescale -from blackjax.mcmc.integrators import ( - generate_isokinetic_integrator, - integrator_order, - isokinetic_mclachlan, - mclachlan_coefficients, -) +from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -159,8 +154,7 @@ def run_adjusted_mclmc( key, diagonal_preconditioning=False, ): - coefficients = mclachlan_coefficients - integrator = generate_isokinetic_integrator(coefficients) + integrator = isokinetic_mclachlan init_key, tune_key, run_key = jax.random.split(key, 3) @@ -183,11 +177,7 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - target_acceptance_rate_of_order = {2: 0.65, 4: 0.8} - - target_acc_rate = target_acceptance_rate_of_order[ - integrator_order(coefficients) - ] + target_acc_rate = 0.65 ( blackjax_state_after_tuning, @@ -225,7 +215,7 @@ def run_adjusted_mclmc( inference_algorithm=alg, num_steps=num_steps, transform=lambda x: x.position, - expectation=lambda x: x.position, + expectation=lambda x: x, progress_bar=False, ) From f4e80641be64f33f4794a757a0b2038130e5c65e Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 May 2024 18:48:07 +0200 Subject: [PATCH 55/71] RESOLVE MYPY ISSUE --- blackjax/adaptation/mclmc_adaptation.py | 69 ++----------------------- 1 file changed, 3 insertions(+), 66 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 422ae1e67..3f0994307 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -358,7 +358,6 @@ def adjusted_mclmc_find_L_and_step_size( ) else: params = params - # jax.debug.print("initial params {x}", x=params) part1_key, part2_key = jax.random.split(rng_key, 2) ( @@ -438,8 +437,6 @@ def step(iteration_state, weight_and_key): sqrt_diag_cov=params.sqrt_diag_cov, ) - # jax.debug.print("step size during {x}",x=(params.step_size, params.L)) - # step updating success, state, step_size_max, energy_change = handle_nans( previous_state, @@ -449,9 +446,6 @@ def step(iteration_state, weight_and_key): info.energy, ) - # jax.debug.print("info acc rate {x}", x=(info.acceptance_rate,)) - # jax.debug.print("state {x}", x=(state.position,)) - log_step_size, log_step_size_avg, step, avg_error, mu = update_da( adaptive_state, info.acceptance_rate ) @@ -465,12 +459,6 @@ def step(iteration_state, weight_and_key): mask * mu + (1 - mask) * adaptive_state.mu, ) - # jax.debug.print("{x} step_size before",x=(adaptive_state.log_step_size, info.acceptance_rate,)) - # adaptive_state = update(adaptive_state, info.acceptance_rate) - # jax.debug.print("{x} step_size after",x=(adaptive_state.log_step_size,)) - - # step_size = jax.lax.clamp(1e-3, jnp.exp(adaptive_state.log_step_size), 1e0) - # step_size = jax.lax.clamp(1e-5, jnp.exp(adaptive_state.log_step_size), step_size_max) step_size = jax.lax.clamp( 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 ) @@ -489,20 +477,14 @@ def step(iteration_state, weight_and_key): if fix_L: params = params._replace( step_size=mask * step_size + (1 - mask) * params.step_size, - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L ) else: params = params._replace( step_size=mask * step_size + (1 - mask) * params.step_size, L=mask * (params.L * (step_size / params.step_size)) - + (1 - mask) * params.L - # L=mask * ((params.L * (step_size / params.step_size))) + (1-mask)*params.L + + (1 - mask) * params.L, ) - # params = params._replace(step_size=step_size, - # L=(params.L/params.step_size * step_size) - # ) return (state, params, (adaptive_state, step_size_max), streaming_avg), ( info, @@ -518,7 +500,6 @@ def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da state, params, (initial_da(params.step_size), jnp.inf), # step size max - # (init(params.step_size), params.L/4), (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), ), xs=(mask, keys), @@ -529,8 +510,6 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps * frac_tune2 ) - # num_steps2=0 - rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) L_step_size_adaptation_keys_pass1 = jax.random.split( rng_key_pass1, num_steps1 + num_steps2 @@ -542,8 +521,6 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): initial_da, update_da, final_da = dual_averaging_adaptation(target=target) - # jax.debug.print("{x} initial num steps",x=(params.L/params.step_size)) - ( (state, params, (dual_avg_state, step_size_max), (_, average)), (info, params_history), @@ -556,54 +533,27 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): initial_da=initial_da, update_da=update_da, ) - # jax.debug.print("final da {x}", x=final_da(dual_avg_state)) - # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) - # params = params._replace(step_size=final_da(dual_avg_state)) - - # jax.debug.print("{x} new num steps",x=(params.L/params.step_size)) - - # jax.debug.print("{x} mean acceptance rate",x=((jnp.mean(info.acceptance_rate)))) - - # jax.debug.print("{x} params after a round of tuning",x=(params)) - # jax.debug.print("{x} step size max",x=(step_size_max)) - # jax.debug.print("{x} final",x=(final(dual_avg_state))) - # jax.debug.print("{x} params",x=(params)) - - # raise Exception # determine L if num_steps2 != 0.0: - # if False: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) - # jax.debug.print("{x} frac tune 2 guess",x=(jnp.sqrt(jnp.sum(variances)))) - # jax.debug.print("{x} frac tune 2 before",x=(params.L)) change = jax.lax.clamp( Lratio_lowerbound, jnp.sqrt(jnp.sum(variances)) / params.L, Lratio_upperbound, ) - # change = jnp.sqrt(jnp.sum(variances))/params.L - # jax.debug.print("{x} L ratio, old val, new val",x=(change, params.L, params.L*change)) - # jax.debug.print("{x} variance",x=(jnp.sqrt(jnp.sum(variances)))) params = params._replace( L=params.L * change, step_size=params.step_size * change ) - # params = params._replace(L=16.) - # params = params._replace(L=jnp.sqrt(jnp.sum(variances))) - # jax.debug.print("{x} params after a round of tuning",x=(params)) - if diagonal_preconditioning: - # diagonal preconditioning params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] # dyn, _, hyp, adap, kalman_state = state + # TODO ^ - # jax.debug.print("{x} params before second round",x=(params)) - # jax.debug.print("{x}",x=("L before", params.L)) - # jax.debug.print("{x}",x=("target", target)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( (state, params, (dual_avg_state, step_size_max), (_, average)), @@ -617,11 +567,6 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): update_da=update_da, initial_da=initial_da, ) - # params = params._replace(L=params.L * (final_da(dual_avg_state)/params.step_size)) - # params = params._replace(step_size=final_da(dual_avg_state)) - # jax.debug.print("{x} mean acceptance rate 2",x=(jnp.mean(info.acceptance_rate,))) - # jax.debug.print("{x}",x=("L after", params.L)) - # jax.debug.print("{x} params after a round of tuning",x=(params)) return state, params, params_history.step_size, final_da(dual_avg_state) @@ -635,8 +580,6 @@ def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) adaptation_L_keys = jax.random.split(key, num_steps) - # jax.debug.print("tune 1\n\n {x}", x=(params.L, params.step_size)) - def step(state, key): next_state, _ = kernel( rng_key=key, @@ -661,14 +604,8 @@ def step(state, key): (Lfactor * params.step_size * jnp.mean(num_steps / ess)) / params.L, Lratio_upperbound, ) - # change = (Lfactor * params.step_size * jnp.mean(num_steps / ess))/params.L - # jax.debug.print("tune 3\n\n {x}", x=(params.L*change, change)) - return state, params._replace( - # L=Lfactor * params.step_size * jnp.mean(num_steps / ess) - L=params.L - * change - ) + return state, params._replace(L=params.L * change) return adaptation_L From 40b5b982ed9cb021f17f7bc55e3df75cdc124d6c Mon Sep 17 00:00:00 2001 From: = Date: Mon, 27 May 2024 18:50:02 +0200 Subject: [PATCH 56/71] RESOLVE MYPY ISSUE --- blackjax/adaptation/mclmc_adaptation.py | 4 ---- blackjax/mcmc/adjusted_mclmc.py | 1 - 2 files changed, 5 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 3f0994307..699dd42d9 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -550,10 +550,6 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): if diagonal_preconditioning: params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) - # state = jax.lax.scan(step, init= state, xs= jnp.ones(steps), length= steps)[0] - # dyn, _, hyp, adap, kalman_state = state - # TODO ^ - initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( (state, params, (dual_avg_state, step_size_max), (_, average)), diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 97717afc3..ab4a51d74 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -33,7 +33,6 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) -# TODO: no default for sqrt_diag_cov def build_kernel( integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, From 0e0016b2534fc51e08491ea346f267ebe0afd113 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 30 May 2024 19:40:29 +0200 Subject: [PATCH 57/71] RETURN EXPECTATION HISTORY --- blackjax/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index b9b7a250c..8c373fb42 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -220,7 +220,7 @@ def one_step(average_and_state, xs, return_state): if return_state: return (average, state), (transform(state), info) else: - return (average, state), None + return (average, state), average[1] one_step = jax.jit(partial(one_step, return_state=return_state_history)) @@ -233,7 +233,7 @@ def one_step(average_and_state, xs, return_state): ) if not return_state_history: - return average, transform(final_state) + return average, transform(final_state), history else: state_history, info_history = history return transform(final_state), state_history, info_history From ca7935b604e07c74d2f39e4b03846ab37ec5823d Mon Sep 17 00:00:00 2001 From: = Date: Sun, 2 Jun 2024 15:42:13 -0400 Subject: [PATCH 58/71] FIX KWARG BUG --- blackjax/mcmc/mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 27b5c2e9c..9d19eadd9 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov)) + step = with_isokinetic_maruyama(integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float From 36adb401db0506d483ccc9d9fc460a2c3aa6583e Mon Sep 17 00:00:00 2001 From: = Date: Sun, 2 Jun 2024 15:45:44 -0400 Subject: [PATCH 59/71] FIX KWARG BUG --- blackjax/mcmc/mclmc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 9d19eadd9..e7a69849b 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -80,7 +80,9 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov)) + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float From c9c514dec17df3ede4c969b99ac9a483bc561bde Mon Sep 17 00:00:00 2001 From: = Date: Sun, 2 Jun 2024 16:04:51 -0400 Subject: [PATCH 60/71] FIX KWARG BUG IN ADJUSTED MCLMC --- blackjax/mcmc/adjusted_mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index ab4a51d74..644af515d 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -76,7 +76,7 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn, sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) ), step_size=step_size, L_proposal=L_proposal * num_integration_steps, From 93f0f466cd20d18cbd081fa2277744fe76b881e0 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 2 Jun 2024 16:09:24 -0400 Subject: [PATCH 61/71] MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT --- blackjax/adaptation/window_adaptation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index dd3e7b282..63c54bad0 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +import blackjax.mcmc as mcmc from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, @@ -249,10 +250,11 @@ def window_adaptation( target_acceptance_rate: float = 0.80, progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, + integrator=mcmc.integrators.velocity_verlet, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of - algorithms in the HMC family. See Blackjax.hmc_family + algorithms in the HMC fmaily. See Blackjax.hmc_family Algorithms in the HMC family on a euclidean manifold depend on the value of at least two parameters: the step size, related to the trajectory @@ -294,7 +296,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel() + mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, From bb2b9e7746515593993f3c4a792a0759ed664c31 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 3 Jun 2024 15:44:08 -0400 Subject: [PATCH 62/71] L_proposal_factor --- blackjax/mcmc/adjusted_mclmc.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 644af515d..261783a76 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -66,7 +66,7 @@ def kernel( state: DynamicHMCState, logdensity_fn: Callable, step_size: float, - L_proposal: float = jnp.inf, + L_proposal_factor: float = jnp.inf, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -79,7 +79,7 @@ def kernel( integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) ), step_size=step_size, - L_proposal=L_proposal * num_integration_steps, + L_proposal_factor=L_proposal_factor * (num_integration_steps/step_size), num_integration_steps=num_integration_steps, divergence_threshold=divergence_threshold, )( @@ -105,7 +105,7 @@ def kernel( def as_top_level_api( logdensity_fn: Callable, step_size: float, - L_proposal: float = jnp.inf, + L_proposal_factor: float = jnp.inf, sqrt_diag_cov=1.0, *, divergence_threshold: int = 1000, @@ -156,7 +156,7 @@ def update_fn(rng_key: PRNGKey, state): state, logdensity_fn, step_size, - L_proposal, + L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -165,7 +165,7 @@ def update_fn(rng_key: PRNGKey, state): def adjusted_mclmc_proposal( integrator: Callable, step_size: Union[float, ArrayLikeTree], - L_proposal: float, + L_proposal_factor: float, num_integration_steps: int = 1, divergence_threshold: float = 1000, *, @@ -202,7 +202,7 @@ def step(i, vars): state, kinetic_energy, rng_key = vars rng_key, next_rng_key = jax.random.split(rng_key) next_state, next_kinetic_energy = integrator( - state, step_size, L_proposal, rng_key + state, step_size, L_proposal_factor, rng_key ) return next_state, kinetic_energy + next_kinetic_energy, next_rng_key From da83d568a9d12174a932f0c82ed60fdcfd268d26 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 4 Jun 2024 06:55:20 -0400 Subject: [PATCH 63/71] SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE --- blackjax/__init__.py | 2 + .../adaptation/adjusted_mclmc_adaptation.py | 313 ++++++++++++++++++ blackjax/adaptation/mclmc_adaptation.py | 308 ----------------- blackjax/mcmc/adjusted_mclmc.py | 2 +- tests/mcmc/test_sampling.py | 2 +- 5 files changed, 317 insertions(+), 310 deletions(-) create mode 100644 blackjax/adaptation/adjusted_mclmc_adaptation.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 4b23614f5..8cd5fc759 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -3,6 +3,7 @@ from blackjax._version import __version__ +from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size from .adaptation.chees_adaptation import chees_adaptation from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation @@ -160,6 +161,7 @@ def generate_top_level_api_from(module): "chees_adaptation", "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation + "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", ] diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py new file mode 100644 index 000000000..5003aa523 --- /dev/null +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -0,0 +1,313 @@ +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, handle_nans +from blackjax.adaptation.step_size import ( + DualAveragingAdaptationState, + dual_averaging_adaptation, +) +from blackjax.diagnostics import effective_sample_size +from blackjax.util import pytree_size, streaming_average_update + +Lratio_lowerbound = 0.0 +Lratio_upperbound = 2.0 + + +def adjusted_mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + target, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=True, + params=None, +): + """ + Finds the optimal value of the parameters for the MH-MCHMC algorithm. + + Parameters + ---------- + mclmc_kernel + The kernel function used for the MCMC algorithm. + num_steps + The number of MCMC steps that will subsequently be run, after tuning. + state + The initial state of the MCMC algorithm. + rng_key + The random number generator key. + target + The target acceptance rate for the step size adaptation. + frac_tune1 + The fraction of tuning for the first step of the adaptation. + frac_tune2 + The fraction of tuning for the second step of the adaptation. + frac_tune3 + The fraction of tuning for the third step of the adaptation. + desired_energy_va + The desired energy variance for the MCMC algorithm. + trust_in_estimate + The trust in the estimate of optimal stepsize. + num_effective_samples + The number of effective samples for the MCMC algorithm. + + Returns + ------- + A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + """ + + dim = pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + + ( + state, + params, + params_history, + final_da_val, + ) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + target=target, + diagonal_preconditioning=diagonal_preconditioning, + )( + state, params, num_steps, part1_key + ) + + if frac_tune3 != 0: + part2_key1, part2_key2 = jax.random.split(part2_key, 2) + + state, params = adjusted_mclmc_make_adaptation_L( + mclmc_kernel, frac=frac_tune3, Lfactor=0.4 + )(state, params, num_steps, part2_key1) + + ( + state, + params, + params_history, + final_da_val, + ) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning, + )( + state, params, num_steps, part2_key2 + ) + + return state, params, params_history, final_da_val + + +def adjusted_mclmc_make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + target, + diagonal_preconditioning, + fix_L_first_da=False, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + + def dual_avg_step(fix_L, update_da): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + def step(iteration_state, weight_and_key): + mask, rng_key = weight_and_key + kernel_key, num_steps_key = jax.random.split(rng_key, 2) + ( + previous_state, + params, + (adaptive_state, step_size_max), + streaming_avg, + ) = iteration_state + + avg_num_integration_steps = params.L / params.step_size + + state, info = kernel( + rng_key=kernel_key, + state=previous_state, + avg_num_integration_steps=avg_num_integration_steps, + step_size=params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + state, + params.step_size, + step_size_max, + info.energy, + ) + + log_step_size, log_step_size_avg, step, avg_error, mu = update_da( + adaptive_state, info.acceptance_rate + ) + + adaptive_state = DualAveragingAdaptationState( + mask * log_step_size + (1 - mask) * adaptive_state.log_step_size, + mask * log_step_size_avg + + (1 - mask) * adaptive_state.log_step_size_avg, + mask * step + (1 - mask) * adaptive_state.step, + mask * avg_error + (1 - mask) * adaptive_state.avg_error, + mask * mu + (1 - mask) * adaptive_state.mu, + ) + + step_size = jax.lax.clamp( + 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 + ) + adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) + # step_size = 1e-3 + + x = ravel_pytree(state.position)[0] + # update the running average of x, x^2 + streaming_avg = streaming_average_update( + expectation=jnp.array([x, jnp.square(x)]), + streaming_avg=streaming_avg, + weight=(1 - mask) * success * step_size, + zero_prevention=mask, + ) + + if fix_L: + params = params._replace( + step_size=mask * step_size + (1 - mask) * params.step_size, + ) + + else: + params = params._replace( + step_size=mask * step_size + (1 - mask) * params.step_size, + L=mask * (params.L * (step_size / params.step_size)) + + (1 - mask) * params.L, + ) + + return (state, params, (adaptive_state, step_size_max), streaming_avg), ( + info, + params, + ) + + return step + + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): + return jax.lax.scan( + dual_avg_step(fix_L, update_da), + init=( + state, + params, + (initial_da(params.step_size), jnp.inf), # step size max + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, keys), + ) + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + + rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) + L_step_size_adaptation_keys_pass1 = jax.random.split( + rng_key_pass1, num_steps1 + num_steps2 + ) + L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) + + # determine which steps to ignore in the streaming average + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + mask, + state, + params, + L_step_size_adaptation_keys_pass1, + fix_L=fix_L_first_da, + initial_da=initial_da, + update_da=update_da, + ) + + # determine L + if num_steps2 != 0.0: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + + change = jax.lax.clamp( + Lratio_lowerbound, + jnp.sqrt(jnp.sum(variances)) / params.L, + Lratio_upperbound, + ) + params = params._replace( + L=params.L * change, step_size=params.step_size * change + ) + if diagonal_preconditioning: + params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + jnp.ones(num_steps1), + state, + params, + L_step_size_adaptation_keys_pass2, + fix_L=True, + update_da=update_da, + initial_da=initial_da, + ) + + return state, params, params_history.step_size, final_da(dual_avg_state) + + return L_step_size_adaptation + + +def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + def step(state, key): + next_state, _ = kernel( + rng_key=key, + state=state, + step_size=params.step_size, + avg_num_integration_steps=params.L / params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan( + f=step, + init=state, + xs=adaptation_L_keys, + ) + + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + + change = jax.lax.clamp( + Lratio_lowerbound, + (Lfactor * params.step_size * jnp.mean(num_steps / ess)) / params.L, + Lratio_upperbound, + ) + + return state, params._replace(L=params.L * change) + + return adaptation_L diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 699dd42d9..4050eeb1c 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -19,10 +19,6 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from blackjax.adaptation.step_size import ( - DualAveragingAdaptationState, - dual_averaging_adaptation, -) from blackjax.diagnostics import effective_sample_size from blackjax.util import pytree_size, streaming_average_update @@ -302,310 +298,6 @@ def step(state, key): return adaptation_L -Lratio_lowerbound = 0.0 -Lratio_upperbound = 2.0 - - -def adjusted_mclmc_find_L_and_step_size( - mclmc_kernel, - num_steps, - state, - rng_key, - target, - frac_tune1=0.1, - frac_tune2=0.1, - frac_tune3=0.1, - diagonal_preconditioning=True, - params=None, -): - """ - Finds the optimal value of the parameters for the MH-MCHMC algorithm. - - Parameters - ---------- - mclmc_kernel - The kernel function used for the MCMC algorithm. - num_steps - The number of MCMC steps that will subsequently be run, after tuning. - state - The initial state of the MCMC algorithm. - rng_key - The random number generator key. - target - The target acceptance rate for the step size adaptation. - frac_tune1 - The fraction of tuning for the first step of the adaptation. - frac_tune2 - The fraction of tuning for the second step of the adaptation. - frac_tune3 - The fraction of tuning for the third step of the adaptation. - desired_energy_va - The desired energy variance for the MCMC algorithm. - trust_in_estimate - The trust in the estimate of optimal stepsize. - num_effective_samples - The number of effective samples for the MCMC algorithm. - - Returns - ------- - A tuple containing the final state of the MCMC algorithm and the final hyperparameters. - """ - - dim = pytree_size(state.position) - if params is None: - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) - ) - else: - params = params - part1_key, part2_key = jax.random.split(rng_key, 2) - - ( - state, - params, - params_history, - final_da_val, - ) = adjusted_mclmc_make_L_step_size_adaptation( - kernel=mclmc_kernel, - dim=dim, - frac_tune1=frac_tune1, - frac_tune2=frac_tune2, - target=target, - diagonal_preconditioning=diagonal_preconditioning, - )( - state, params, num_steps, part1_key - ) - - if frac_tune3 != 0: - part2_key1, part2_key2 = jax.random.split(part2_key, 2) - - state, params = adjusted_mclmc_make_adaptation_L( - mclmc_kernel, frac=frac_tune3, Lfactor=0.4 - )(state, params, num_steps, part2_key1) - - ( - state, - params, - params_history, - final_da_val, - ) = adjusted_mclmc_make_L_step_size_adaptation( - kernel=mclmc_kernel, - dim=dim, - frac_tune1=frac_tune1, - frac_tune2=0, - target=target, - fix_L_first_da=True, - diagonal_preconditioning=diagonal_preconditioning, - )( - state, params, num_steps, part2_key2 - ) - - return state, params, params_history, final_da_val - - -def adjusted_mclmc_make_L_step_size_adaptation( - kernel, - dim, - frac_tune1, - frac_tune2, - target, - diagonal_preconditioning, - fix_L_first_da=False, -): - """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" - - def dual_avg_step(fix_L, update_da): - """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - - def step(iteration_state, weight_and_key): - mask, rng_key = weight_and_key - kernel_key, num_steps_key = jax.random.split(rng_key, 2) - ( - previous_state, - params, - (adaptive_state, step_size_max), - streaming_avg, - ) = iteration_state - - avg_num_integration_steps = params.L / params.step_size - - state, info = kernel( - rng_key=kernel_key, - state=previous_state, - avg_num_integration_steps=avg_num_integration_steps, - step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, - ) - - # step updating - success, state, step_size_max, energy_change = handle_nans( - previous_state, - state, - params.step_size, - step_size_max, - info.energy, - ) - - log_step_size, log_step_size_avg, step, avg_error, mu = update_da( - adaptive_state, info.acceptance_rate - ) - - adaptive_state = DualAveragingAdaptationState( - mask * log_step_size + (1 - mask) * adaptive_state.log_step_size, - mask * log_step_size_avg - + (1 - mask) * adaptive_state.log_step_size_avg, - mask * step + (1 - mask) * adaptive_state.step, - mask * avg_error + (1 - mask) * adaptive_state.avg_error, - mask * mu + (1 - mask) * adaptive_state.mu, - ) - - step_size = jax.lax.clamp( - 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 - ) - adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) - # step_size = 1e-3 - - x = ravel_pytree(state.position)[0] - # update the running average of x, x^2 - streaming_avg = streaming_average_update( - expectation=jnp.array([x, jnp.square(x)]), - streaming_avg=streaming_avg, - weight=(1 - mask) * success * step_size, - zero_prevention=mask, - ) - - if fix_L: - params = params._replace( - step_size=mask * step_size + (1 - mask) * params.step_size, - ) - - else: - params = params._replace( - step_size=mask * step_size + (1 - mask) * params.step_size, - L=mask * (params.L * (step_size / params.step_size)) - + (1 - mask) * params.L, - ) - - return (state, params, (adaptive_state, step_size_max), streaming_avg), ( - info, - params, - ) - - return step - - def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): - return jax.lax.scan( - dual_avg_step(fix_L, update_da), - init=( - state, - params, - (initial_da(params.step_size), jnp.inf), # step size max - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), - ), - xs=(mask, keys), - ) - - def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = int(num_steps * frac_tune1), int( - num_steps * frac_tune2 - ) - - rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) - L_step_size_adaptation_keys_pass1 = jax.random.split( - rng_key_pass1, num_steps1 + num_steps2 - ) - L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) - - # determine which steps to ignore in the streaming average - mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) - - initial_da, update_da, final_da = dual_averaging_adaptation(target=target) - - ( - (state, params, (dual_avg_state, step_size_max), (_, average)), - (info, params_history), - ) = step_size_adaptation( - mask, - state, - params, - L_step_size_adaptation_keys_pass1, - fix_L=fix_L_first_da, - initial_da=initial_da, - update_da=update_da, - ) - - # determine L - if num_steps2 != 0.0: - x_average, x_squared_average = average[0], average[1] - variances = x_squared_average - jnp.square(x_average) - - change = jax.lax.clamp( - Lratio_lowerbound, - jnp.sqrt(jnp.sum(variances)) / params.L, - Lratio_upperbound, - ) - params = params._replace( - L=params.L * change, step_size=params.step_size * change - ) - if diagonal_preconditioning: - params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) - - initial_da, update_da, final_da = dual_averaging_adaptation(target=target) - ( - (state, params, (dual_avg_state, step_size_max), (_, average)), - (info, params_history), - ) = step_size_adaptation( - jnp.ones(num_steps1), - state, - params, - L_step_size_adaptation_keys_pass2, - fix_L=True, - update_da=update_da, - initial_da=initial_da, - ) - - return state, params, params_history.step_size, final_da(dual_avg_state) - - return L_step_size_adaptation - - -def adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor): - """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" - - def adaptation_L(state, params, num_steps, key): - num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) - - def step(state, key): - next_state, _ = kernel( - rng_key=key, - state=state, - step_size=params.step_size, - avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, - ) - return next_state, next_state.position - - state, samples = jax.lax.scan( - f=step, - init=state, - xs=adaptation_L_keys, - ) - - flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) - ess = effective_sample_size(flat_samples[None, ...]) - - change = jax.lax.clamp( - Lratio_lowerbound, - (Lfactor * params.step_size * jnp.mean(num_steps / ess)) / params.L, - Lratio_upperbound, - ) - - return state, params._replace(L=params.L * change) - - return adaptation_L - - def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 261783a76..cd29a35a1 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -79,7 +79,7 @@ def kernel( integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) ), step_size=step_size, - L_proposal_factor=L_proposal_factor * (num_integration_steps/step_size), + L_proposal_factor=L_proposal_factor * (num_integration_steps / step_size), num_integration_steps=num_integration_steps, divergence_threshold=divergence_threshold, )( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 8e431e1ac..c276ff898 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -184,7 +184,7 @@ def run_adjusted_mclmc( blackjax_mclmc_sampler_params, params_history, final_da, - ) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size( + ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, From 31a7be676e736cff2c7dbe7bf89a459bda3be50a Mon Sep 17 00:00:00 2001 From: = Date: Tue, 4 Jun 2024 07:10:48 -0400 Subject: [PATCH 64/71] SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE --- blackjax/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 8c373fb42..b9b7a250c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -220,7 +220,7 @@ def one_step(average_and_state, xs, return_state): if return_state: return (average, state), (transform(state), info) else: - return (average, state), average[1] + return (average, state), None one_step = jax.jit(partial(one_step, return_state=return_state_history)) @@ -233,7 +233,7 @@ def one_step(average_and_state, xs, return_state): ) if not return_state_history: - return average, transform(final_state), history + return average, transform(final_state) else: state_history, info_history = history return transform(final_state), state_history, info_history From 90be1be1375e95be7411af6caf5ea68bef542e27 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 5 Jun 2024 11:45:09 -0400 Subject: [PATCH 65/71] RENAME STREAMING_AVERAGE_UPDATE ARGS IN ADJUSTED MCLMC ADAPTATION --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 5003aa523..e4a6f1d5b 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -131,7 +131,7 @@ def step(iteration_state, weight_and_key): previous_state, params, (adaptive_state, step_size_max), - streaming_avg, + previous_weight_and_average, ) = iteration_state avg_num_integration_steps = params.L / params.step_size @@ -174,9 +174,9 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( - expectation=jnp.array([x, jnp.square(x)]), - streaming_avg=streaming_avg, + previous_weight_and_average = streaming_average_update( + current_value=jnp.array([x, jnp.square(x)]), + previous_weight_and_average=previous_weight_and_average, weight=(1 - mask) * success * step_size, zero_prevention=mask, ) @@ -193,7 +193,7 @@ def step(iteration_state, weight_and_key): + (1 - mask) * params.L, ) - return (state, params, (adaptive_state, step_size_max), streaming_avg), ( + return (state, params, (adaptive_state, step_size_max), previous_weight_and_average), ( info, params, ) From c19df2186a7ec5504e6555a84276680cee5c5ede Mon Sep 17 00:00:00 2001 From: = Date: Wed, 5 Jun 2024 12:16:28 -0400 Subject: [PATCH 66/71] diagnostics --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index e4a6f1d5b..92133d142 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -193,7 +193,12 @@ def step(iteration_state, weight_and_key): + (1 - mask) * params.L, ) - return (state, params, (adaptive_state, step_size_max), previous_weight_and_average), ( + return ( + state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ), ( info, params, ) From 5ca1b229eb8c1355260018c21d605a63de6de2ff Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Jun 2024 23:16:06 +0200 Subject: [PATCH 67/71] fix bugs --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 2 +- blackjax/mcmc/adjusted_mclmc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 92133d142..26944a9fe 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -260,7 +260,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace(sqrt_diag_cov=jnp.sqrt(variances)) + params = params._replace(sqrt_diag_cov=jnp.sqrt(variances), L = jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index cd29a35a1..0b1bf1bc4 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -79,7 +79,7 @@ def kernel( integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) ), step_size=step_size, - L_proposal_factor=L_proposal_factor * (num_integration_steps / step_size), + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), num_integration_steps=num_integration_steps, divergence_threshold=divergence_threshold, )( From ccd9a284a239f104a2104197565ab868f2c69cfc Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Jun 2024 00:01:40 +0200 Subject: [PATCH 68/71] FIX MINOR TUNING BUGS --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 26944a9fe..47df58d2f 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -260,7 +260,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace(sqrt_diag_cov=jnp.sqrt(variances), L = jnp.sqrt(dim)) + params = params._replace( + sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) + ) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( From 1ef4c95f4e12eeb4e7a2542162f029c459596d90 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Jun 2024 15:50:52 +0200 Subject: [PATCH 69/71] UPDATE TUNING --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 47df58d2f..04f4a7f8d 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -309,12 +309,6 @@ def step(state, key): flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) - change = jax.lax.clamp( - Lratio_lowerbound, - (Lfactor * params.step_size * jnp.mean(num_steps / ess)) / params.L, - Lratio_upperbound, - ) - - return state, params._replace(L=params.L * change) + return state, params._replace(L=(0.4 * params.L) / ess) return adaptation_L From 7e21f43461ccabf9a5128ebced4cdea827df353b Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Jun 2024 15:55:34 +0200 Subject: [PATCH 70/71] UPDATE TUNING --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 04f4a7f8d..a8d180c28 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -309,6 +309,6 @@ def step(state, key): flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) - return state, params._replace(L=(0.4 * params.L) / ess) + return state, params._replace(L=(0.4 * params.L) / jnp.mean(ess)) return adaptation_L From 17f83e2522cd8af60e3e4d34c29574e6fc43546f Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Jun 2024 17:54:36 +0200 Subject: [PATCH 71/71] UPDATE TUNING --- blackjax/adaptation/adjusted_mclmc_adaptation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index a8d180c28..249766977 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -309,6 +309,6 @@ def step(state, key): flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) - return state, params._replace(L=(0.4 * params.L) / jnp.mean(ess)) + return state, params._replace(L=Lfactor * params.L * jnp.mean(num_steps / ess)) return adaptation_L