Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hmc proposed sample handling refactor #468

Merged
merged 11 commits into from
Feb 17, 2025
30 changes: 22 additions & 8 deletions ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,51 @@ using Random
using DensityInterface
using HeterogeneousComputing, AutoDiffOperators

using Accessors: @set, @reset

using AffineMaps: MulAdd

using BAT: MeasureLike, BATMeasure

using BAT: get_context, get_adselector, _NoADSelected
using BAT: getproposal, mcmc_target
using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering
using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples
using BAT: AbstractTransformTarget, NoAdaptiveTransform
using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin
using BAT: MCMCBasicStats, push!, reweight_relative!
using BAT: RAMTuning
using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning
using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample_z, proposed_sample
using BAT: AbstractTransformTarget, NoAdaptiveTransform, TriangularAffineTransform, valgrad_func
using BAT: RNGPartition, get_rng, set_rng!
using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio
using BAT: get_samples!, reset_rng_counters!
using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!
using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, set_mc_state_transform!!, mcmc_update_z_position!!
using BAT: totalndof, measure_support, checked_logdensityof
using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE

using BAT: HamiltonianMC
using BAT: AHMCSampleID, AHMCSampleIDVector
using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric
using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning
using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanLikeTuning

using ValueShapes: varshape
using ChangesOfVariables: with_logabsdet_jacobian

using LinearAlgebra: cholesky

using Accessors: @set
using MeasureBase: pullback

using Parameters: @with_kw

using PositiveFactorizations: Positive

using ValueShapes: varshape

BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_INTEGRATOR}) = AdvancedHMC.Leapfrog(NaN)
BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_TERMINATION_CRITERION}) = AdvancedHMC.GeneralisedNoUTurn()


include("ahmc_impl/ahmc_stan_tuner_impl.jl")
include("ahmc_impl/ahmc_config_impl.jl")
include("ahmc_impl/ahmc_sampler_impl.jl")
include("ahmc_impl/ahmc_tuner_impl.jl")


end # module BATAdvancedHMCExt
8 changes: 4 additions & 4 deletions ext/ahmc_impl/ahmc_config_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

# Integrator ==============================================

function _ahmc_set_step_size(integrator::AdvancedHMC.AbstractIntegrator, hamiltonian::AdvancedHMC.Hamiltonian, θ_init::AbstractVector{<:Real})
function _ahmc_set_step_size(integrator::AdvancedHMC.AbstractIntegrator, hamiltonian::AdvancedHMC.Hamiltonian, θ_init::AbstractVector{<:Real}, rng::AbstractRNG)
# ToDo: Add way to specify max_n_iters
T = eltype(θ_init)
step_size = integrator.ϵ
if isnan(step_size)
new_step_size = AdvancedHMC.find_good_stepsize(hamiltonian, θ_init, max_n_iters = 100)
new_step_size = AdvancedHMC.find_good_stepsize(rng, hamiltonian, θ_init, max_n_iters = 100)
@set integrator.ϵ = T(new_step_size)
else
@set integrator.ϵ = T(step_size)
Expand Down Expand Up @@ -57,7 +57,7 @@ function ahmc_adaptor(
θ_init::AbstractVector{<:Real}
)
T = eltype(θ_init)
return AdvancedHMC.StepSizeAdaptor(tuning.target_acceptance, integrator)
return AdvancedHMC.StepSizeAdaptor(T(tuning.target_acceptance), integrator)
end

function ahmc_adaptor(
Expand All @@ -73,7 +73,7 @@ function ahmc_adaptor(
end

function ahmc_adaptor(
tuning::StanHMCTuning,
tuning::StanLikeTuning,
metric::AdvancedHMC.AbstractMetric,
integrator::AdvancedHMC.AbstractIntegrator,
θ_init::AbstractVector{<:Real}
Expand Down
78 changes: 54 additions & 24 deletions ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::HamiltonianMC) = PriorToNormal()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning()
BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StepSizeAdaptor()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform()
BAT.bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = RAMTuning()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = TriangularAffineTransform()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:tempering}, proposal::HamiltonianMC) = NoMCMCTempering()

Expand All @@ -17,12 +19,13 @@ BAT.bat_default(::Type{TransformedMCMC}, ::Val{:init}, proposal::HamiltonianMC,
BAT.bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::HamiltonianMC, pretransform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) =
MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 250), max_ncycles = 4)


# Change to incorporate the initial adaptive transform into f and fg
function BAT._create_proposal_state(
proposal::HamiltonianMC,
target::BATMeasure,
context::BATContext,
v_init::AbstractVector{P},
f_transform::Function,
rng::AbstractRNG
) where {P<:Real}
vs = varshape(target)
Expand All @@ -32,13 +35,13 @@ function BAT._create_proposal_state(
params_vec .= v_init

adsel = get_adselector(context)
f = checked_logdensityof(target)
f = checked_logdensityof(pullback(f_transform, target))
fg = valgrad_func(f, adsel)

metric = ahmc_metric(proposal.metric, params_vec)
init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg)
hamiltonian, init_transition = AdvancedHMC.sample_init(rng, init_hamiltonian, params_vec)
integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec)
integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec, rng)
termination = _ahmc_convert_termination(proposal.termination, params_vec)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, termination))

Expand Down Expand Up @@ -85,51 +88,63 @@ function BAT.next_cycle!(mc_state::HMCState)
mc_state
end

# TODO: MD, should this be a !! function?
function BAT.mcmc_propose!!(mc_state::HMCState)
# @unpack target, proposal, f_transform, samples, context = mc_state
target = mc_state.target
proposal = mc_state.proposal
f_transform = mc_state.f_transform
samples = mc_state.samples
sample_z = mc_state.sample_z
context = mc_state.context

rng = get_rng(context)

current = _current_sample_idx(mc_state)
proposed = _proposed_sample_idx(mc_state)

x_current = samples.v[current]
x_proposed = samples.v[proposed]
current_log_posterior = samples.logd[current]
current_z_idx = _current_sample_z_idx(mc_state)
proposed_z_idx = _proposed_sample_z_idx(mc_state)

proposed_x_idx = _proposed_sample_idx(mc_state)

proposal.transition = AdvancedHMC.transition(rng, proposal.hamiltonian, proposal.kernel, proposal.transition.z)
x_proposed[:] = proposal.transition.z.θ
# location in normalized (or generally transformed) space ("z-space")
z_current = sample_z.v[current_z_idx]
z_proposed = sample_z.v[proposed_z_idx]

proposed_log_posterior = logdensityof(target, x_proposed)
# location in target space ("x-space") which is generally pre-transformed
x_proposed = samples.v[proposed_x_idx]

samples.logd[proposed] = proposed_log_posterior
τ = deepcopy(proposal.kernel.τ)
@reset τ.integrator = AdvancedHMC.jitter(rng, τ.integrator)

hamiltonian = proposal.hamiltonian

# Current location in the phase space for Hamiltonian MonteCarlo
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic))
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.

accepted = x_current != x_proposed
proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase)
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate

# TODO: Setting p_accept to 1 or 0 for now.
# Use AdvancedHMC.stat(transition).acceptance_rate in the future?
p_accept = Float64(accepted)
z_proposed[:] = proposal.transition.z.θ
accepted = z_current[:] != z_proposed[:]

p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate

x_proposed[:], ladj = with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = logdensityof(target, x_proposed)
samples.logd[proposed_x_idx] = logd_x_proposed

sample_z.logd[proposed_z_idx] = logd_x_proposed + ladj

return mc_state, accepted, p_accept
end

function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer)
# @unpack samples, proposal = mc_state
samples = mc_state.samples
proposal = mc_state.proposal

if accepted
samples.info.sampletype[current] = ACCEPTED_SAMPLE
samples.info.sampletype[proposed] = CURRENT_SAMPLE
mc_state.nsamples += 1

tstat = AdvancedHMC.stat(proposal.transition)
samples.info.hamiltonian_energy[proposed] = tstat.hamiltonian_energy
# ToDo: Handle proposal-dependent tstat (only NUTS has tree_depth):
Expand All @@ -146,5 +161,20 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float
samples.weight[proposed] = w_proposed
end


BAT.eff_acceptance_ratio(mc_state::HMCState) = nsamples(mc_state) / nsteps(mc_state)

function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Function)
adsel = get_adselector(mc_state.context)
f = checked_logdensityof(pullback(f_transform_new, mc_state.target))
fg = valgrad_func(f, adsel)

h = mc_state.proposal.hamiltonian

h = @set h.ℓπ = f
h = @set h.∂ℓπ∂θ = fg

mc_state_new = @set mc_state.proposal.hamiltonian = h

mc_state_new = @set mc_state_new.f_transform = f_transform_new
return mc_state_new
end
79 changes: 79 additions & 0 deletions ext/ahmc_impl/ahmc_stan_tuner_impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

mutable struct StanLikeTunerState{
S<:MCMCBasicStats,
} <: MCMCTransformTunerState
tuning::StanLikeTuning
target_acceptance::Float64
stats::S
stan_state::AdvancedHMC.Adaptation.StanHMCAdaptorState
end

BAT.create_trafo_tuner_state(tuning::StanLikeTuning, chain_state::MCMCChainState, n_steps_hint::Integer) = StanLikeTunerState(tuning, tuning.target_acceptance, MCMCBasicStats(chain_state), AdvancedHMC.Adaptation.StanHMCAdaptorState())

function BAT.mcmc_tuning_init!!(tuner::StanLikeTunerState, chain_state::HMCState, max_nsteps::Integer)
tuning = tuner.tuning
AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1))
nothing
end

function BAT.mcmc_tuning_reinit!!(tuner::StanLikeTunerState, chain_state::HMCState, max_nsteps::Integer)
tuning = tuner.tuning
AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1))
nothing
end

BAT.mcmc_tuning_postinit!!(tuner::StanLikeTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing


function BAT.mcmc_tune_post_cycle!!(tuner::StanLikeTunerState, chain_state::HMCState, samples::DensitySampleVector)
max_log_posterior = maximum(samples.logd)
accept_ratio = eff_acceptance_ratio(chain_state)
if accept_ratio >= 0.9 * tuner.target_acceptance
chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = true)
@debug "MCMC chain $(chain_state.info.id) tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
else
chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false)
@debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
end
return chain_state, tuner
end


BAT.mcmc_tuning_finalize!!(tuner::StanLikeTunerState, chain_state::HMCState) = nothing


function BAT.mcmc_tune_post_step!!(
tuner::StanLikeTunerState,
chain_state::MCMCChainState,
p_accept::Real
)
stan_state = tuner.stan_state
stan_state.i += 1

stats = tuner.stats
is_in_window = stan_state.i >= stan_state.window_start && stan_state.i <= stan_state.window_end
is_window_end = stan_state.i in stan_state.window_splits

if is_in_window
BAT.push!(stats, proposed_sample(chain_state))
end

if is_window_end
A = chain_state.f_transform.A
T = eltype(A)
n_dims = size(A, 2)

M = convert(Array, stats.param_stats.cov)
A_new = T.(cholesky(Positive, M).L)

reweight_relative!(stats, 0)

f_transform_new = MulAdd(A_new, zeros(T, n_dims))
chain_state = set_mc_state_transform!!(chain_state, f_transform_new)
end

chain_state_new = mcmc_update_z_position!!(chain_state)

return chain_state_new, tuner
end
8 changes: 5 additions & 3 deletions ext/ahmc_impl/ahmc_tuner_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function BAT.mcmc_tune_post_cycle!!(tuner::HMCProposalTunerState, chain_state::H
chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false)
@debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))"
end
return chain_state, tuner, false
return chain_state, tuner
end


Expand All @@ -66,12 +66,14 @@ function BAT.mcmc_tune_post_step!!(
tstat = AdvancedHMC.stat(proposal_new.transition)

AdvancedHMC.adapt!(adaptor, proposal_new.transition.z.θ, tstat.acceptance_rate)
proposal_new.hamiltonian = AdvancedHMC.update(proposal_new.hamiltonian, adaptor)
h = proposal_new.hamiltonian
h = AdvancedHMC.update(h, adaptor)

proposal_new.kernel = AdvancedHMC.update(proposal_new.kernel, adaptor)
tstat = merge(tstat, (is_adapt =true,))

chain_state_tmp = @set chain_state.proposal.transition.stat = tstat
chain_state_final = @set chain_state_tmp.proposal = proposal_new

return chain_state_final, tuner_state, false
return chain_state_final, tuner_state
end
2 changes: 1 addition & 1 deletion src/extdefs/ahmc_defs/ahmc_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ $(TYPEDFIELDS)
IT,
TC
} <: MCMCProposal
metric::MT = DiagEuclideanMetric()
metric::MT = UnitEuclideanMetric()
integrator::IT = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_INTEGRATOR))
termination::TC = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_TERMINATION_CRITERION))
end
Expand Down
9 changes: 5 additions & 4 deletions src/extdefs/ahmc_defs/ahmc_config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ end

# Uses Stan (also AdvancedHMC) defaults
# (see https://mc-stan.org/docs/2_26/reference-manual/hmc-algorithm-parameters.html):
@with_kw struct StanHMCTuning <: HMCTuning
@with_kw struct StanLikeTuning <: MCMCTransformTuning
"target acceptance rate"
target_acceptance::Float64 = 0.8

"width of initial fast adaptation interval"
initial_bufsize::Int = 75
init_buffer::Int = 75

"width of final fast adaptation interval"
term_bufsize::Int = 50
term_buffer::Int = 50

"initial width of slow adaptation interval"
window_size::Int = 25
end
export StanHMCTuning

export StanLikeTuning
Loading
Loading