Skip to content

Commit

Permalink
Remove proposed sample extraction for HMC altogether
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Feb 14, 2025
1 parent d949377 commit 0b5e29f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 85 deletions.
90 changes: 5 additions & 85 deletions ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic))
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.

proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase)
accepted = z_current[:] != proposal.transition.z.θ
z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc
proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase)
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate

z_proposed[:] = proposal.transition.z.θ
accepted = z_current[:] != z_proposed[:]

p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate

Expand Down Expand Up @@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
mc_state_new = @set mc_state_new.f_transform = f_transform_new
return mc_state_new
end


# Copied from AdvancedHMC.jl, but also return proposed point
function _bat_transition(
rng::AbstractRNG,
τ::AdvancedHMC.Trajectory{TS,I,TC},
h::AdvancedHMC.Hamiltonian,
z0::AdvancedHMC.PhasePoint,
) where {
TS<:AdvancedHMC.AbstractTrajectorySampler,
I<:AdvancedHMC.AbstractIntegrator,
TC<:AdvancedHMC.DynamicTerminationCriterion,
}
H0 = AdvancedHMC.energy(z0)
tree = AdvancedHMC.BinaryTree(
z0,
z0,
AdvancedHMC.TurnStatistic.termination_criterion, z0),
zero(H0),
zero(Int),
zero(H0),
)
sampler = TS(rng, z0)
termination = AdvancedHMC.Termination(false, false)
zcand = z0
proposed_zs = Vector[]
accept_probs = Float64[]

j = 0
while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth
v = rand(rng, [-1, 1])
if v == -1
tree′, sampler′, termination′ =
AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0)
treeleft, treeright = tree′, tree
else
tree′, sampler′, termination′ =
AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
treeleft, treeright = tree, tree′
end

# This acceptance prob. is specific to AdvancedHMC.MultinomialTS
p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw))
push!(accept_probs, p_tmp)
push!(proposed_zs, sampler′.zcand.θ)

if !AdvancedHMC.isterminated(termination′)
j = j + 1
if AdvancedHMC.mh_accept(rng, sampler, sampler′)
zcand = sampler′.zcand
end
end
tree = AdvancedHMC.combine(treeleft, treeright)
sampler = AdvancedHMC.combine(zcand, sampler, sampler′)
termination =
termination *
termination′ *
AdvancedHMC.isterminated.termination_criterion, h, tree, treeleft, treeright)
end

H = AdvancedHMC.energy(zcand)
tstat = AdvancedHMC.merge(
(
n_steps = tree.nα,
is_accept = true,
acceptance_rate = tree.sum_α / tree.nα,
log_density = zcand.ℓπ.value,
hamiltonian_energy = H,
hamiltonian_energy_error = H - H0,
max_hamiltonian_energy_error = tree.ΔH_max,
tree_depth = j,
numerical_error = termination.numerical,
),
AdvancedHMC.stat.integrator),
)

accept_total = sum(accept_probs)
z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total
p_accept = tstat.acceptance_rate

return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept
end
2 changes: 2 additions & 0 deletions src/samplers/mcmc/mcmc_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ function mcmc_step!!(mcmc_state::MCMCState)

current = _current_sample_idx(chain_state)
proposed = _proposed_sample_idx(chain_state)

# This does not change `sample_z` in the chain_state, that happens in the next mcmc step in `_cleanup_samples()`.
_accept_reject!(chain_state, accepted, p_accept, current, proposed)

mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)
Expand Down
5 changes: 5 additions & 0 deletions src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ function mcmc_tune_post_step!!(
mc_state::MCMCChainState,
p_accept::Real,
)

if current_sample_z(mc_state).v == proposed_sample_z(mc_state)
return mc_state, tuner_state

Check warning on line 81 in src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl#L81

Added line #L81 was not covered by tests
end

(; f_transform, sample_z) = mc_state
(; target_acceptance, gamma) = tuner_state.tuning
b = f_transform.b
Expand Down

0 comments on commit 0b5e29f

Please sign in to comment.