Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks and making fit for the next version #31

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions benchmarks/bouncy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ model_lr = @model (At, y, σ) begin
end
σ = 100.0


function make_grads(model_lr, At, y, σ)
post = model_lr(At, y, σ) | (;y)
as_post = as(post)
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
ℓ(θ) = -obj(θ)
@inline function dneglogp(t, x, v) # two directional derivatives
@inline function dneglogp(t, x, v, args...) # two directional derivatives
f(t) = obj(x + t*v)
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
u.value, u.partials[]
end

gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
function ∇neglogp!(y, t, x)
function ∇neglogp!(y, t, x, args...)
ForwardDiff.gradient!(y, obj, x, gconfig)
return
end
Expand All @@ -73,8 +74,9 @@ n = 2000
c = 4.0 # initial guess for the bound

init_scale=1;
print("Pathfinder: ")
@time pf_result = pathfinder(ℓ; dim=d, init_scale);
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ));
#M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ));
M = pf_result.fit_distribution.Σ;
x0 = pf_result.fit_distribution.μ;
v0 = PDMats.unwhiten(M, randn(length(x0)));
Expand All @@ -88,8 +90,8 @@ MAP = pf_result.optim_solution; # MAP, could be useful for control variates
# define BouncyParticle sampler (has two relevant parameters)
Z = BouncyParticle(missing, # graphical structure
MAP, # MAP estimate, unused
2.0, # momentum refreshment rate and sample saving rate
0.95, # momentum correlation / only gradually change momentum in refreshment/momentum update
0.2, # momentum refreshment rate and sample saving rate
0.9, # momentum correlation / only gradually change momentum in refreshment/momentum update
M, # metric (PDMat compatible object for momentum covariance)
missing # legacy
) ;
Expand All @@ -104,7 +106,7 @@ using TupleVectors: chainvec
using Tilde.MeasureTheory: transform


function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
function collect_sampler(trans, sampler, n; progress=true, progress_stops=20)
if progress
prg = Progress(progress_stops, 1)
else
Expand All @@ -113,15 +115,15 @@ function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
nstop = n/stops

x1 = transform(t, sampler.u0[2][1])
x1 = trans(sampler.u0[2][1])
tv = chainvec(x1, n)
ϕ = iterate(sampler)
j = 1
local state
while ϕ !== nothing && j < n
j += 1
val, state = ϕ
tv[j] = transform(t, val[2])
tv[j] = trans(val[2])
ϕ = iterate(sampler, state)
if j > nstop
nstop += n/stops
Expand All @@ -131,13 +133,16 @@ function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
ismissing(prg) || ProgressMeter.finish!(prg)
tv, (;uT=state[1], acc=state[3][1], total=state[3][2], bound=state[4].c)
end
collect_sampler(as(post), sampler, 10; progress=false);
collect_sampler(transform(as(post)), sampler, 10; progress=false);

print("BPS: ")
elapsed_time = @elapsed @time begin
trans = transform(as(post))
global bps_samples, info
bps_samples, info = collect_sampler(as(post), sampler, n; progress=false)
bps_samples, info = collect_sampler(trans, sampler, n; progress=false)
end


using MCMCChains
bps_chain = MCMCChains.Chains(bps_samples.θ);
bps_chain = setinfo(bps_chain, (;start_time=0.0, stop_time = elapsed_time));
Expand All @@ -152,10 +157,11 @@ Tilde.sample(post, dynamichmc(
;init=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
warmup_stages=default_warmup_stages(; middle_steps=0, doubling_stages=0),
), 1,1);
print("HMC: ")
hmc_time = @elapsed @time (hmc_samples = Tilde.sample(post, dynamichmc(
;init=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
warmup_stages=default_warmup_stages(; middle_steps=0, doubling_stages=0),
), 2000,1));
), n,1));
hmc_chain = MCMCChains.Chains(hmc_samples.θ);
μ̂2 = round.(mean(hmc_chain).nt[:mean], sigdigits=4);
println("μ̂ (HMC) = ", μ̂2)
Expand All @@ -168,8 +174,8 @@ using UnicodePlots

plt = scatterplot(ess_bps, ess_hmc);
UnicodePlots.title!(plt, "Effective Samples Per Second");
xlabel!(plt, "Bouncy Particle Sampler");
ylabel!(plt, "DynamicHMC");
UnicodePlots.xlabel!(plt, "Bouncy Particle Sampler");
UnicodePlots.ylabel!(plt, "DynamicHMC");
plt_bounds = collect(extrema(ess_hmc));
lineplot!(plt, plt_bounds, plt_bounds);
plt
Expand Down