Skip to content

Commit

Permalink
Change proposed sample in hmc to weighted mean, Fix weight assignment…
Browse files Browse the repository at this point in the history
… error in mcmc_stepgit add -A ()
  • Loading branch information
Micki-D committed Feb 14, 2025
1 parent a52127d commit d949377
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
16 changes: 12 additions & 4 deletions ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ function BAT.mcmc_propose!!(mc_state::HMCState)

p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate

x_proposed[:] = f_transform(z_proposed)
x_proposed[:], ladj = with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = logdensityof(target, x_proposed)
samples.logd[proposed_x_idx] = logd_x_proposed

sample_z.logd[proposed_z_idx] = logd_x_proposed + ladj

return mc_state, accepted, p_accept
end

Expand Down Expand Up @@ -200,6 +202,7 @@ function _bat_transition(
termination = AdvancedHMC.Termination(false, false)
zcand = z0
proposed_zs = Vector[]
accept_probs = Float64[]

j = 0
while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth
Expand All @@ -213,14 +216,18 @@ function _bat_transition(
AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
treeleft, treeright = tree, tree′
end

# This acceptance prob. is specific to AdvancedHMC.MultinomialTS
p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw))
push!(accept_probs, p_tmp)
push!(proposed_zs, sampler′.zcand.θ)

if !AdvancedHMC.isterminated(termination′)
j = j + 1
if AdvancedHMC.mh_accept(rng, sampler, sampler′)
zcand = sampler′.zcand
end
end
push!(proposed_zs, sampler′.zcand.θ)

tree = AdvancedHMC.combine(treeleft, treeright)
sampler = AdvancedHMC.combine(zcand, sampler, sampler′)
termination =
Expand All @@ -245,7 +252,8 @@ function _bat_transition(
AdvancedHMC.stat.integrator),
)

z_proposed = proposed_zs[end]
accept_total = sum(accept_probs)
z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total
p_accept = tstat.acceptance_rate

return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept
Expand Down
8 changes: 3 additions & 5 deletions src/samplers/mcmc/mcmc_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,13 @@ function mcmc_step!!(mcmc_state::MCMCState)

chain_state, accepted, p_accept = mcmc_propose!!(chain_state)

mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)

chain_state = mcmc_state_new.chain_state

current = _current_sample_idx(chain_state)
proposed = _proposed_sample_idx(chain_state)

_accept_reject!(chain_state, accepted, p_accept, current, proposed)

mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)

chain_state = mcmc_state_new.chain_state
mcmc_state_final = @set mcmc_state_new.chain_state = chain_state

return mcmc_state_final
Expand Down
1 change: 0 additions & 1 deletion test/samplers/mcmc/test_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import AdvancedHMC
@testset "MCMC iteration" begin
v_init = bat_initval(target, InitFromTarget(), context).result
# Note: No @inferred, since MCMCChainState is not type stable (yet) with HamiltonianMC
# TODO: MD, reactivate
@test BAT.MCMCChainState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) isa BAT.HMCState
mcmc_state = BAT.MCMCState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context))
nsteps = 10^4
Expand Down

0 comments on commit d949377

Please sign in to comment.