From 720b84d0f45832cfbb70035c5fe429ac0628f008 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Sun, 8 Dec 2024 16:08:16 +0100 Subject: [PATCH 01/11] Create new phasepoint during each AHMC step --- ext/BATAdvancedHMCExt.jl | 8 +++-- ext/ahmc_impl/ahmc_config_impl.jl | 6 ++-- ext/ahmc_impl/ahmc_sampler_impl.jl | 41 +++++++++++++++---------- ext/ahmc_impl/ahmc_tuner_impl.jl | 4 +-- src/samplers/mcmc/chain_pool_init.jl | 6 +++- src/samplers/mcmc/mcmc_algorithm.jl | 8 +++-- src/samplers/mcmc/mcmc_sample.jl | 4 +++ src/samplers/mcmc/mcmc_state.jl | 27 +++++++++------- src/samplers/mcmc/mh_sampler.jl | 1 - src/samplers/mcmc/multi_cycle_burnin.jl | 3 +- 10 files changed, 68 insertions(+), 40 deletions(-) diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index 323d90c6f..055c4959d 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -15,8 +15,8 @@ using BAT: MeasureLike, BATMeasure using BAT: get_context, get_adselector, _NoADSelected using BAT: getproposal, mcmc_target -using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering -using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples +using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning +using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples, current_sample_z using BAT: AbstractTransformTarget, NoAdaptiveTransform using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio @@ -30,9 +30,11 @@ using BAT: AHMCSampleID, AHMCSampleIDVector using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning +using MeasureBase: pullback + using ValueShapes: varshape -using Accessors: @set +using Accessors: @set, @reset BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_INTEGRATOR}) = AdvancedHMC.Leapfrog(NaN) diff --git a/ext/ahmc_impl/ahmc_config_impl.jl b/ext/ahmc_impl/ahmc_config_impl.jl index 642ffdf58..29849be1a 100644 --- a/ext/ahmc_impl/ahmc_config_impl.jl +++ b/ext/ahmc_impl/ahmc_config_impl.jl @@ -3,12 +3,12 @@ # Integrator ============================================== -function _ahmc_set_step_size(integrator::AdvancedHMC.AbstractIntegrator, hamiltonian::AdvancedHMC.Hamiltonian, θ_init::AbstractVector{<:Real}) +function _ahmc_set_step_size(integrator::AdvancedHMC.AbstractIntegrator, hamiltonian::AdvancedHMC.Hamiltonian, θ_init::AbstractVector{<:Real}, rng::AbstractRNG) # ToDo: Add way to specify max_n_iters T = eltype(θ_init) step_size = integrator.ϵ if isnan(step_size) - new_step_size = AdvancedHMC.find_good_stepsize(hamiltonian, θ_init, max_n_iters = 100) + new_step_size = AdvancedHMC.find_good_stepsize(rng, hamiltonian, θ_init, max_n_iters = 100) @set integrator.ϵ = T(new_step_size) else @set integrator.ϵ = T(step_size) @@ -57,7 +57,7 @@ function ahmc_adaptor( θ_init::AbstractVector{<:Real} ) T = eltype(θ_init) - return AdvancedHMC.StepSizeAdaptor(tuning.target_acceptance, integrator) + return AdvancedHMC.StepSizeAdaptor(T(tuning.target_acceptance), integrator) end function ahmc_adaptor( diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index cf9af351a..c71c060e8 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -5,6 +5,8 @@ BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::Hamilto BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = NoMCMCTransformTuning() + BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform() BAT.bat_default(::Type{TransformedMCMC}, ::Val{:tempering}, proposal::HamiltonianMC) = NoMCMCTempering() @@ -38,7 +40,7 @@ function BAT._create_proposal_state( metric = ahmc_metric(proposal.metric, params_vec) init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg) hamiltonian, init_transition = AdvancedHMC.sample_init(rng, init_hamiltonian, params_vec) - integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec) + integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec, rng) termination = _ahmc_convert_termination(proposal.termination, params_vec) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, termination)) @@ -66,7 +68,7 @@ function BAT.next_cycle!(mc_state::HMCState) mc_state.nsamples = 0 mc_state.stepno = 0 - reset_rng_counters!(mc_state) + #reset_rng_counters!(mc_state) resize!(mc_state.samples, 1) @@ -85,9 +87,7 @@ function BAT.next_cycle!(mc_state::HMCState) mc_state end -# TODO: MD, should this be a !! function? function BAT.mcmc_propose!!(mc_state::HMCState) - # @unpack target, proposal, f_transform, samples, context = mc_state target = mc_state.target proposal = mc_state.proposal f_transform = mc_state.f_transform @@ -96,27 +96,36 @@ function BAT.mcmc_propose!!(mc_state::HMCState) rng = get_rng(context) - current = _current_sample_idx(mc_state) - proposed = _proposed_sample_idx(mc_state) + current_x_idx = _current_sample_idx(mc_state) + proposed_x_idx = _proposed_sample_idx(mc_state) + + z_current = current_sample_z(mc_state).v + + x_current = samples.v[current_x_idx] + x_proposed = samples.v[proposed_x_idx] + current_log_posterior = samples.logd[current_x_idx] + + + + τ = deepcopy(proposal.kernel.τ) + @reset τ.integrator = AdvancedHMC.jitter(rng, τ.integrator) + + hamiltonian = proposal.hamiltonian + z = AdvancedHMC.phasepoint(hamiltonian, vec(x_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic, vec(x_current[:]))) + + proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z) + tstat = AdvancedHMC.stat(proposal.transition) + p_accept = tstat.acceptance_rate - x_current = samples.v[current] - x_proposed = samples.v[proposed] - current_log_posterior = samples.logd[current] - proposal.transition = AdvancedHMC.transition(rng, proposal.hamiltonian, proposal.kernel, proposal.transition.z) x_proposed[:] = proposal.transition.z.θ proposed_log_posterior = logdensityof(target, x_proposed) - samples.logd[proposed] = proposed_log_posterior + samples.logd[proposed_x_idx] = proposed_log_posterior accepted = x_current != x_proposed - - # TODO: Setting p_accept to 1 or 0 for now. - # Use AdvancedHMC.stat(transition).acceptance_rate in the future? - p_accept = Float64(accepted) - return mc_state, accepted, p_accept end diff --git a/ext/ahmc_impl/ahmc_tuner_impl.jl b/ext/ahmc_impl/ahmc_tuner_impl.jl index b7b58e504..cf6b33b22 100644 --- a/ext/ahmc_impl/ahmc_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_tuner_impl.jl @@ -49,8 +49,8 @@ function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::H adaptor = tuner.adaptor proposal = chain_state.proposal AdvancedHMC.finalize!(adaptor) - proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor) - proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor) + proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor) # Remove for transition to trafo based tuning + proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor) nothing end diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index d39fca2a7..7ca40166a 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -91,6 +91,7 @@ function mcmc_init!( @debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain state(s)." new_mcmc_states = _gen_mcmc_states(samplingalg, target, rngpart, ncandidates .+ (one(Int64):n), initval_alg, context) + global st_cp_init_pre_misc = deepcopy(new_mcmc_states) filter!(isvalidstate, new_mcmc_states) @@ -102,13 +103,16 @@ function mcmc_init!( ncandidates += n @debug "Testing $(length(new_mcmc_states)) candidate MCMC chain state(s)." + global st_cp_init_post_gen = (deepcopy(new_mcmc_states), deepcopy(new_outputs), init_alg, nonzero_weights) + #BREAK_INIT new_mcmc_states = mcmc_iterate!!( new_outputs, new_mcmc_states; max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), nonzero_weights = nonzero_weights ) - + global st_cp_init_post_it = deepcopy(new_mcmc_states) + #BREAK_cp_init viable_idxs = findall(isviablestate.(new_mcmc_states)) viable_mcmc_states = new_mcmc_states[viable_idxs] viable_outputs = new_outputs[viable_idxs] diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index 09b3b42a1..be514147f 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -276,11 +276,15 @@ function mcmc_iterate!!( log_time = start_time start_nsteps = nsteps(mcmc_state) start_nsamples = nsamples(mcmc_state) - + + global states = Any[] while ( (nsteps(mcmc_state) - start_nsteps) < max_nsteps && (time() - start_time) < max_time - ) + ) + + push!(states, deepcopy(mcmc_state)) + #BREAK_sample_pre_step mcmc_state = mcmc_step!!(mcmc_state) if !isnothing(output) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index f0c7a4ec0..db6160227 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -85,6 +85,10 @@ function bat_sample_impl(m::BATMeasure, samplingalg::TransformedMCMC, context::B context ) + global st_sample_post_init = deepcopy(mcmc_states) + global st_sample_out = deepcopy(chain_outputs) + #BREAK_sample + if !samplingalg.store_burnin chain_outputs .= DensitySampleVector.(mcmc_states) end diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 41bdb4e2b..d1c935f68 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -46,19 +46,22 @@ function MCMCChainState( n_dims = getdof(target) #Create Proposal state. Necessary in particular for AHMC proposal + global state_proposal_init = (samplingalg, target, context, v_init, rng) + #BREAK_state_init + proposal = _create_proposal_state(samplingalg.proposal, target, context, v_init, rng) stepno::Int64 = 0 cycle::Int32 = 0 nsamples::Int64 = 0 - g = init_adaptive_transform(samplingalg.adaptive_transform, target, context) - + f = init_adaptive_transform(samplingalg.adaptive_transform, target, context) + f_inv = inverse(f) logd_x = logdensityof(target_unevaluated, v_init) - inverse_g = inverse(g) - z = inverse_g(v_init) - logd_z = logdensityof(MeasureBase.pullback(g, target_unevaluated), z) - + + z = f_inv(v_init) + logd_z = logdensityof(MeasureBase.pullback(f, target_unevaluated), z) + W = mcmc_weight_type(samplingalg.sample_weighting) T = typeof(logd_x) @@ -76,7 +79,7 @@ function MCMCChainState( state = MCMCChainState( target, proposal, - g, + f, samplingalg.sample_weighting, samples, sample_z, @@ -88,7 +91,7 @@ function MCMCChainState( ) # TODO: MD, resetting the counters necessary/desired? - reset_rng_counters!(state) + #reset_rng_counters!(state) state end @@ -148,7 +151,7 @@ end function mcmc_step!!(mcmc_state::MCMCState) _cleanup_samples(mcmc_state) - reset_rng_counters!(mcmc_state) + #reset_rng_counters!(mcmc_state) chain_state = mcmc_state.chain_state @@ -159,7 +162,9 @@ function mcmc_step!!(mcmc_state::MCMCState) resize!(samples, size(samples, 1) + 1) samples.info[lastindex(samples)] = _get_sample_id(proposal, chain_state.info.id, chain_state.info.cycle, chain_state.stepno, PROPOSED_SAMPLE)[1] - + + global it_state = deepcopy(chain_state) + # BREEAK chain_state, accepted, p_accept = mcmc_propose!!(chain_state) mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) @@ -222,7 +227,7 @@ function next_cycle!(chain_state::MCMCChainState) chain_state.nsamples = 0 chain_state.stepno = 0 - reset_rng_counters!(chain_state) + #reset_rng_counters!(chain_state) resize!(chain_state.samples, 1) diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index ea9ad59e9..aec4e2f06 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -103,7 +103,6 @@ end const MHChainState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:MHProposalState} - function mcmc_propose!!(mc_state::MHChainState) @unpack target, proposal, f_transform, context = mc_state rng = get_rng(context) diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index fffe87a6f..cf0da0928 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -52,7 +52,8 @@ function mcmc_burnin!( max_nsteps = burnin.nsteps_per_cycle, nonzero_weights = nonzero_weights ) - + global st_burnin_post_it = mcmc_states + #BREAK_burnin mcmc_states = mcmc_tune_post_cycle!!.(mcmc_states, new_outputs) isnothing(outputs) || append!.(outputs, new_outputs) From a3f2c39c4667ce66a0b94f483be94f4053b13720 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Sat, 25 Jan 2025 16:51:06 +0100 Subject: [PATCH 02/11] HMC trafo refactor preparation --- ext/BATAdvancedHMCExt.jl | 9 +-- ext/ahmc_impl/ahmc_sampler_impl.jl | 57 +++++++++++++------ ext/ahmc_impl/ahmc_tuner_impl.jl | 12 ++-- src/measures/bat_pushfwd_measure.jl | 2 +- src/samplers/mcmc/chain_pool_init.jl | 9 +-- src/samplers/mcmc/mcmc_algorithm.jl | 4 -- src/samplers/mcmc/mcmc_sample.jl | 4 -- src/samplers/mcmc/mcmc_state.jl | 36 ++++-------- .../mcmc_tuning/mcmc_adaptive_mh_tuner.jl | 8 ++- .../mcmc/mcmc_tuning/mcmc_noop_tuner.jl | 4 +- .../mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 9 +-- src/samplers/mcmc/mh_sampler.jl | 8 ++- src/samplers/mcmc/multi_cycle_burnin.jl | 3 +- src/variates/shaped_variates.jl | 4 +- 14 files changed, 86 insertions(+), 83 deletions(-) diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index 055c4959d..e0795f61a 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -15,13 +15,14 @@ using BAT: MeasureLike, BATMeasure using BAT: get_context, get_adselector, _NoADSelected using BAT: getproposal, mcmc_target -using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning -using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples, current_sample_z -using BAT: AbstractTransformTarget, NoAdaptiveTransform +using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin +using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning, RAMTuning +using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z +using BAT: AbstractTransformTarget, NoAdaptiveTransform, TriangularAffineTransform, valgrad_func using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio using BAT: get_samples!, reset_rng_counters! -using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!! +using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, set_mc_state_transform!! using BAT: totalndof, measure_support, checked_logdensityof using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index c71c060e8..8dba913fa 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -3,11 +3,11 @@ BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::HamiltonianMC) = PriorToNormal() -BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StepSizeAdaptor() -BAT.bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = NoMCMCTransformTuning() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = RAMTuning() -BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = TriangularAffineTransform() BAT.bat_default(::Type{TransformedMCMC}, ::Val{:tempering}, proposal::HamiltonianMC) = NoMCMCTempering() @@ -19,12 +19,13 @@ BAT.bat_default(::Type{TransformedMCMC}, ::Val{:init}, proposal::HamiltonianMC, BAT.bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::HamiltonianMC, pretransform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 250), max_ncycles = 4) - +# Change to incorporate the initial adaptive transform into f and fg function BAT._create_proposal_state( proposal::HamiltonianMC, target::BATMeasure, context::BATContext, v_init::AbstractVector{P}, + f_transform::Function, rng::AbstractRNG ) where {P<:Real} vs = varshape(target) @@ -34,7 +35,7 @@ function BAT._create_proposal_state( params_vec .= v_init adsel = get_adselector(context) - f = checked_logdensityof(target) + f = checked_logdensityof(pullback(f_transform, target)) fg = valgrad_func(f, adsel) metric = ahmc_metric(proposal.metric, params_vec) @@ -92,40 +93,45 @@ function BAT.mcmc_propose!!(mc_state::HMCState) proposal = mc_state.proposal f_transform = mc_state.f_transform samples = mc_state.samples + sample_z = mc_state.sample_z context = mc_state.context rng = get_rng(context) - current_x_idx = _current_sample_idx(mc_state) + current_z_idx = _current_sample_z_idx(mc_state) + proposed_z_idx = _proposed_sample_z_idx(mc_state) + proposed_x_idx = _proposed_sample_idx(mc_state) - z_current = current_sample_z(mc_state).v - - x_current = samples.v[current_x_idx] + # location in normalized (or generally transformed) space ("z-space") + z_current = sample_z.v[current_z_idx] + z_proposed = sample_z.v[proposed_z_idx] + + # location in target space ("x-space") x_proposed = samples.v[proposed_x_idx] - current_log_posterior = samples.logd[current_x_idx] - - τ = deepcopy(proposal.kernel.τ) @reset τ.integrator = AdvancedHMC.jitter(rng, τ.integrator) hamiltonian = proposal.hamiltonian - z = AdvancedHMC.phasepoint(hamiltonian, vec(x_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic, vec(x_current[:]))) - proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z) + # Current location in the phase space for Hamiltonian MonteCarlo + z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic)) + # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics. + + proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase) tstat = AdvancedHMC.stat(proposal.transition) p_accept = tstat.acceptance_rate + z_proposed[:] = proposal.transition.z.θ - - x_proposed[:] = proposal.transition.z.θ + x_proposed[:] = f_transform(z_proposed) proposed_log_posterior = logdensityof(target, x_proposed) samples.logd[proposed_x_idx] = proposed_log_posterior - accepted = x_current != x_proposed + accepted = z_current != z_proposed return mc_state, accepted, p_accept end @@ -155,5 +161,20 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float samples.weight[proposed] = w_proposed end - BAT.eff_acceptance_ratio(mc_state::HMCState) = nsamples(mc_state) / nsteps(mc_state) + +function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Function) + adsel = get_adselector(mc_state.context) + f = checked_logdensityof(pullback(f_transform_new, mc_state.target)) + fg = valgrad_func(f, adsel) + + h = mc_state.proposal.hamiltonian + + h = @set h.ℓπ = f + h = @set h.∂ℓπ∂θ = fg + + mc_state_new = @set mc_state.proposal.hamiltonian = h + + mc_state_new = @set mc_state_new.f_transform = f_transform_new + return mc_state_new +end diff --git a/ext/ahmc_impl/ahmc_tuner_impl.jl b/ext/ahmc_impl/ahmc_tuner_impl.jl index cf6b33b22..59df2602b 100644 --- a/ext/ahmc_impl/ahmc_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_tuner_impl.jl @@ -41,7 +41,7 @@ function BAT.mcmc_tune_post_cycle!!(tuner::HMCProposalTunerState, chain_state::H chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false) @debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))" end - return chain_state, tuner, false + return chain_state, tuner end @@ -49,8 +49,8 @@ function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::H adaptor = tuner.adaptor proposal = chain_state.proposal AdvancedHMC.finalize!(adaptor) - proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor) # Remove for transition to trafo based tuning - proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor) + proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor) + proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor) nothing end @@ -66,12 +66,14 @@ function BAT.mcmc_tune_post_step!!( tstat = AdvancedHMC.stat(proposal_new.transition) AdvancedHMC.adapt!(adaptor, proposal_new.transition.z.θ, tstat.acceptance_rate) - proposal_new.hamiltonian = AdvancedHMC.update(proposal_new.hamiltonian, adaptor) + h = proposal_new.hamiltonian + h = AdvancedHMC.update(h, adaptor) + proposal_new.kernel = AdvancedHMC.update(proposal_new.kernel, adaptor) tstat = merge(tstat, (is_adapt =true,)) chain_state_tmp = @set chain_state.proposal.transition.stat = tstat chain_state_final = @set chain_state_tmp.proposal = proposal_new - return chain_state_final, tuner_state, false + return chain_state_final, tuner_state end diff --git a/src/measures/bat_pushfwd_measure.jl b/src/measures/bat_pushfwd_measure.jl index 443f2a8ed..aab2a6575 100644 --- a/src/measures/bat_pushfwd_measure.jl +++ b/src/measures/bat_pushfwd_measure.jl @@ -71,7 +71,7 @@ end #!!!!!!!!! Use return type of f with testvalue, if no shape change return varshape(m.orig) directly -#ValueShapes.varshape(m::BATPushFwdMeasure) = f(varshape(m.orig)) +# ValueShapes.varshape(m::BATPushFwdMeasure) = varshape(m.origin) ValueShapes.varshape(m::BATPushFwdMeasure{<:DistributionTransform}) = varshape(m.f.target_dist) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 7ca40166a..717216cc5 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -91,7 +91,6 @@ function mcmc_init!( @debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain state(s)." new_mcmc_states = _gen_mcmc_states(samplingalg, target, rngpart, ncandidates .+ (one(Int64):n), initval_alg, context) - global st_cp_init_pre_misc = deepcopy(new_mcmc_states) filter!(isvalidstate, new_mcmc_states) @@ -99,20 +98,18 @@ function mcmc_init!( next_cycle!.(new_mcmc_states) mcmc_tuning_init!!.(new_mcmc_states, init_alg.nsteps_init) + new_mcmc_states = mcmc_update_z_position!!.(new_mcmc_states) ncandidates += n @debug "Testing $(length(new_mcmc_states)) candidate MCMC chain state(s)." - global st_cp_init_post_gen = (deepcopy(new_mcmc_states), deepcopy(new_outputs), init_alg, nonzero_weights) - - #BREAK_INIT + new_mcmc_states = mcmc_iterate!!( new_outputs, new_mcmc_states; max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), nonzero_weights = nonzero_weights ) - global st_cp_init_post_it = deepcopy(new_mcmc_states) - #BREAK_cp_init + viable_idxs = findall(isviablestate.(new_mcmc_states)) viable_mcmc_states = new_mcmc_states[viable_idxs] viable_outputs = new_outputs[viable_idxs] diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index be514147f..07b7d38ee 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -277,14 +277,10 @@ function mcmc_iterate!!( start_nsteps = nsteps(mcmc_state) start_nsamples = nsamples(mcmc_state) - global states = Any[] while ( (nsteps(mcmc_state) - start_nsteps) < max_nsteps && (time() - start_time) < max_time ) - - push!(states, deepcopy(mcmc_state)) - #BREAK_sample_pre_step mcmc_state = mcmc_step!!(mcmc_state) if !isnothing(output) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index db6160227..f0c7a4ec0 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -85,10 +85,6 @@ function bat_sample_impl(m::BATMeasure, samplingalg::TransformedMCMC, context::B context ) - global st_sample_post_init = deepcopy(mcmc_states) - global st_sample_out = deepcopy(chain_outputs) - #BREAK_sample - if !samplingalg.store_burnin chain_outputs .= DensitySampleVector.(mcmc_states) end diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index d1c935f68..3f1cf6150 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -45,23 +45,18 @@ function MCMCChainState( rng = get_rng(context) n_dims = getdof(target) - #Create Proposal state. Necessary in particular for AHMC proposal - global state_proposal_init = (samplingalg, target, context, v_init, rng) - #BREAK_state_init - - proposal = _create_proposal_state(samplingalg.proposal, target, context, v_init, rng) - stepno::Int64 = 0 - - cycle::Int32 = 0 - nsamples::Int64 = 0 - f = init_adaptive_transform(samplingalg.adaptive_transform, target, context) f_inv = inverse(f) logd_x = logdensityof(target_unevaluated, v_init) z = f_inv(v_init) logd_z = logdensityof(MeasureBase.pullback(f, target_unevaluated), z) - + + proposal = _create_proposal_state(samplingalg.proposal, target, context, v_init, f, rng) + stepno::Int64 = 0 + cycle::Int32 = 0 + nsamples::Int64 = 0 + W = mcmc_weight_type(samplingalg.sample_weighting) T = typeof(logd_x) @@ -92,7 +87,6 @@ function MCMCChainState( # TODO: MD, resetting the counters necessary/desired? #reset_rng_counters!(state) - state end @@ -163,8 +157,6 @@ function mcmc_step!!(mcmc_state::MCMCState) samples.info[lastindex(samples)] = _get_sample_id(proposal, chain_state.info.id, chain_state.info.cycle, chain_state.stepno, PROPOSED_SAMPLE)[1] - global it_state = deepcopy(chain_state) - # BREEAK chain_state, accepted, p_accept = mcmc_propose!!(chain_state) mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) @@ -333,12 +325,8 @@ end # TODO: MD, when should the z-position be updated? Before or after the proposal tuning? function mcmc_tune_post_cycle!!(state::MCMCState, samples::DensitySampleVector) - chain_state_tmp, trafo_tuner_state_new, trafo_changed = mcmc_tune_post_cycle!!(state.trafo_tuner_state, state.chain_state, samples) - chain_state_new, proposal_tuner_state_new, _ = mcmc_tune_post_cycle!!(state.proposal_tuner_state, chain_state_tmp, samples) - - if trafo_changed - chain_state_new = mcmc_update_z_position!!(chain_state_new) - end + chain_state_tmp, trafo_tuner_state_new = mcmc_tune_post_cycle!!(state.trafo_tuner_state, state.chain_state, samples) + chain_state_new, proposal_tuner_state_new = mcmc_tune_post_cycle!!(state.proposal_tuner_state, chain_state_tmp, samples) mcmc_state_cs = @set state.chain_state = chain_state_new mcmc_state_tt = @set mcmc_state_cs.trafo_tuner_state = trafo_tuner_state_new @@ -348,12 +336,8 @@ function mcmc_tune_post_cycle!!(state::MCMCState, samples::DensitySampleVector) end function mcmc_tune_post_step!!(state::MCMCState, p_accept::Real) - chain_state_tmp, trafo_tuner_state_new, trafo_changed = mcmc_tune_post_step!!(state.trafo_tuner_state, state.chain_state, p_accept) - chain_state_new, proposal_tuner_state_new, _ = mcmc_tune_post_step!!(state.proposal_tuner_state, chain_state_tmp, p_accept) - - if trafo_changed - chain_state_new = mcmc_update_z_position!!(chain_state_new) - end + chain_state_tmp, trafo_tuner_state_new = mcmc_tune_post_step!!(state.trafo_tuner_state, state.chain_state, p_accept) + chain_state_new, proposal_tuner_state_new = mcmc_tune_post_step!!(state.proposal_tuner_state, chain_state_tmp, p_accept) # TODO: MD, inelegant, use AccessorsExtra.jl to set several fields at once? https://github.com/JuliaAPlavin/AccessorsExtra.jl mcmc_state_cs = @set state.chain_state = chain_state_new diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl index 40568a866..d170c4df1 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl @@ -134,8 +134,10 @@ function mcmc_tune_post_cycle!!(tuner::AdaptiveAffineTuningState, chain_state::M tuner.iteration += 1 - # TODO: MD, think about keeping old z_position if transform changes only slightly, and return a bool accordingly, instead of always 'true' - chain_state, tuner, true + # TODO: MD, think about keeping old z_position if transform changes only slightly + chain_state_new = mcmc_update_z_position!!(chain_state) + + return chain_state_new, tuner end @@ -147,5 +149,5 @@ function mcmc_tune_post_step!!( chain_state::MCMCChainState, p_accept::Real ) - return chain_state, tuner, false + return chain_state, tuner end diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl index 2d7d90de0..913c018ae 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -47,8 +47,8 @@ mcmc_tuning_reinit!!(::NoMCMCProposalTunerState, ::MCMCChainState, ::Integer) = mcmc_tuning_postinit!!(::NoMCMCProposalTunerState, ::MCMCChainState, ::DensitySampleVector) = nothing -mcmc_tune_post_cycle!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::DensitySampleVector) = chain_state, tuner, false +mcmc_tune_post_cycle!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::DensitySampleVector) = chain_state, tuner mcmc_tuning_finalize!!(::NoMCMCProposalTunerState, ::MCMCChainState) = nothing -mcmc_tune_post_step!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::Real) = chain_state, tuner, false +mcmc_tune_post_step!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::Real) = chain_state, tuner diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index dd60a178d..036a0d962 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -66,7 +66,7 @@ function mcmc_tune_post_cycle!!(tuner::RAMTrafoTunerState, chain_state::MCMCChai chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false) @debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(α)), max. log posterior = $(Float32(max_log_posterior))" end - return chain_state, tuner, false + return chain_state, tuner end mcmc_tuning_finalize!!(tuner::RAMTrafoTunerState, chain::MCMCChainState) = nothing @@ -95,11 +95,12 @@ function mcmc_tune_post_step!!( α = mean_update_rate * p_accept new_b = oftype(b, (1- α) * b + α * x.v) - f_transform_new = MulAdd(new_s_L, new_b) + f_transform_new = MulAdd(new_s_L, new_b) tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 - mc_state_new = @set mc_state.f_transform = f_transform_new + mc_state_new = set_mc_state_transform!!(mc_state, f_transform_new) + mc_state_new = mcmc_update_z_position!!(mc_state_new) - return mc_state_new, tuner_state_new, true + return mc_state_new, tuner_state_new end diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index aec4e2f06..ae5b6fe98 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -53,7 +53,8 @@ function _create_proposal_state( proposal::RandomWalk, target::BATMeasure, context::BATContext, - v_init::AbstractVector{<:Real}, + v_init::AbstractVector{<:Real}, + f_transform::Function, rng::AbstractRNG ) n_dims = length(v_init) @@ -156,3 +157,8 @@ end eff_acceptance_ratio(mc_state::MHChainState) = nsamples(mc_state) / nsteps(mc_state) + +function set_mc_state_transform!!(mc_state::MHChainState, f_transform_new::Function) + mc_state_new = @set mc_state.f_transform = f_transform_new + return mc_state_new +end diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index cf0da0928..05f59b1ff 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -52,8 +52,7 @@ function mcmc_burnin!( max_nsteps = burnin.nsteps_per_cycle, nonzero_weights = nonzero_weights ) - global st_burnin_post_it = mcmc_states - #BREAK_burnin + mcmc_states = mcmc_tune_post_cycle!!.(mcmc_states, new_outputs) isnothing(outputs) || append!.(outputs, new_outputs) diff --git a/src/variates/shaped_variates.jl b/src/variates/shaped_variates.jl index 5401dd3d4..e6bc36d79 100644 --- a/src/variates/shaped_variates.jl +++ b/src/variates/shaped_variates.jl @@ -111,6 +111,4 @@ function check_variate(trgshape::Any, v::Any) throw(ArgumentError("Shape of variate incompatible with target variate trgshape, with variate of type $(typeof(v)) and expected trgshape $(trgshape)")) end -function check_variate(trgshape::Missing, v::Any) - throw(ArgumentError("Cannot evaluate without value trgshape information")) -end +check_variate(trgshape::Missing, v::Any) = nothing From 8b9e07302fac126984fb05d5c85978d684bb2183 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Tue, 4 Feb 2025 17:55:55 +0100 Subject: [PATCH 03/11] Adjust MCMC RAMTuner in case p_accept = 0 for proposed samples --- ext/ahmc_impl/ahmc_sampler_impl.jl | 12 +++++------- src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 15 +++++++++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index 8dba913fa..bc0604d3e 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -107,7 +107,7 @@ function BAT.mcmc_propose!!(mc_state::HMCState) z_current = sample_z.v[current_z_idx] z_proposed = sample_z.v[proposed_z_idx] - # location in target space ("x-space") + # location in target space ("x-space") which is generally pre-transformed x_proposed = samples.v[proposed_x_idx] τ = deepcopy(proposal.kernel.τ) @@ -120,18 +120,16 @@ function BAT.mcmc_propose!!(mc_state::HMCState) # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics. proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase) - tstat = AdvancedHMC.stat(proposal.transition) - p_accept = tstat.acceptance_rate - z_proposed[:] = proposal.transition.z.θ - x_proposed[:] = f_transform(z_proposed) - - proposed_log_posterior = logdensityof(target, x_proposed) + proposed_log_posterior = logdensityof(target, x_proposed) samples.logd[proposed_x_idx] = proposed_log_posterior accepted = z_current != z_proposed + tstat = AdvancedHMC.stat(proposal.transition) + p_accept = accepted ? tstat.acceptance_rate : 0.0 + return mc_state, accepted, p_accept end diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index 036a0d962..c3d22d854 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -77,8 +77,16 @@ function mcmc_tune_post_step!!( mc_state::MCMCChainState, p_accept::Real, ) - (; target_acceptance, gamma) = tuner_state.tuning (; f_transform, sample_z) = mc_state + + tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 + + # TODO: MD: Discuss; apparently the RandomwWalk sampler wants the trafo to be tuned even if p_accept = 0. If not, the burnin does not converge. + if iszero(p_accept) && !(mc_state isa MHChainState) + return mc_state, tuner_state_new + end + + (; target_acceptance, gamma) = tuner_state.tuning b = f_transform.b n_dims = size(sample_z.v[1], 1) @@ -88,16 +96,15 @@ function mcmc_tune_post_step!!( u = sample_z.v[2] - sample_z.v[1] # proposed - current M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L' - new_s_L = oftype(s_L, cholesky(Positive, M).L) + new_s_L = oftype(s_L, cholesky(Positive, M).L) + x = mc_state.samples[_proposed_sample_idx(mc_state)] # proposed in x-space mean_update_rate = η / 10 # heuristic α = mean_update_rate * p_accept new_b = oftype(b, (1- α) * b + α * x.v) f_transform_new = MulAdd(new_s_L, new_b) - - tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 mc_state_new = set_mc_state_transform!!(mc_state, f_transform_new) mc_state_new = mcmc_update_z_position!!(mc_state_new) From 34b50bff609482f322a0b66a50e0d892dfa36d69 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Thu, 6 Feb 2025 19:20:08 +0100 Subject: [PATCH 04/11] Add StanHMCTuning() for transformation tuning --- ext/BATAdvancedHMCExt.jl | 21 +++-- ext/ahmc_impl/ahmc_stan_tuner_impl.jl | 78 +++++++++++++++++++ src/extdefs/ahmc_defs/ahmc_config.jl | 7 +- .../mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 14 ++-- test/samplers/mcmc/test_hmc.jl | 17 ++-- 5 files changed, 111 insertions(+), 26 deletions(-) create mode 100644 ext/ahmc_impl/ahmc_stan_tuner_impl.jl diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index e0795f61a..0964eea0b 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -11,18 +11,24 @@ using Random using DensityInterface using HeterogeneousComputing, AutoDiffOperators +using Accessors: @set, @reset + +using AffineMaps: MulAdd + using BAT: MeasureLike, BATMeasure using BAT: get_context, get_adselector, _NoADSelected using BAT: getproposal, mcmc_target using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin -using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning, RAMTuning -using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z +using BAT: MCMCBasicStats, push!, reweight_relative! +using BAT: RAMTuning +using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning +using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample using BAT: AbstractTransformTarget, NoAdaptiveTransform, TriangularAffineTransform, valgrad_func using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio using BAT: get_samples!, reset_rng_counters! -using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, set_mc_state_transform!! +using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, set_mc_state_transform!!, mcmc_update_z_position!! using BAT: totalndof, measure_support, checked_logdensityof using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE @@ -31,20 +37,23 @@ using BAT: AHMCSampleID, AHMCSampleIDVector using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning +using LinearAlgebra: cholesky + using MeasureBase: pullback -using ValueShapes: varshape +using Parameters: @with_kw -using Accessors: @set, @reset +using PositiveFactorizations: Positive +using ValueShapes: varshape BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_INTEGRATOR}) = AdvancedHMC.Leapfrog(NaN) BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_TERMINATION_CRITERION}) = AdvancedHMC.GeneralisedNoUTurn() +include("ahmc_impl/ahmc_stan_tuner_impl.jl") include("ahmc_impl/ahmc_config_impl.jl") include("ahmc_impl/ahmc_sampler_impl.jl") include("ahmc_impl/ahmc_tuner_impl.jl") - end # module BATAdvancedHMCExt diff --git a/ext/ahmc_impl/ahmc_stan_tuner_impl.jl b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl new file mode 100644 index 000000000..03584329b --- /dev/null +++ b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl @@ -0,0 +1,78 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +mutable struct StanHMCTrafoTunerState{ + S<:MCMCBasicStats, +} <: MCMCTransformTunerState + tuning::StanHMCTuning + target_acceptance::Float64 + stats::S + stan_state::AdvancedHMC.Adaptation.StanHMCAdaptorState +end + +BAT.create_trafo_tuner_state(tuning::StanHMCTuning, chain_state::MCMCChainState, n_steps_hint::Integer) = StanHMCTrafoTunerState(tuning, tuning.target_acceptance, MCMCBasicStats(chain_state), AdvancedHMC.Adaptation.StanHMCAdaptorState()) + +function BAT.mcmc_tuning_init!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) + tuning = tuner.tuning + AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1)) + nothing +end + +function BAT.mcmc_tuning_reinit!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) + tuning = tuner.tuning + AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1)) + nothing +end + +BAT.mcmc_tuning_postinit!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing + + +function BAT.mcmc_tune_post_cycle!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) + max_log_posterior = maximum(samples.logd) + accept_ratio = eff_acceptance_ratio(chain_state) + if accept_ratio >= 0.9 * tuner.target_acceptance + chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = true) + @debug "MCMC chain $(chain_state.info.id) tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))" + else + chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false) + @debug "MCMC chain $(chain_state.info.id) *not* tuned, acceptance ratio = $(Float32(accept_ratio)), integrator = $(chain_state.proposal.τ.integrator), max. log posterior = $(Float32(max_log_posterior))" + end + return chain_state, tuner +end + + +BAT.mcmc_tuning_finalize!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState) = nothing + + +function BAT.mcmc_tune_post_step!!( + tuner::StanHMCTrafoTunerState, + chain_state::MCMCChainState, + p_accept::Real +) + stan_state = tuner.stan_state + stan_state.i += 1 + + stats = tuner.stats + is_in_window = stan_state.i >= stan_state.window_start && stan_state.i <= stan_state.window_end + is_window_end = stan_state.i in stan_state.window_splits + + # What to append? + is_in_window && BAT.push!(stats, proposed_sample(chain_state)) + + if is_window_end + A = chain_state.f_transform.A + T = eltype(A) + n_dims = size(A, 2) + + M = convert(Array, stats.param_stats.cov) + A_new = T.(cholesky(Positive, M).L) + + reweight_relative!(stats, 0) + + f_transform_new = MulAdd(A_new, zeros(T, n_dims)) + chain_state = set_mc_state_transform!!(chain_state, f_transform_new) + end + + chain_state_new = mcmc_update_z_position!!(chain_state) + + return chain_state_new, tuner +end diff --git a/src/extdefs/ahmc_defs/ahmc_config.jl b/src/extdefs/ahmc_defs/ahmc_config.jl index 653ca0b9d..c72b5df85 100644 --- a/src/extdefs/ahmc_defs/ahmc_config.jl +++ b/src/extdefs/ahmc_defs/ahmc_config.jl @@ -30,17 +30,18 @@ end # Uses Stan (also AdvancedHMC) defaults # (see https://mc-stan.org/docs/2_26/reference-manual/hmc-algorithm-parameters.html): -@with_kw struct StanHMCTuning <: HMCTuning +@with_kw struct StanHMCTuning <: MCMCTransformTuning "target acceptance rate" target_acceptance::Float64 = 0.8 "width of initial fast adaptation interval" - initial_bufsize::Int = 75 + init_buffer::Int = 75 "width of final fast adaptation interval" - term_bufsize::Int = 50 + term_buffer::Int = 50 "initial width of slow adaptation interval" window_size::Int = 25 end + export StanHMCTuning diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index c3d22d854..d49014253 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -39,7 +39,7 @@ end mutable struct RAMProposalTunerState <: MCMCTransformTunerState end -create_trafo_tuner_state(tuning::RAMTuning, chain::MCMCChainState, n_steps_hint::Integer) = RAMTrafoTunerState(tuning, 0) +create_trafo_tuner_state(tuning::RAMTuning, chain_state::MCMCChainState, n_steps_hint::Integer) = RAMTrafoTunerState(tuning, 0) function mcmc_tuning_init!!(tuner_state::RAMTrafoTunerState, chain_state::MCMCChainState, max_nsteps::Integer) chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false) # TODO ? @@ -71,24 +71,22 @@ end mcmc_tuning_finalize!!(tuner::RAMTrafoTunerState, chain::MCMCChainState) = nothing -# Return mc_state instead of f_transform function mcmc_tune_post_step!!( tuner_state::RAMTrafoTunerState, mc_state::MCMCChainState, p_accept::Real, ) - (; f_transform, sample_z) = mc_state - - tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 - - # TODO: MD: Discuss; apparently the RandomwWalk sampler wants the trafo to be tuned even if p_accept = 0. If not, the burnin does not converge. + # TODO: MD: Discuss; apparently the RandomWalk sampler wants the trafo to be tuned even if p_accept = 0. If not, the burnin does not converge. if iszero(p_accept) && !(mc_state isa MHChainState) - return mc_state, tuner_state_new + return mc_state, tuner_state end + (; f_transform, sample_z) = mc_state (; target_acceptance, gamma) = tuner_state.tuning b = f_transform.b + tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 + n_dims = size(sample_z.v[1], 1) η = min(1, n_dims * tuner_state.nsteps^(-gamma)) diff --git a/test/samplers/mcmc/test_hmc.jl b/test/samplers/mcmc/test_hmc.jl index 34afafd27..3575f6a1d 100644 --- a/test/samplers/mcmc/test_hmc.jl +++ b/test/samplers/mcmc/test_hmc.jl @@ -19,9 +19,9 @@ import AdvancedHMC @test target isa BAT.BATDistMeasure proposal = HamiltonianMC() - proposal_tuning = StanHMCTuning() + transform_tuning = StanHMCTuning() nchains = 4 - samplingalg = TransformedMCMC(proposal = proposal, proposal_tuning = proposal_tuning, nchains = nchains) + samplingalg = TransformedMCMC(proposal = proposal, transform_tuning = transform_tuning, nchains = nchains) @testset "MCMC iteration" begin v_init = bat_initval(target, InitFromTarget(), context).result @@ -29,7 +29,6 @@ import AdvancedHMC # TODO: MD, reactivate @test BAT.MCMCChainState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) isa BAT.HMCState mcmc_state = BAT.MCMCState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) - tuner = BAT.create_proposal_tuner_state(StanHMCTuning(), mcmc_state.chain_state, 0) nsteps = 10^4 BAT.mcmc_tuning_init!!(mcmc_state, 0) BAT.mcmc_tuning_reinit!!(mcmc_state, div(nsteps, 10)) @@ -49,7 +48,7 @@ import AdvancedHMC @testset "MCMC tuning and burn-in" begin max_nsteps = 10^5 - proposal_tuning = BAT.StanHMCTuning() + transform_tuning = BAT.StanHMCTuning() pretransform = DoNotTransform() init_alg = bat_default(TransformedMCMC, Val(:init), proposal, pretransform, nchains, max_nsteps) burnin_alg = bat_default(TransformedMCMC, Val(:burnin), proposal, pretransform, nchains, max_nsteps) @@ -59,7 +58,7 @@ import AdvancedHMC callback = (x...) -> nothing samplingalg = TransformedMCMC(proposal = proposal, - proposal_tuning = proposal_tuning, + transform_tuning = transform_tuning, pretransform = pretransform, init = init_alg, burnin = burnin_alg, @@ -108,7 +107,7 @@ import AdvancedHMC shaped_target, TransformedMCMC( proposal = proposal, - proposal_tuning = StanHMCTuning(), + transform_tuning = StanHMCTuning(), pretransform = DoNotTransform(), nsteps = 10^4, store_burnin = true @@ -124,7 +123,7 @@ import AdvancedHMC shaped_target, TransformedMCMC( proposal = proposal, - proposal_tuning = StanHMCTuning(), + transform_tuning = StanHMCTuning(), pretransform = DoNotTransform(), nsteps = 10^4, store_burnin = false @@ -144,7 +143,7 @@ import AdvancedHMC inner_posterior = PosteriorMeasure(likelihood, prior) # Test with nested posteriors: posterior = PosteriorMeasure(likelihood, inner_posterior) - @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), proposal_tuning = StanHMCTuning(), pretransform = PriorToNormal()), prior.dist, context).verified + @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), transform_tuning = StanHMCTuning(), pretransform = PriorToNormal()), prior.dist, context).verified end @testset "HMC autodiff" begin @@ -156,7 +155,7 @@ import AdvancedHMC hmc_samplingalg = TransformedMCMC( proposal = HamiltonianMC(), - proposal_tuning = StanHMCTuning(), + transform_tuning = StanHMCTuning(), nchains = 2, nsteps = 100, init = MCMCChainPoolInit(init_tries_per_chain = 2..2, nsteps_init = 5), From 070336b149838ef65adb6dbe721c5a8545a4212d Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Wed, 12 Feb 2025 21:07:23 +0100 Subject: [PATCH 05/11] Use proposed samples from AdvancedHMC.jl for tuners in case of HMC Proposal --- ext/BATAdvancedHMCExt.jl | 4 +- ext/ahmc_impl/ahmc_sampler_impl.jl | 97 ++++++++++++++++--- ext/ahmc_impl/ahmc_stan_tuner_impl.jl | 5 +- src/samplers/mcmc/mcmc_algorithm.jl | 2 +- src/samplers/mcmc/mcmc_sample.jl | 9 +- src/samplers/mcmc/mcmc_state.jl | 7 +- .../mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 16 +-- src/samplers/mcmc/mh_sampler.jl | 2 - 8 files changed, 105 insertions(+), 37 deletions(-) diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index 0964eea0b..e8e5b8f7b 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -23,7 +23,7 @@ using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainS using BAT: MCMCBasicStats, push!, reweight_relative! using BAT: RAMTuning using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning -using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample +using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample_z, proposed_sample using BAT: AbstractTransformTarget, NoAdaptiveTransform, TriangularAffineTransform, valgrad_func using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio @@ -37,6 +37,8 @@ using BAT: AHMCSampleID, AHMCSampleIDVector using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning +using ChangesOfVariables: with_logabsdet_jacobian + using LinearAlgebra: cholesky using MeasureBase: pullback diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index bc0604d3e..9d4f325e8 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -69,7 +69,7 @@ function BAT.next_cycle!(mc_state::HMCState) mc_state.nsamples = 0 mc_state.stepno = 0 - #reset_rng_counters!(mc_state) + reset_rng_counters!(mc_state) resize!(mc_state.samples, 1) @@ -119,16 +119,15 @@ function BAT.mcmc_propose!!(mc_state::HMCState) z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic)) # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics. - proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase) - z_proposed[:] = proposal.transition.z.θ - x_proposed[:] = f_transform(z_proposed) - - proposed_log_posterior = logdensityof(target, x_proposed) - samples.logd[proposed_x_idx] = proposed_log_posterior + proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase) + accepted = z_current[:] != proposal.transition.z.θ + z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc + + p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate - accepted = z_current != z_proposed - tstat = AdvancedHMC.stat(proposal.transition) - p_accept = accepted ? tstat.acceptance_rate : 0.0 + x_proposed[:] = f_transform(z_proposed) + logd_x_proposed = logdensityof(target, x_proposed) + samples.logd[proposed_x_idx] = logd_x_proposed return mc_state, accepted, p_accept end @@ -142,7 +141,7 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float samples.info.sampletype[current] = ACCEPTED_SAMPLE samples.info.sampletype[proposed] = CURRENT_SAMPLE mc_state.nsamples += 1 - + tstat = AdvancedHMC.stat(proposal.transition) samples.info.hamiltonian_energy[proposed] = tstat.hamiltonian_energy # ToDo: Handle proposal-dependent tstat (only NUTS has tree_depth): @@ -176,3 +175,79 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct mc_state_new = @set mc_state_new.f_transform = f_transform_new return mc_state_new end + + +# Copied from AdvancedHMC.jl, but also return proposed point +function _bat_transition( + rng::AbstractRNG, + τ::AdvancedHMC.Trajectory{TS,I,TC}, + h::AdvancedHMC.Hamiltonian, + z0::AdvancedHMC.PhasePoint, +) where { + TS<:AdvancedHMC.AbstractTrajectorySampler, + I<:AdvancedHMC.AbstractIntegrator, + TC<:AdvancedHMC.DynamicTerminationCriterion, +} + H0 = AdvancedHMC.energy(z0) + tree = AdvancedHMC.BinaryTree( + z0, + z0, + AdvancedHMC.TurnStatistic(τ.termination_criterion, z0), + zero(H0), + zero(Int), + zero(H0), + ) + sampler = TS(rng, z0) + termination = AdvancedHMC.Termination(false, false) + zcand = z0 + proposed_zs = Vector[] + + j = 0 + while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth + v = rand(rng, [-1, 1]) + if v == -1 + tree′, sampler′, termination′ = + AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0) + treeleft, treeright = tree′, tree + else + tree′, sampler′, termination′ = + AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0) + treeleft, treeright = tree, tree′ + end + if !AdvancedHMC.isterminated(termination′) + j = j + 1 + if AdvancedHMC.mh_accept(rng, sampler, sampler′) + zcand = sampler′.zcand + end + end + push!(proposed_zs, sampler′.zcand.θ) + + tree = AdvancedHMC.combine(treeleft, treeright) + sampler = AdvancedHMC.combine(zcand, sampler, sampler′) + termination = + termination * + termination′ * + AdvancedHMC.isterminated(τ.termination_criterion, h, tree, treeleft, treeright) + end + + H = AdvancedHMC.energy(zcand) + tstat = AdvancedHMC.merge( + ( + n_steps = tree.nα, + is_accept = true, + acceptance_rate = tree.sum_α / tree.nα, + log_density = zcand.ℓπ.value, + hamiltonian_energy = H, + hamiltonian_energy_error = H - H0, + max_hamiltonian_energy_error = tree.ΔH_max, + tree_depth = j, + numerical_error = termination.numerical, + ), + AdvancedHMC.stat(τ.integrator), + ) + + z_proposed = proposed_zs[end] + p_accept = tstat.acceptance_rate + + return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept +end diff --git a/ext/ahmc_impl/ahmc_stan_tuner_impl.jl b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl index 03584329b..0e5da380b 100644 --- a/ext/ahmc_impl/ahmc_stan_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl @@ -55,8 +55,9 @@ function BAT.mcmc_tune_post_step!!( is_in_window = stan_state.i >= stan_state.window_start && stan_state.i <= stan_state.window_end is_window_end = stan_state.i in stan_state.window_splits - # What to append? - is_in_window && BAT.push!(stats, proposed_sample(chain_state)) + if is_in_window + BAT.push!(stats, proposed_sample(chain_state)) + end if is_window_end A = chain_state.f_transform.A diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index 07b7d38ee..ce5f09e70 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -280,7 +280,7 @@ function mcmc_iterate!!( while ( (nsteps(mcmc_state) - start_nsteps) < max_nsteps && (time() - start_time) < max_time - ) + ) mcmc_state = mcmc_step!!(mcmc_state) if !isnothing(output) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index f0c7a4ec0..4dba17f5c 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -88,23 +88,22 @@ function bat_sample_impl(m::BATMeasure, samplingalg::TransformedMCMC, context::B if !samplingalg.store_burnin chain_outputs .= DensitySampleVector.(mcmc_states) end - + mcmc_states = mcmc_burnin!( samplingalg.store_burnin ? chain_outputs : nothing, mcmc_states, samplingalg, samplingalg.store_burnin ? samplingalg.callback : nop_func ) - + next_cycle!.(mcmc_states) - + mcmc_states = mcmc_iterate!!( chain_outputs, mcmc_states; max_nsteps = samplingalg.nsteps, nonzero_weights = samplingalg.nonzero_weights - ) - + ) samples_transformed = DensitySampleVector(first(mcmc_states)) isempty(chain_outputs) || append!.(Ref(samples_transformed), chain_outputs) diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 3f1cf6150..784df7469 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -145,7 +145,7 @@ end function mcmc_step!!(mcmc_state::MCMCState) _cleanup_samples(mcmc_state) - #reset_rng_counters!(mcmc_state) + reset_rng_counters!(mcmc_state) chain_state = mcmc_state.chain_state @@ -156,7 +156,7 @@ function mcmc_step!!(mcmc_state::MCMCState) resize!(samples, size(samples, 1) + 1) samples.info[lastindex(samples)] = _get_sample_id(proposal, chain_state.info.id, chain_state.info.cycle, chain_state.stepno, PROPOSED_SAMPLE)[1] - + chain_state, accepted, p_accept = mcmc_propose!!(chain_state) mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) @@ -219,7 +219,7 @@ function next_cycle!(chain_state::MCMCChainState) chain_state.nsamples = 0 chain_state.stepno = 0 - #reset_rng_counters!(chain_state) + reset_rng_counters!(chain_state) resize!(chain_state.samples, 1) @@ -277,7 +277,6 @@ end function mcmc_update_z_position!!(mc_state::MCMCChainState) f_transform = mc_state.f_transform - sample_z = mc_state.sample_z current_sample_x = current_sample(mc_state) proposed_sample_x = proposed_sample(mc_state) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index d49014253..cf19a602a 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -76,34 +76,28 @@ function mcmc_tune_post_step!!( mc_state::MCMCChainState, p_accept::Real, ) - # TODO: MD: Discuss; apparently the RandomWalk sampler wants the trafo to be tuned even if p_accept = 0. If not, the burnin does not converge. - if iszero(p_accept) && !(mc_state isa MHChainState) - return mc_state, tuner_state - end - (; f_transform, sample_z) = mc_state (; target_acceptance, gamma) = tuner_state.tuning b = f_transform.b - + tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 - + n_dims = size(sample_z.v[1], 1) η = min(1, n_dims * tuner_state.nsteps^(-gamma)) s_L = f_transform.A u = sample_z.v[2] - sample_z.v[1] # proposed - current - M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L' - + M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2) * s_L' new_s_L = oftype(s_L, cholesky(Positive, M).L) - + x = mc_state.samples[_proposed_sample_idx(mc_state)] # proposed in x-space mean_update_rate = η / 10 # heuristic α = mean_update_rate * p_accept new_b = oftype(b, (1- α) * b + α * x.v) f_transform_new = MulAdd(new_s_L, new_b) - + mc_state_new = set_mc_state_transform!!(mc_state, f_transform_new) mc_state_new = mcmc_update_z_position!!(mc_state_new) diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index ae5b6fe98..06366d4fa 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -144,8 +144,6 @@ function _accept_reject!(mc_state::MHChainState, accepted::Bool, p_accept::Float samples.info.sampletype[proposed] = CURRENT_SAMPLE mc_state.nsamples += 1 - - mc_state.sample_z[1] = deepcopy(proposed_sample_z(mc_state)) else samples.info.sampletype[proposed] = REJECTED_SAMPLE end From 91cad66470948377985c3090af37f1850f5b1c6e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 13 Feb 2025 09:18:39 +0100 Subject: [PATCH 06/11] Revert some whitespace and comment changes --- src/measures/bat_pushfwd_measure.jl | 2 +- src/samplers/mcmc/chain_pool_init.jl | 5 ++--- src/samplers/mcmc/mcmc_algorithm.jl | 2 +- src/samplers/mcmc/mcmc_sample.jl | 9 +++++---- src/samplers/mcmc/multi_cycle_burnin.jl | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/measures/bat_pushfwd_measure.jl b/src/measures/bat_pushfwd_measure.jl index aab2a6575..443f2a8ed 100644 --- a/src/measures/bat_pushfwd_measure.jl +++ b/src/measures/bat_pushfwd_measure.jl @@ -71,7 +71,7 @@ end #!!!!!!!!! Use return type of f with testvalue, if no shape change return varshape(m.orig) directly -# ValueShapes.varshape(m::BATPushFwdMeasure) = varshape(m.origin) +#ValueShapes.varshape(m::BATPushFwdMeasure) = f(varshape(m.orig)) ValueShapes.varshape(m::BATPushFwdMeasure{<:DistributionTransform}) = varshape(m.f.target_dist) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 717216cc5..d39fca2a7 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -98,18 +98,17 @@ function mcmc_init!( next_cycle!.(new_mcmc_states) mcmc_tuning_init!!.(new_mcmc_states, init_alg.nsteps_init) - new_mcmc_states = mcmc_update_z_position!!.(new_mcmc_states) ncandidates += n @debug "Testing $(length(new_mcmc_states)) candidate MCMC chain state(s)." - + new_mcmc_states = mcmc_iterate!!( new_outputs, new_mcmc_states; max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), nonzero_weights = nonzero_weights ) - + viable_idxs = findall(isviablestate.(new_mcmc_states)) viable_mcmc_states = new_mcmc_states[viable_idxs] viable_outputs = new_outputs[viable_idxs] diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index ce5f09e70..09b3b42a1 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -276,7 +276,7 @@ function mcmc_iterate!!( log_time = start_time start_nsteps = nsteps(mcmc_state) start_nsamples = nsamples(mcmc_state) - + while ( (nsteps(mcmc_state) - start_nsteps) < max_nsteps && (time() - start_time) < max_time diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index 4dba17f5c..f0c7a4ec0 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -88,22 +88,23 @@ function bat_sample_impl(m::BATMeasure, samplingalg::TransformedMCMC, context::B if !samplingalg.store_burnin chain_outputs .= DensitySampleVector.(mcmc_states) end - + mcmc_states = mcmc_burnin!( samplingalg.store_burnin ? chain_outputs : nothing, mcmc_states, samplingalg, samplingalg.store_burnin ? samplingalg.callback : nop_func ) - + next_cycle!.(mcmc_states) - + mcmc_states = mcmc_iterate!!( chain_outputs, mcmc_states; max_nsteps = samplingalg.nsteps, nonzero_weights = samplingalg.nonzero_weights - ) + ) + samples_transformed = DensitySampleVector(first(mcmc_states)) isempty(chain_outputs) || append!.(Ref(samples_transformed), chain_outputs) diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index 05f59b1ff..fffe87a6f 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -52,7 +52,7 @@ function mcmc_burnin!( max_nsteps = burnin.nsteps_per_cycle, nonzero_weights = nonzero_weights ) - + mcmc_states = mcmc_tune_post_cycle!!.(mcmc_states, new_outputs) isnothing(outputs) || append!.(outputs, new_outputs) From 379e5f55ed0dbc619404b82d5ec7c67857668295 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 13 Feb 2025 09:43:00 +0100 Subject: [PATCH 07/11] Rename StanHMCTuning to StanLikeTuning --- ext/BATAdvancedHMCExt.jl | 2 +- ext/ahmc_impl/ahmc_config_impl.jl | 2 +- ext/ahmc_impl/ahmc_stan_tuner_impl.jl | 18 +++++++++--------- src/extdefs/ahmc_defs/ahmc_config.jl | 4 ++-- test/samplers/mcmc/test_hmc.jl | 12 ++++++------ 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index e8e5b8f7b..f5b63c8d6 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -35,7 +35,7 @@ using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJ using BAT: HamiltonianMC using BAT: AHMCSampleID, AHMCSampleIDVector using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric -using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning +using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanLikeTuning using ChangesOfVariables: with_logabsdet_jacobian diff --git a/ext/ahmc_impl/ahmc_config_impl.jl b/ext/ahmc_impl/ahmc_config_impl.jl index 29849be1a..a5b72ea86 100644 --- a/ext/ahmc_impl/ahmc_config_impl.jl +++ b/ext/ahmc_impl/ahmc_config_impl.jl @@ -73,7 +73,7 @@ function ahmc_adaptor( end function ahmc_adaptor( - tuning::StanHMCTuning, + tuning::StanLikeTuning, metric::AdvancedHMC.AbstractMetric, integrator::AdvancedHMC.AbstractIntegrator, θ_init::AbstractVector{<:Real} diff --git a/ext/ahmc_impl/ahmc_stan_tuner_impl.jl b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl index 0e5da380b..dcd3b0d74 100644 --- a/ext/ahmc_impl/ahmc_stan_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_stan_tuner_impl.jl @@ -1,32 +1,32 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -mutable struct StanHMCTrafoTunerState{ +mutable struct StanLikeTunerState{ S<:MCMCBasicStats, } <: MCMCTransformTunerState - tuning::StanHMCTuning + tuning::StanLikeTuning target_acceptance::Float64 stats::S stan_state::AdvancedHMC.Adaptation.StanHMCAdaptorState end -BAT.create_trafo_tuner_state(tuning::StanHMCTuning, chain_state::MCMCChainState, n_steps_hint::Integer) = StanHMCTrafoTunerState(tuning, tuning.target_acceptance, MCMCBasicStats(chain_state), AdvancedHMC.Adaptation.StanHMCAdaptorState()) +BAT.create_trafo_tuner_state(tuning::StanLikeTuning, chain_state::MCMCChainState, n_steps_hint::Integer) = StanLikeTunerState(tuning, tuning.target_acceptance, MCMCBasicStats(chain_state), AdvancedHMC.Adaptation.StanHMCAdaptorState()) -function BAT.mcmc_tuning_init!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) +function BAT.mcmc_tuning_init!!(tuner::StanLikeTunerState, chain_state::HMCState, max_nsteps::Integer) tuning = tuner.tuning AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1)) nothing end -function BAT.mcmc_tuning_reinit!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) +function BAT.mcmc_tuning_reinit!!(tuner::StanLikeTunerState, chain_state::HMCState, max_nsteps::Integer) tuning = tuner.tuning AdvancedHMC.Adaptation.initialize!(tuner.stan_state, tuning.init_buffer, tuning.term_buffer, tuning.window_size, Int(max_nsteps - 1)) nothing end -BAT.mcmc_tuning_postinit!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing +BAT.mcmc_tuning_postinit!!(tuner::StanLikeTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing -function BAT.mcmc_tune_post_cycle!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) +function BAT.mcmc_tune_post_cycle!!(tuner::StanLikeTunerState, chain_state::HMCState, samples::DensitySampleVector) max_log_posterior = maximum(samples.logd) accept_ratio = eff_acceptance_ratio(chain_state) if accept_ratio >= 0.9 * tuner.target_acceptance @@ -40,11 +40,11 @@ function BAT.mcmc_tune_post_cycle!!(tuner::StanHMCTrafoTunerState, chain_state:: end -BAT.mcmc_tuning_finalize!!(tuner::StanHMCTrafoTunerState, chain_state::HMCState) = nothing +BAT.mcmc_tuning_finalize!!(tuner::StanLikeTunerState, chain_state::HMCState) = nothing function BAT.mcmc_tune_post_step!!( - tuner::StanHMCTrafoTunerState, + tuner::StanLikeTunerState, chain_state::MCMCChainState, p_accept::Real ) diff --git a/src/extdefs/ahmc_defs/ahmc_config.jl b/src/extdefs/ahmc_defs/ahmc_config.jl index c72b5df85..9cd0568d4 100644 --- a/src/extdefs/ahmc_defs/ahmc_config.jl +++ b/src/extdefs/ahmc_defs/ahmc_config.jl @@ -30,7 +30,7 @@ end # Uses Stan (also AdvancedHMC) defaults # (see https://mc-stan.org/docs/2_26/reference-manual/hmc-algorithm-parameters.html): -@with_kw struct StanHMCTuning <: MCMCTransformTuning +@with_kw struct StanLikeTuning <: MCMCTransformTuning "target acceptance rate" target_acceptance::Float64 = 0.8 @@ -44,4 +44,4 @@ end window_size::Int = 25 end -export StanHMCTuning +export StanLikeTuning diff --git a/test/samplers/mcmc/test_hmc.jl b/test/samplers/mcmc/test_hmc.jl index 3575f6a1d..eefa3a534 100644 --- a/test/samplers/mcmc/test_hmc.jl +++ b/test/samplers/mcmc/test_hmc.jl @@ -19,7 +19,7 @@ import AdvancedHMC @test target isa BAT.BATDistMeasure proposal = HamiltonianMC() - transform_tuning = StanHMCTuning() + transform_tuning = StanLikeTuning() nchains = 4 samplingalg = TransformedMCMC(proposal = proposal, transform_tuning = transform_tuning, nchains = nchains) @@ -48,7 +48,7 @@ import AdvancedHMC @testset "MCMC tuning and burn-in" begin max_nsteps = 10^5 - transform_tuning = BAT.StanHMCTuning() + transform_tuning = BAT.StanLikeTuning() pretransform = DoNotTransform() init_alg = bat_default(TransformedMCMC, Val(:init), proposal, pretransform, nchains, max_nsteps) burnin_alg = bat_default(TransformedMCMC, Val(:burnin), proposal, pretransform, nchains, max_nsteps) @@ -107,7 +107,7 @@ import AdvancedHMC shaped_target, TransformedMCMC( proposal = proposal, - transform_tuning = StanHMCTuning(), + transform_tuning = StanLikeTuning(), pretransform = DoNotTransform(), nsteps = 10^4, store_burnin = true @@ -123,7 +123,7 @@ import AdvancedHMC shaped_target, TransformedMCMC( proposal = proposal, - transform_tuning = StanHMCTuning(), + transform_tuning = StanLikeTuning(), pretransform = DoNotTransform(), nsteps = 10^4, store_burnin = false @@ -143,7 +143,7 @@ import AdvancedHMC inner_posterior = PosteriorMeasure(likelihood, prior) # Test with nested posteriors: posterior = PosteriorMeasure(likelihood, inner_posterior) - @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), transform_tuning = StanHMCTuning(), pretransform = PriorToNormal()), prior.dist, context).verified + @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), transform_tuning = StanLikeTuning(), pretransform = PriorToNormal()), prior.dist, context).verified end @testset "HMC autodiff" begin @@ -155,7 +155,7 @@ import AdvancedHMC hmc_samplingalg = TransformedMCMC( proposal = HamiltonianMC(), - transform_tuning = StanHMCTuning(), + transform_tuning = StanLikeTuning(), nchains = 2, nsteps = 100, init = MCMCChainPoolInit(init_tries_per_chain = 2..2, nsteps_init = 5), From 36b51c93aa1d9406f0424dc21c02765481cb51dc Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Thu, 13 Feb 2025 12:44:02 +0100 Subject: [PATCH 08/11] Change default HMC Metric for HamiltonianMC MCMCProposal to UnitEuclideanMetric --- ext/ahmc_impl/ahmc_sampler_impl.jl | 1 - src/extdefs/ahmc_defs/ahmc_alg.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index 9d4f325e8..da585c505 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -133,7 +133,6 @@ function BAT.mcmc_propose!!(mc_state::HMCState) end function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer) - # @unpack samples, proposal = mc_state samples = mc_state.samples proposal = mc_state.proposal diff --git a/src/extdefs/ahmc_defs/ahmc_alg.jl b/src/extdefs/ahmc_defs/ahmc_alg.jl index 3772987b4..b697158d0 100644 --- a/src/extdefs/ahmc_defs/ahmc_alg.jl +++ b/src/extdefs/ahmc_defs/ahmc_alg.jl @@ -35,7 +35,7 @@ $(TYPEDFIELDS) IT, TC } <: MCMCProposal - metric::MT = DiagEuclideanMetric() + metric::MT = UnitEuclideanMetric() integrator::IT = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_INTEGRATOR)) termination::TC = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_TERMINATION_CRITERION)) end From 18d11153f86ade9ee57f685ce96a77081f06e861 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Thu, 13 Feb 2025 13:59:33 +0100 Subject: [PATCH 09/11] Adjust HMC Tests --- test/samplers/mcmc/test_hmc.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/samplers/mcmc/test_hmc.jl b/test/samplers/mcmc/test_hmc.jl index eefa3a534..eab3c1d1d 100644 --- a/test/samplers/mcmc/test_hmc.jl +++ b/test/samplers/mcmc/test_hmc.jl @@ -56,7 +56,7 @@ import AdvancedHMC strict = true nonzero_weights = false callback = (x...) -> nothing - + samplingalg = TransformedMCMC(proposal = proposal, transform_tuning = transform_tuning, pretransform = pretransform, @@ -66,7 +66,7 @@ import AdvancedHMC strict = strict, nonzero_weights = nonzero_weights ) - + # Note: No @inferred, not type stable (yet) with HamiltonianMC init_result = BAT.mcmc_init!( samplingalg, @@ -75,33 +75,34 @@ import AdvancedHMC callback, context ) - + (mcmc_states, outputs) = init_result - # @test mcmc_states isa AbstractVector{<:BAT.HMCState} # TODO: MD, reactivate, works for AbstractVector{<:MCMCChainState}, but doesn't seen to like the typealias - # @test tuners isa AbstractVector{<:BAT.HMCState} + @test mcmc_states isa AbstractVector{<:BAT.MCMCState} @test outputs isa AbstractVector{<:DensitySampleVector} - - BAT.mcmc_burnin!( + + mcmc_states = BAT.mcmc_burnin!( outputs, mcmc_states, samplingalg, callback ) - + + BAT.next_cycle!.(mcmc_states) + mcmc_states = BAT.mcmc_iterate!!( outputs, mcmc_states; max_nsteps = div(max_nsteps, length(mcmc_states)), nonzero_weights = nonzero_weights ) - + samples = DensitySampleVector(first(mcmc_states)) append!.(Ref(samples), outputs) @test length(samples) == sum(samples.weight) @test BAT.test_dist_samples(unshaped(objective), samples) end - + @testset "bat_sample" begin samples = bat_sample( shaped_target, From 5f5396a7fb5a969b9dbc76899caa605692d94af4 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Fri, 14 Feb 2025 17:42:13 +0100 Subject: [PATCH 10/11] Change proposed sample in hmc to weighted mean, Fix weight assignment error in mcmc_stepgit add -A () --- ext/ahmc_impl/ahmc_sampler_impl.jl | 16 ++++++++++++---- src/samplers/mcmc/mcmc_state.jl | 8 +++----- test/samplers/mcmc/test_hmc.jl | 1 - 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index da585c505..ec2f9a355 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -125,10 +125,12 @@ function BAT.mcmc_propose!!(mc_state::HMCState) p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate - x_proposed[:] = f_transform(z_proposed) + x_proposed[:], ladj = with_logabsdet_jacobian(f_transform, z_proposed) logd_x_proposed = logdensityof(target, x_proposed) samples.logd[proposed_x_idx] = logd_x_proposed + sample_z.logd[proposed_z_idx] = logd_x_proposed + ladj + return mc_state, accepted, p_accept end @@ -200,6 +202,7 @@ function _bat_transition( termination = AdvancedHMC.Termination(false, false) zcand = z0 proposed_zs = Vector[] + accept_probs = Float64[] j = 0 while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth @@ -213,14 +216,18 @@ function _bat_transition( AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0) treeleft, treeright = tree, tree′ end + + # This acceptance prob. is specific to AdvancedHMC.MultinomialTS + p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw)) + push!(accept_probs, p_tmp) + push!(proposed_zs, sampler′.zcand.θ) + if !AdvancedHMC.isterminated(termination′) j = j + 1 if AdvancedHMC.mh_accept(rng, sampler, sampler′) zcand = sampler′.zcand end end - push!(proposed_zs, sampler′.zcand.θ) - tree = AdvancedHMC.combine(treeleft, treeright) sampler = AdvancedHMC.combine(zcand, sampler, sampler′) termination = @@ -245,7 +252,8 @@ function _bat_transition( AdvancedHMC.stat(τ.integrator), ) - z_proposed = proposed_zs[end] + accept_total = sum(accept_probs) + z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total p_accept = tstat.acceptance_rate return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 784df7469..0f0ef0c8c 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -159,15 +159,13 @@ function mcmc_step!!(mcmc_state::MCMCState) chain_state, accepted, p_accept = mcmc_propose!!(chain_state) - mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) - - chain_state = mcmc_state_new.chain_state - current = _current_sample_idx(chain_state) proposed = _proposed_sample_idx(chain_state) - _accept_reject!(chain_state, accepted, p_accept, current, proposed) + mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) + + chain_state = mcmc_state_new.chain_state mcmc_state_final = @set mcmc_state_new.chain_state = chain_state return mcmc_state_final diff --git a/test/samplers/mcmc/test_hmc.jl b/test/samplers/mcmc/test_hmc.jl index eab3c1d1d..bd96313d6 100644 --- a/test/samplers/mcmc/test_hmc.jl +++ b/test/samplers/mcmc/test_hmc.jl @@ -26,7 +26,6 @@ import AdvancedHMC @testset "MCMC iteration" begin v_init = bat_initval(target, InitFromTarget(), context).result # Note: No @inferred, since MCMCChainState is not type stable (yet) with HamiltonianMC - # TODO: MD, reactivate @test BAT.MCMCChainState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) isa BAT.HMCState mcmc_state = BAT.MCMCState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) nsteps = 10^4 From 68565f2048011d8ba0620c3d278b74d4792844fa Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Fri, 14 Feb 2025 23:17:53 +0100 Subject: [PATCH 11/11] Remove proposed sample extraction for HMC altogether --- ext/ahmc_impl/ahmc_sampler_impl.jl | 90 ++----------------- src/samplers/mcmc/mcmc_state.jl | 2 + .../mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 5 ++ 3 files changed, 12 insertions(+), 85 deletions(-) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index ec2f9a355..808eba975 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState) z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic)) # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics. - proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase) - accepted = z_current[:] != proposal.transition.z.θ - z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc + proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase) + p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate + + z_proposed[:] = proposal.transition.z.θ + accepted = z_current[:] != z_proposed[:] p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate @@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct mc_state_new = @set mc_state_new.f_transform = f_transform_new return mc_state_new end - - -# Copied from AdvancedHMC.jl, but also return proposed point -function _bat_transition( - rng::AbstractRNG, - τ::AdvancedHMC.Trajectory{TS,I,TC}, - h::AdvancedHMC.Hamiltonian, - z0::AdvancedHMC.PhasePoint, -) where { - TS<:AdvancedHMC.AbstractTrajectorySampler, - I<:AdvancedHMC.AbstractIntegrator, - TC<:AdvancedHMC.DynamicTerminationCriterion, -} - H0 = AdvancedHMC.energy(z0) - tree = AdvancedHMC.BinaryTree( - z0, - z0, - AdvancedHMC.TurnStatistic(τ.termination_criterion, z0), - zero(H0), - zero(Int), - zero(H0), - ) - sampler = TS(rng, z0) - termination = AdvancedHMC.Termination(false, false) - zcand = z0 - proposed_zs = Vector[] - accept_probs = Float64[] - - j = 0 - while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth - v = rand(rng, [-1, 1]) - if v == -1 - tree′, sampler′, termination′ = - AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0) - treeleft, treeright = tree′, tree - else - tree′, sampler′, termination′ = - AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0) - treeleft, treeright = tree, tree′ - end - - # This acceptance prob. is specific to AdvancedHMC.MultinomialTS - p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw)) - push!(accept_probs, p_tmp) - push!(proposed_zs, sampler′.zcand.θ) - - if !AdvancedHMC.isterminated(termination′) - j = j + 1 - if AdvancedHMC.mh_accept(rng, sampler, sampler′) - zcand = sampler′.zcand - end - end - tree = AdvancedHMC.combine(treeleft, treeright) - sampler = AdvancedHMC.combine(zcand, sampler, sampler′) - termination = - termination * - termination′ * - AdvancedHMC.isterminated(τ.termination_criterion, h, tree, treeleft, treeright) - end - - H = AdvancedHMC.energy(zcand) - tstat = AdvancedHMC.merge( - ( - n_steps = tree.nα, - is_accept = true, - acceptance_rate = tree.sum_α / tree.nα, - log_density = zcand.ℓπ.value, - hamiltonian_energy = H, - hamiltonian_energy_error = H - H0, - max_hamiltonian_energy_error = tree.ΔH_max, - tree_depth = j, - numerical_error = termination.numerical, - ), - AdvancedHMC.stat(τ.integrator), - ) - - accept_total = sum(accept_probs) - z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total - p_accept = tstat.acceptance_rate - - return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept -end diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 0f0ef0c8c..1b51a3743 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -161,6 +161,8 @@ function mcmc_step!!(mcmc_state::MCMCState) current = _current_sample_idx(chain_state) proposed = _proposed_sample_idx(chain_state) + + # This does not change `sample_z` in the chain_state, that happens in the next mcmc step in `_cleanup_samples()`. _accept_reject!(chain_state, accepted, p_accept, current, proposed) mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index cf19a602a..1c1922f60 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -76,6 +76,11 @@ function mcmc_tune_post_step!!( mc_state::MCMCChainState, p_accept::Real, ) + + if current_sample_z(mc_state).v == proposed_sample_z(mc_state) + return mc_state, tuner_state + end + (; f_transform, sample_z) = mc_state (; target_acceptance, gamma) = tuner_state.tuning b = f_transform.b