diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 76a016242..7645a890b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average_update +from blackjax.util import incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( - current_value=jnp.array([x, jnp.square(x)]), - previous_weight_and_average=streaming_avg, + streaming_avg = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask, ) diff --git a/blackjax/util.py b/blackjax/util.py index 9f4d6f9c7..b6c5367b5 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -3,12 +3,11 @@ from functools import partial from typing import Callable, Union -import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree from jax.random import normal, split -from jax.tree_util import tree_leaves +from jax.tree_util import tree_leaves, tree_map from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn @@ -149,9 +148,7 @@ def run_inference_algorithm( initial_state: ArrayLikeTree = None, initial_position: ArrayLikeTree = None, progress_bar: bool = False, - transform: Callable = lambda x: x, - return_state_history=True, - expectation: Callable = lambda x: x, + transform: Callable = lambda state, info: (state, info), ) -> tuple: """Wrapper to run an inference algorithm. @@ -166,8 +163,7 @@ def run_inference_algorithm( initial_state The initial state of the inference algorithm. initial_position - The initial position of the inference algorithm. This is used when the initial - state is not provided. + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -175,26 +171,14 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transformation of the trace of states (and info) to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done - incrementally, so doesn't require storing all the states. - return_state_history - if False, `run_inference_algorithm` will only return an expectation of the value - of transform, and return that average instead of the full set of samples. This - is useful when memory is a bottleneck. Returns ------- - If return_state_history is True: 1. The final state. - 2. The trace of the state. - 3. The trace of the info of the inference algorithm for diagnostics. - If return_state_history is False: - 1. This is the expectation of state over the chain. Otherwise the final state. - 2. The final state of the inference algorithm. + 2. The history of states. """ if initial_state is None and initial_position is None: @@ -212,58 +196,116 @@ def run_inference_algorithm( keys = split(rng_key, num_steps) - def one_step(average_and_state, xs, return_state): + def one_step(state, xs): _, rng_key = xs - average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average_update(expectation(transform(state)), average) - if return_state: - return (average, state), (transform(state), info) - else: - return (average, state), None + return state, transform(state, info) - one_step = jax.jit(partial(one_step, return_state=return_state_history)) - - xs = (jnp.arange(num_steps), keys) scan_fn = gen_scan_fn(num_steps, progress_bar) - ((_, average), final_state), history = scan_fn( - one_step, - ((0, expectation(transform(initial_state))), initial_state), - xs, - ) - if not return_state_history: - return average, transform(final_state) - else: - state_history, info_history = history - return transform(final_state), state_history, info_history + xs = jnp.arange(num_steps), keys + final_state, history = scan_fn(one_step, initial_state, xs) + + return final_state, history -def streaming_average_update( - current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0 +def store_only_expectation_values( + sampling_algorithm, + state_transform=lambda x: x, + incremental_value_transform=lambda x: x, + burn_in=0, +): + """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same + kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. + + It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample. + + Example: + + .. code:: + + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3) + model = StandardNormal(2) + initial_position = model.sample_init(init_key) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key + ) + integrator_type = "mclachlan" + L = 1.0 + step_size = 0.1 + num_steps = 4 + + integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] + state_transform = lambda state: state.position + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, + state_transform=state_transform) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + """ + + def init_fn(state): + averaging_state = (0.0, state_transform(state)) + return (state, averaging_state) + + def update_fn(rng_key, state_and_incremental_val): + state, averaging_state = state_and_incremental_val + state, info = sampling_algorithm.step( + rng_key, state + ) # update the state with the sampling algorithm + averaging_state = incremental_value_update( + state_transform(state), + averaging_state, + weight=( + averaging_state[0] >= burn_in + ), # If we want to eliminate some number of steps as a burn-in + zero_prevention=1e-10 * (burn_in > 0), + ) + # update the expectation value with the running average + return (state, averaging_state), info + + def transform(state_and_incremental_val, info): + (state, (_, incremental_value)) = state_and_incremental_val + return incremental_value_transform(incremental_value), info + + return SamplingAlgorithm(init_fn, update_fn), transform + + +def incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - current_value - the current value of the function that we want to take average of - previous_weight_and_average - tuple of (previous_weight, previous_average) where previous_weight is the - sum of weights and average is the current estimated average + expectation + the value of the expectation at the current timestep + incremental_val + tuple of (total, average) where total is the sum of weights and average is the current average weight weight of the current state zero_prevention small value to prevent division by zero Returns: ---------- - new total weight and streaming average + new streaming average """ - previous_weight, previous_average = previous_weight_and_average - current_weight = previous_weight + weight - current_average = jax.tree.map( - lambda x, avg: (previous_weight * avg + weight * x) - / (current_weight + zero_prevention), - current_value, - previous_average, + + total, average = incremental_val + average = tree_map( + lambda exp, av: (total * av + weight * exp) + / (total + weight + zero_prevention), + expectation, + average, ) - return current_weight, current_average + total += weight + return total, average diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 68751bee8..4b34511be 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters): algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, _, infos = jax.vmap( + _, (_, infos) = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 18a07625b..c399929da 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -135,12 +135,12 @@ def run_mclmc( sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) - _, samples, _ = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) return samples @@ -197,7 +197,7 @@ def check_attrs(attribute, keyset): for i, attribute in enumerate(["state", "info", "adaptation_state"]): check_attrs(attribute, keysets[i]) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, @@ -223,15 +223,16 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=mala, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -375,15 +376,16 @@ def test_pathfinder_adaptation( ) inference_algorithm = algorithm(logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + transform=lambda state, info: state.position, ) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -418,17 +420,18 @@ def test_meads(self): inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -465,17 +468,18 @@ def test_chees(self, jitter_generator): inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -494,15 +498,16 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=barker, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -679,19 +684,20 @@ def test_latent_gaussian(self): initial_state = inference_algorithm.init(jnp.zeros((1,))) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=self.sampling_steps, ), )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( - np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 + np.var(states[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose( - np.mean(states.position[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 + np.mean(states[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 ) @@ -724,7 +730,7 @@ def univariate_normal_test_case( **kwargs, ): inference_key, orbit_key = jax.random.split(rng_key) - _, states, _ = self.variant( + _, (states, info) = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, @@ -855,7 +861,7 @@ def postprocess_samples(states, key): 20_000, burnin, postprocess_samples, - transform=lambda x: (x.positions, x.weights), + transform=lambda state, info: ((state.positions, state.weights), info), ) @chex.all_variants(with_pmap=False) @@ -997,14 +1003,15 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=2_000, ) ) - _, states, _ = inference_loop_multiple_chains( + _, states = inference_loop_multiple_chains( rng_key=multi_chain_sample_key, initial_state=initial_states ) - posterior_samples = states.position[:, -1000:] + posterior_samples = states[:, -1000:] posterior_delta = posterior_samples - true_loc posterior_variance = posterior_delta**2.0 posterior_correlation = jnp.prod(posterior_delta, axis=-1, keepdims=True) / ( diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c2295e7e2..2d108a48d 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters): ) inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, diff --git a/tests/test_util.py b/tests/test_util.py index 1f03498dd..78198f013 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized import blackjax -from blackjax.util import run_inference_algorithm +from blackjax.util import run_inference_algorithm, store_only_expectation_values class RunInferenceAlgorithmTest(chex.TestCase): @@ -30,7 +30,7 @@ def check_compatible(self, initial_state, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) def test_streaming(self): @@ -41,37 +41,49 @@ def logdensity_fn(x): 10, ) - init_key, run_key = jax.random.split(self.key, 2) - + init_key, state_key, run_key = jax.random.split(self.key, 3) initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key + ) + L = 1.0 + step_size = 0.1 + num_steps = 4 + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=L, + step_size=step_size, ) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + state_transform = lambda x: x.position - _, states, info = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=True, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda state, info: state_transform(state), + progress_bar=True, + ) + + print("average of steps (slow way):", samples.mean(axis=0)) + + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, state_transform=state_transform ) - average, _ = run_inference_algorithm( + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=False, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, ) - assert jnp.allclose(states.mean(axis=0), average) + assert jnp.allclose(trace_at_every_step[0][-1], samples.mean(axis=0)) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): @@ -81,7 +93,7 @@ def test_compatible_with_initial_pos(self, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) @parameterized.parameters([True, False])