From 2e2562488468194b672557fe59b13a47b289cdd2 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Thu, 28 Mar 2024 22:16:02 +0100 Subject: [PATCH] Migrate from deprecated `host_callback` to `io_callback` (#651) * Migrate from deprecated `host_callback` to `io_callback` Co-Authored-By: George Necula * Format file * Fix bug --- blackjax/progress_bar.py | 51 ++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index d4fa45ca5..ac509b9b6 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -16,7 +16,7 @@ """ from fastprogress.fastprogress import progress_bar from jax import lax -from jax.experimental import host_callback +from jax.experimental import io_callback def progress_bar_scan(num_samples, print_rate=None): @@ -29,55 +29,39 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg, transform, device): + def _define_bar(arg): + del arg progress_bars[0] = progress_bar(range(num_samples)) progress_bars[0].update(0) - def _update_bar(arg, transform, device): - progress_bars[0].update_bar(arg) + def _update_bar(arg): + progress_bars[0].update_bar(arg + 1) + + def _close_bar(arg): + del arg + progress_bars[0].on_iter_end() def _update_progress_bar(iter_num): "Updates progress bar of a JAX scan or loop" _ = lax.cond( iter_num == 0, - lambda _: host_callback.id_tap( - _define_bar, iter_num, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, + lambda _: io_callback(_define_bar, None, iter_num), + lambda _: None, operand=None, ) _ = lax.cond( # update every multiple of `print_rate` except at the end - (iter_num % print_rate == 0), - lambda _: host_callback.id_tap( - _update_bar, iter_num, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, + (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), + lambda _: io_callback(_update_bar, None, iter_num), + lambda _: None, operand=None, ) _ = lax.cond( - # update by `remainder` - iter_num == num_samples - 1, - lambda _: host_callback.id_tap( - _update_bar, num_samples, result=iter_num, tap_with_device=True - ), - lambda _: iter_num, - operand=None, - ) - - def _close_bar(arg, transform, device): - progress_bars[0].on_iter_end() - print() - - def close_bar(result, iter_num): - return lax.cond( iter_num == num_samples - 1, - lambda _: host_callback.id_tap( - _close_bar, None, result=result, tap_with_device=True - ), - lambda _: result, + lambda _: io_callback(_close_bar, None, None), + lambda _: None, operand=None, ) @@ -94,8 +78,7 @@ def wrapper_progress_bar(carry, x): else: iter_num = x _update_progress_bar(iter_num) - result = func(carry, x) - return close_bar(result, iter_num) + return func(carry, x) return wrapper_progress_bar