Skip to content

Commit

Permalink
Simplify run_inference_algorithm (#714)
Browse files Browse the repository at this point in the history
* fix minor type errors

* storing only expectation values

* fixed memory efficient sampling

* clean up

* renaming vars

* precommit fixes

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* merge main

* burn in and fix tests

* burn in and fix tests

* minor fixes

* minor fixes

* minor fixes

---------

Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
reubenharry and JakobRobnik authored Aug 12, 2024
1 parent 148c028 commit 7135fd7
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 110 deletions.
8 changes: 4 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
156 changes: 99 additions & 57 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal, split
from jax.tree_util import tree_leaves
from jax.tree_util import tree_leaves, tree_map

from blackjax.base import SamplingAlgorithm, VIAlgorithm
from blackjax.progress_bar import gen_scan_fn
Expand Down Expand Up @@ -149,9 +148,7 @@ def run_inference_algorithm(
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
return_state_history=True,
expectation: Callable = lambda x: x,
transform: Callable = lambda state, info: (state, info),
) -> tuple:
"""Wrapper to run an inference algorithm.
Expand All @@ -166,35 +163,22 @@ def run_inference_algorithm(
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial
state is not provided.
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Number of MCMC steps.
progress_bar
Whether to display a progress bar.
transform
A transformation of the trace of states to be returned. This is useful for
A transformation of the trace of states (and info) to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done
incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value
of transform, and return that average instead of the full set of samples. This
is useful when memory is a bottleneck.
Returns
-------
If return_state_history is True:
1. The final state.
2. The trace of the state.
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
2. The history of states.
"""

if initial_state is None and initial_position is None:
Expand All @@ -212,58 +196,116 @@ def run_inference_algorithm(

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
def one_step(state, xs):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None
return state, transform(state, info)

one_step = jax.jit(partial(one_step, return_state=return_state_history))

xs = (jnp.arange(num_steps), keys)
scan_fn = gen_scan_fn(num_steps, progress_bar)
((_, average), final_state), history = scan_fn(
one_step,
((0, expectation(transform(initial_state))), initial_state),
xs,
)

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history
xs = jnp.arange(num_steps), keys
final_state, history = scan_fn(one_step, initial_state, xs)

return final_state, history


def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
def store_only_expectation_values(
sampling_algorithm,
state_transform=lambda x: x,
incremental_value_transform=lambda x: x,
burn_in=0,
):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.
It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.
Example:
.. code::
init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3)
model = StandardNormal(2)
initial_position = model.sample_init(init_key)
initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key
)
integrator_type = "mclachlan"
L = 1.0
step_size = 0.1
num_steps = 4
integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda state: state.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)
initial_state = memory_efficient_sampling_alg.init(initial_state)
final_state, trace_at_every_step = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=memory_efficient_sampling_alg,
num_steps=num_steps,
transform=transform,
progress_bar=True,
)
"""

def init_fn(state):
averaging_state = (0.0, state_transform(state))
return (state, averaging_state)

def update_fn(rng_key, state_and_incremental_val):
state, averaging_state = state_and_incremental_val
state, info = sampling_algorithm.step(
rng_key, state
) # update the state with the sampling algorithm
averaging_state = incremental_value_update(
state_transform(state),
averaging_state,
weight=(
averaging_state[0] >= burn_in
), # If we want to eliminate some number of steps as a burn-in
zero_prevention=1e-10 * (burn_in > 0),
)
# update the expectation value with the running average
return (state, averaging_state), info

def transform(state_and_incremental_val, info):
(state, (_, incremental_value)) = state_and_incremental_val
return incremental_value_transform(incremental_value), info

return SamplingAlgorithm(init_fn, update_fn), transform


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
expectation
the value of the expectation at the current timestep
incremental_val
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
new streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
expectation,
average,
)
return current_weight, current_average
total += weight
return total, average
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters):
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
_, (_, infos) = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
Expand Down
Loading

0 comments on commit 7135fd7

Please sign in to comment.