Skip to content

Commit

Permalink
Adjust HMC Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Feb 13, 2025
1 parent 03971c1 commit a52127d
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions test/samplers/mcmc/test_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import AdvancedHMC
strict = true
nonzero_weights = false
callback = (x...) -> nothing

samplingalg = TransformedMCMC(proposal = proposal,
transform_tuning = transform_tuning,
pretransform = pretransform,
Expand All @@ -66,7 +66,7 @@ import AdvancedHMC
strict = strict,
nonzero_weights = nonzero_weights
)

# Note: No @inferred, not type stable (yet) with HamiltonianMC
init_result = BAT.mcmc_init!(
samplingalg,
Expand All @@ -75,33 +75,34 @@ import AdvancedHMC
callback,
context
)

(mcmc_states, outputs) = init_result
# @test mcmc_states isa AbstractVector{<:BAT.HMCState} # TODO: MD, reactivate, works for AbstractVector{<:MCMCChainState}, but doesn't seen to like the typealias
# @test tuners isa AbstractVector{<:BAT.HMCState}
@test mcmc_states isa AbstractVector{<:BAT.MCMCState}
@test outputs isa AbstractVector{<:DensitySampleVector}

BAT.mcmc_burnin!(
mcmc_states = BAT.mcmc_burnin!(
outputs,
mcmc_states,
samplingalg,
callback
)


BAT.next_cycle!.(mcmc_states)

mcmc_states = BAT.mcmc_iterate!!(
outputs,
mcmc_states;
max_nsteps = div(max_nsteps, length(mcmc_states)),
nonzero_weights = nonzero_weights
)

samples = DensitySampleVector(first(mcmc_states))
append!.(Ref(samples), outputs)

@test length(samples) == sum(samples.weight)
@test BAT.test_dist_samples(unshaped(objective), samples)
end

@testset "bat_sample" begin
samples = bat_sample(
shaped_target,
Expand Down

0 comments on commit a52127d

Please sign in to comment.