From 32d04462df79deabdb0a57b5682261c7945254f5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Jan 2025 13:10:31 +0000 Subject: [PATCH] Initial `AbstractMCMC.step` should not sample (#366) * initial step should just construct the initial transitio and step; it should not perform any sampling of the parameters. * bump patch version * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed issue in callback for first iteration * attempt at fixing callback issues --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/abstractmcmc.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 9403bffe..0e2cdb63 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -165,8 +165,8 @@ function AbstractMCMC.step( # Compute next transition and state. state = HMCState(0, t, metric, κ, adaptor) - # Take actual first step. - return AbstractMCMC.step(rng, model, spl, state; kwargs...) + # Return the initial transition and state. + return Transition(t.z, merge(stat(t), (is_adapt = false,))), state end function AbstractMCMC.step( @@ -260,10 +260,13 @@ function (cb::HMCProgressCallback)( κ = state.κ tstat = t.stat isadapted = tstat.is_adapt - if isadapted - cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error - else - cb.num_divergent_transitions[] += tstat.numerical_error + # The initial transition will not much information beyond the `is_adapt` field. + if haskey(tstat, :numerical_error) + if isadapted + cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error + else + cb.num_divergent_transitions[] += tstat.numerical_error + end end # Update progress meter