diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index ec2f9a355..808eba975 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState) z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic)) # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics. - proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase) - accepted = z_current[:] != proposal.transition.z.θ - z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc + proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase) + p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate + + z_proposed[:] = proposal.transition.z.θ + accepted = z_current[:] != z_proposed[:] p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate @@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct mc_state_new = @set mc_state_new.f_transform = f_transform_new return mc_state_new end - - -# Copied from AdvancedHMC.jl, but also return proposed point -function _bat_transition( - rng::AbstractRNG, - τ::AdvancedHMC.Trajectory{TS,I,TC}, - h::AdvancedHMC.Hamiltonian, - z0::AdvancedHMC.PhasePoint, -) where { - TS<:AdvancedHMC.AbstractTrajectorySampler, - I<:AdvancedHMC.AbstractIntegrator, - TC<:AdvancedHMC.DynamicTerminationCriterion, -} - H0 = AdvancedHMC.energy(z0) - tree = AdvancedHMC.BinaryTree( - z0, - z0, - AdvancedHMC.TurnStatistic(τ.termination_criterion, z0), - zero(H0), - zero(Int), - zero(H0), - ) - sampler = TS(rng, z0) - termination = AdvancedHMC.Termination(false, false) - zcand = z0 - proposed_zs = Vector[] - accept_probs = Float64[] - - j = 0 - while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth - v = rand(rng, [-1, 1]) - if v == -1 - tree′, sampler′, termination′ = - AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0) - treeleft, treeright = tree′, tree - else - tree′, sampler′, termination′ = - AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0) - treeleft, treeright = tree, tree′ - end - - # This acceptance prob. is specific to AdvancedHMC.MultinomialTS - p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw)) - push!(accept_probs, p_tmp) - push!(proposed_zs, sampler′.zcand.θ) - - if !AdvancedHMC.isterminated(termination′) - j = j + 1 - if AdvancedHMC.mh_accept(rng, sampler, sampler′) - zcand = sampler′.zcand - end - end - tree = AdvancedHMC.combine(treeleft, treeright) - sampler = AdvancedHMC.combine(zcand, sampler, sampler′) - termination = - termination * - termination′ * - AdvancedHMC.isterminated(τ.termination_criterion, h, tree, treeleft, treeright) - end - - H = AdvancedHMC.energy(zcand) - tstat = AdvancedHMC.merge( - ( - n_steps = tree.nα, - is_accept = true, - acceptance_rate = tree.sum_α / tree.nα, - log_density = zcand.ℓπ.value, - hamiltonian_energy = H, - hamiltonian_energy_error = H - H0, - max_hamiltonian_energy_error = tree.ΔH_max, - tree_depth = j, - numerical_error = termination.numerical, - ), - AdvancedHMC.stat(τ.integrator), - ) - - accept_total = sum(accept_probs) - z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total - p_accept = tstat.acceptance_rate - - return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept -end diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 0f0ef0c8c..1b51a3743 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -161,6 +161,8 @@ function mcmc_step!!(mcmc_state::MCMCState) current = _current_sample_idx(chain_state) proposed = _proposed_sample_idx(chain_state) + + # This does not change `sample_z` in the chain_state, that happens in the next mcmc step in `_cleanup_samples()`. _accept_reject!(chain_state, accepted, p_accept, current, proposed) mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index cf19a602a..1c1922f60 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -76,6 +76,11 @@ function mcmc_tune_post_step!!( mc_state::MCMCChainState, p_accept::Real, ) + + if current_sample_z(mc_state).v == proposed_sample_z(mc_state) + return mc_state, tuner_state + end + (; f_transform, sample_z) = mc_state (; target_acceptance, gamma) = tuner_state.tuning b = f_transform.b