Skip to content

Commit 61613a7

Browse files
committed
Update progressbar managers with existing fit results from ZarrTrace
1 parent 0ab1596 commit 61613a7

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

pymc/sampling/mcmc.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,13 +1157,24 @@ def _sample_many(
11571157

11581158
with progress_manager:
11591159
for i in range(chains):
1160+
trace = traces[i]
1161+
if isinstance(trace, ZarrChain):
1162+
progress_manager.set_initial_state(*trace.completed_draws_and_divergences())
1163+
progress_manager._progress.update(
1164+
progress_manager.tasks[i],
1165+
draws=progress_manager.completed_draws
1166+
if progress_manager.combined_progress
1167+
else progress_manager.draws,
1168+
divergences=progress_manager.divergences,
1169+
refresh=True,
1170+
)
11601171
step.sampling_state = initial_step_state
11611172
_sample(
11621173
draws=draws,
11631174
chain=i,
11641175
start=start[i],
11651176
step=step,
1166-
trace=traces[i],
1177+
trace=trace,
11671178
rng=rngs[i],
11681179
callback=callback,
11691180
progress_manager=progress_manager,

pymc/sampling/parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ def __init__(
509509
progressbar=progressbar,
510510
progressbar_theme=progressbar_theme,
511511
)
512+
if self.zarr_recording:
513+
self._progress.set_initial_state(
514+
*cast(ZarrChain, zarr_chains)[0].completed_draws_and_divergences()
515+
)
512516

513517
def _make_active(self):
514518
while self._inactive and len(self._active) < self._max_active:

pymc/sampling/population.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def _sample_population(
110110

111111
with CustomProgress(disable=not progressbar) as progress:
112112
task = progress.add_task("[red]Sampling...", total=draws)
113+
if isinstance(traces[0], ZarrChain):
114+
completed_draws, _ = traces[0].completed_draws_and_divergences()
115+
progress.update(task, completed=completed_draws)
113116
for _ in sampling:
114117
progress.update(task)
115118

@@ -197,6 +200,7 @@ def __init__(
197200
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
198201
# ):
199202
task = self._progress.add_task(description=f"Chain {c}")
203+
self._progress.update(task, completed=first_draw_idx)
200204
secondary_end, primary_end = multiprocessing.Pipe()
201205
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
202206
process = multiprocessing.Process(

pymc/util.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def __init__(
812812

813813
self._show_progress = show_progress
814814
self.divergences = 0
815+
self.draws = 0
815816
self.completed_draws = 0
816817
self.total_draws = draws + tune
817818
self.desc = "Sampling chain"
@@ -827,27 +828,35 @@ def __enter__(self):
827828
def __exit__(self, exc_type, exc_val, exc_tb):
828829
return self._progress.__exit__(exc_type, exc_val, exc_tb)
829830

831+
def set_initial_state(self, draws: int = 0, divergences: int = 0):
832+
self.draws = draws
833+
self.completed_draws += draws
834+
self.divergences = divergences
835+
830836
def _initialize_tasks(self):
831837
if self.combined_progress:
832838
self.tasks = [
833839
self._progress.add_task(
834840
self.desc.format(self),
835-
completed=0,
836-
draws=0,
841+
completed=self.completed_draws,
842+
draws=self.completed_draws,
837843
total=self.total_draws * self.chains - 1,
838844
chain_idx=0,
839845
sampling_speed=0,
840846
speed_unit="draws/s",
841-
**{stat: value[0] for stat, value in self.progress_stats.items()},
847+
**{
848+
stat: value[0] if stat != "diverging" else self.divergences
849+
for stat, value in self.progress_stats.items()
850+
},
842851
)
843852
]
844853

845854
else:
846855
self.tasks = [
847856
self._progress.add_task(
848857
self.desc.format(self),
849-
completed=0,
850-
draws=0,
858+
completed=self.completed_draws,
859+
draws=self.draws,
851860
total=self.total_draws - 1,
852861
chain_idx=chain_idx,
853862
sampling_speed=0,

0 commit comments

Comments
 (0)