Skip to content

Commit

Permalink
Enable progress bar under pmap (#712)
Browse files Browse the repository at this point in the history
* enable pmap progbar

* fix bar creation

* add locking

* fix formatting

* switch to using chain state
  • Loading branch information
andrewdipper authored Aug 7, 2024
1 parent 441412a commit 27dfc9e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 27 deletions.
8 changes: 7 additions & 1 deletion blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,22 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
start_state = ((init_state, init_adaptation_state), -1)
else:
one_step_ = jax.jit(one_step)
start_state = (init_state, init_adaptation_state)

keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
start_state,
(jnp.arange(num_steps), keys, schedule),
)

if progress_bar:
last_state, _ = last_state

last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
Expand Down
55 changes: 33 additions & 22 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,64 @@
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from threading import Lock

from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import io_callback
from jax.numpy import array


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
idx_counter = 0
lock = Lock()

if print_rate is None:
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)
def _calc_chain_idx(iter_num):
nonlocal idx_counter
with lock:
idx = idx_counter
idx_counter += 1
return idx

def _update_bar(arg, chain_id):
chain_id = int(chain_id)
if arg == 0:
chain_id = _calc_chain_idx(arg)
progress_bars[chain_id] = progress_bar(range(num_samples))
progress_bars[chain_id].update(0)

def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)
progress_bars[chain_id].update_bar(arg + 1)
return chain_id

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()
def _close_bar(arg, chain_id):
progress_bars[int(chain_id)].on_iter_end()

def _update_progress_bar(iter_num):
def _update_progress_bar(iter_num, chain_id):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
chain_id = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
lambda _: io_callback(_update_bar, array(0), iter_num, chain_id),
lambda _: chain_id,
operand=None,
)

_ = lax.cond(
iter_num == num_samples - 1,
lambda _: io_callback(_close_bar, None, None),
lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id),
lambda _: None,
operand=None,
)
return chain_id

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Expand All @@ -77,8 +85,11 @@ def wrapper_progress_bar(carry, x):
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
return func(carry, x)
subcarry, chain_id = carry
chain_id = _update_progress_bar(iter_num, chain_id)
subcarry, y = func(subcarry, x)

return (subcarry, chain_id), y

return wrapper_progress_bar

Expand Down
14 changes: 10 additions & 4 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,19 @@ def one_step(average_and_state, xs, return_state):

one_step = jax.jit(partial(one_step, return_state=return_state_history))

xs = (jnp.arange(num_steps), keys)
if progress_bar:
one_step = progress_bar_scan(num_steps)(one_step)
(((_, average), final_state), _), history = lax.scan(
one_step,
(((0, expectation(transform(initial_state))), initial_state), -1),
xs,
)

xs = (jnp.arange(num_steps), keys)
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)
else:
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)

if not return_state_history:
return average, transform(final_state)
Expand Down

0 comments on commit 27dfc9e

Please sign in to comment.