Skip to content

Commit

Permalink
Migrate from deprecated host_callback to io_callback (#651)
Browse files Browse the repository at this point in the history
* Migrate from deprecated `host_callback` to `io_callback`

Co-Authored-By:
George Necula <[email protected]>

* Format file

* Fix bug
  • Loading branch information
junpenglao authored Mar 28, 2024
1 parent 3dc3809 commit 2e25624
Showing 1 changed file with 17 additions and 34 deletions.
51 changes: 17 additions & 34 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import host_callback
from jax.experimental import io_callback


def progress_bar_scan(num_samples, print_rate=None):
Expand All @@ -29,55 +29,39 @@ def progress_bar_scan(num_samples, print_rate=None):
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg, transform, device):
def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)

def _update_bar(arg, transform, device):
progress_bars[0].update_bar(arg)
def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()

def _update_progress_bar(iter_num):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(
_define_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0),
lambda _: host_callback.id_tap(
_update_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update by `remainder`
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_update_bar, num_samples, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
operand=None,
)

def _close_bar(arg, transform, device):
progress_bars[0].on_iter_end()
print()

def close_bar(result, iter_num):
return lax.cond(
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_close_bar, None, result=result, tap_with_device=True
),
lambda _: result,
lambda _: io_callback(_close_bar, None, None),
lambda _: None,
operand=None,
)

Expand All @@ -94,8 +78,7 @@ def wrapper_progress_bar(carry, x):
else:
iter_num = x
_update_progress_bar(iter_num)
result = func(carry, x)
return close_bar(result, iter_num)
return func(carry, x)

return wrapper_progress_bar

Expand Down

0 comments on commit 2e25624

Please sign in to comment.