Skip to content

Commit 477ae17

Browse files
torfjeldedevmotionyebai
authored
Compatibility with new DPPL version (#1900)
* use unflatten in evaluation of LogDensityFunction * make AD-related functions able to take AbstractVarInfo * use unflatten where appropriate * updated Gibbs * updated HMC * move to using BangBang versions of link and invlink * use link!! * update tests to be compatible with new DynamicPPL.TestUtils * updated deps for tests * fixed tests for ESS * upper-bound distributions in tests because otherwise depwarns will cause timeouts * replace link! with link!!, etc. * added Setfield and updated optimization stuff * updated the contrib to use link!!, etc. * updated AD tests * updated DPPL versions * removed usage of deprecated inv * made some function signatures more restrictive * Update src/inference/mh.jl * fixed MH sampler * increase atol for certain tests to make them pass on MacOS * reduce atol for a MH test * disable emcee tests for now * Update Project.toml Co-authored-by: David Widmann <[email protected]> * further reductions in atol to make tests pass * Update test/runtests.jl * Update mh.jl * restrict ForwardDiff for tests to avoid issue with cholesky * increased number of samples and lowered atol for MH tests Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent ad57181 commit 477ae17

File tree

15 files changed

+142
-123
lines changed

15 files changed

+142
-123
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2727
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2828
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2929
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
30+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3031
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3132
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3233
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -45,7 +46,7 @@ DataStructures = "0.18"
4546
Distributions = "0.23.3, 0.24, 0.25"
4647
DistributionsAD = "0.6"
4748
DocStringExtensions = "0.8, 0.9"
48-
DynamicPPL = "0.20"
49+
DynamicPPL = "0.21"
4950
EllipticalSliceSampling = "0.5, 1"
5051
ForwardDiff = "0.10.3"
5152
Libtask = "0.6.7, 0.7"
@@ -55,6 +56,7 @@ NamedArrays = "0.9"
5556
Reexport = "0.2, 1"
5657
Requires = "0.5, 1.0"
5758
SciMLBase = "1.37.1"
59+
Setfield = "0.8"
5860
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
5961
StatsBase = "0.32, 0.33"
6062
StatsFuns = "0.8, 0.9, 1"

src/Turing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct LogDensityFunction{V,M,S,C}
3535
end
3636

3737
function (f::LogDensityFunction)(θ::AbstractVector)
38-
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
38+
vi_new = DynamicPPL.unflatten(f.varinfo, f.sampler, θ)
39+
return getlogp(last(DynamicPPL.evaluate!!(f.model, vi_new, f.sampler, f.context)))
3940
end
4041

4142
# LogDensityProblems interface

src/contrib/inference/dynamichmc.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ function DynamicPPL.initialstep(
6060
)
6161
# Ensure that initial sample is in unconstrained space.
6262
if !DynamicPPL.islinked(vi, spl)
63-
DynamicPPL.link!(vi, spl)
64-
model(rng, vi, spl)
63+
vi = DynamicPPL.link!!(vi, spl, model)
64+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
6565
end
6666

6767
# Define log-density function.
@@ -79,8 +79,8 @@ function DynamicPPL.initialstep(
7979
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
8080

8181
# Update the variables.
82-
vi[spl] = Q.q
83-
DynamicPPL.setlogp!!(vi, Q.ℓq)
82+
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
83+
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
8484

8585
# Create first sample and state.
8686
sample = Transition(vi)
@@ -109,8 +109,8 @@ function AbstractMCMC.step(
109109
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
110110

111111
# Update the variables.
112-
vi[spl] = Q.q
113-
DynamicPPL.setlogp!!(vi, Q.ℓq)
112+
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
113+
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
114114

115115
# Create next sample and state.
116116
sample = Transition(vi)

src/contrib/inference/sghmc.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ function DynamicPPL.initialstep(
5656
)
5757
# Transform the samples to unconstrained space and compute the joint log probability.
5858
if !DynamicPPL.islinked(vi, spl)
59-
DynamicPPL.link!(vi, spl)
60-
model(rng, vi, spl)
59+
vi = DynamicPPL.link!!(vi, spl, model)
60+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
6161
end
6262

6363
# Compute initial sample and state.
@@ -90,8 +90,8 @@ function AbstractMCMC.step(
9090
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
9191

9292
# Save new variables and recompute log density.
93-
vi[spl] = θ
94-
model(rng, vi, spl)
93+
vi = DynamicPPL.setindex!!(vi, θ, spl)
94+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
9595

9696
# Compute next sample and state.
9797
sample = Transition(vi)
@@ -209,8 +209,8 @@ function DynamicPPL.initialstep(
209209
)
210210
# Transform the samples to unconstrained space and compute the joint log probability.
211211
if !DynamicPPL.islinked(vi, spl)
212-
DynamicPPL.link!(vi, spl)
213-
model(rng, vi, spl)
212+
vi = DynamicPPL.link!!(vi, spl, model)
213+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
214214
end
215215

216216
# Create first sample and state.
@@ -238,8 +238,8 @@ function AbstractMCMC.step(
238238
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
239239

240240
# Save new variables and recompute log density.
241-
vi[spl] = θ
242-
model(rng, vi, spl)
241+
vi = DynamicPPL.setindex!!(vi, θ, spl)
242+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
243243

244244
# Compute next sample and state.
245245
sample = SGLDTransition(vi, stepsize)

src/inference/emcee.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function AbstractMCMC.step(
4343
ArgumentError("initial parameters have to be specified for each walker")
4444
)
4545
vis = map(vis, init_params) do vi, init
46-
vi = DynamicPPL.initialize_parameters!!(vi, init, spl)
46+
vi = DynamicPPL.initialize_parameters!!(vi, init, spl, model)
4747

4848
# Update log joint probability.
4949
last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))
@@ -57,7 +57,7 @@ function AbstractMCMC.step(
5757
state = EmceeState(
5858
vis[1],
5959
map(vis) do vi
60-
DynamicPPL.link!(vi, spl)
60+
vi = DynamicPPL.link!!(vi, spl, model)
6161
AMH.Transition(vi[spl], getlogp(vi))
6262
end
6363
)
@@ -82,9 +82,9 @@ function AbstractMCMC.step(
8282
# Compute the next transition and state.
8383
transition = map(states) do _state
8484
vi = setindex!!(vi, _state.params, spl)
85-
DynamicPPL.invlink!(vi, spl)
85+
vi = DynamicPPL.invlink!!(vi, spl, model)
8686
t = Transition(tonamedtuple(vi), _state.lp)
87-
DynamicPPL.link!(vi, spl)
87+
vi = DynamicPPL.link!!(vi, spl, model)
8888
return t
8989
end
9090
newstate = EmceeState(vi, states)

src/inference/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ function DynamicPPL.initialstep(
199199
states = map(samplers) do local_spl
200200
# Recompute `vi.logp` if needed.
201201
if local_spl.selector.rerun
202-
model(rng, vi, local_spl)
202+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, local_spl)))
203203
end
204204

205205
# Compute initial state.

src/inference/hmc.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ function DynamicPPL.initialstep(
150150
kwargs...
151151
)
152152
# Transform the samples to unconstrained space and compute the joint log probability.
153-
link!(vi, spl)
154-
vi = last(DynamicPPL.evaluate!!(model, rng, vi, spl))
153+
vi = link!!(vi, spl, model)
155154

156155
# Extract parameters.
157156
theta = vi[spl]
@@ -173,8 +172,8 @@ function DynamicPPL.initialstep(
173172
# and its gradient are finite.
174173
if init_params === nothing
175174
while !isfinite(z)
175+
# NOTE: This will sample in the unconstrained space.
176176
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
177-
link!(vi, spl)
178177
theta = vi[spl]
179178

180179
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
@@ -210,10 +209,10 @@ function DynamicPPL.initialstep(
210209

211210
# Update `vi` based on acceptance
212211
if t.stat.is_accept
213-
vi = setindex!!(vi, t.z.θ, spl)
212+
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
214213
vi = setlogp!!(vi, t.stat.log_density)
215214
else
216-
vi = setindex!!(vi, theta, spl)
215+
vi = DynamicPPL.unflatten(vi, spl, theta)
217216
vi = setlogp!!(vi, log_density_old)
218217
end
219218

@@ -252,7 +251,7 @@ function AbstractMCMC.step(
252251
# Update variables
253252
vi = state.vi
254253
if t.stat.is_accept
255-
vi = setindex!!(vi, t.z.θ, spl)
254+
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
256255
vi = setlogp!!(vi, t.stat.log_density)
257256
end
258257

@@ -532,8 +531,9 @@ function HMCState(
532531
kwargs...
533532
)
534533
# Link everything if needed.
535-
if !islinked(vi, spl)
536-
link!(vi, spl)
534+
waslinked = islinked(vi, spl)
535+
if !waslinked
536+
vi = link!!(vi, spl, model)
537537
end
538538

539539
# Get the initial log pdf and gradient functions.
@@ -562,8 +562,10 @@ function HMCState(
562562
# Generate a phasepoint. Replaced during sample_init!
563563
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.
564564

565-
# Unlink everything.
566-
invlink!(vi, spl)
565+
# Unlink everything, if it was indeed linked before.
566+
if waslinked
567+
vi = invlink!!(vi, spl, model)
568+
end
567569

568570
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
569571
end

src/inference/mh.jl

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ end
197197
198198
Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
199199
"""
200-
function set_namedtuple!(vi::VarInfo, nt::NamedTuple)
200+
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
201+
# TODO: Replace this with something like
202+
# for vn in keys(vi)
203+
# vi = DynamicPPL.setindex!!(vi, get(nt, vn))
204+
# end
201205
for (n, vals) in pairs(nt)
202206
vns = vi.metadata[n].vns
203207
nvns = length(vns)
@@ -245,6 +249,7 @@ This variant uses the `set_namedtuple!` function to update the `VarInfo`.
245249
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}
246250

247251
function (f::MHLogDensityFunction)(x::NamedTuple)
252+
# TODO: Make this work with immutable `f.varinfo` too.
248253
sampler = f.sampler
249254
vi = f.varinfo
250255

@@ -286,14 +291,14 @@ function reconstruct(
286291
end
287292

288293
"""
289-
dist_val_tuple(spl::Sampler{<:MH}, vi::AbstractVarInfo)
294+
dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo)
290295
291296
Return two `NamedTuples`.
292297
293298
The first `NamedTuple` has symbols as keys and distributions as values.
294299
The second `NamedTuple` has model symbols as keys and their stored values as values.
295300
"""
296-
function dist_val_tuple(spl::Sampler{<:MH}, vi::AbstractVarInfo)
301+
function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
297302
vns = _getvns(vi, spl)
298303
dt = _dist_tuple(spl.alg.proposals, vi, vns)
299304
vt = _val_tuple(vi, vns)
@@ -349,15 +354,12 @@ function should_link(
349354
return true
350355
end
351356

352-
function maybe_link!(varinfo, sampler, proposal)
353-
if should_link(varinfo, sampler, proposal)
354-
link!(varinfo, sampler)
355-
end
356-
return nothing
357+
function maybe_link!!(varinfo, sampler, proposal, model)
358+
return should_link(varinfo, sampler, proposal) ? link!!(varinfo, sampler, model) : varinfo
357359
end
358360

359361
# Make a proposal if we don't have a covariance proposal matrix (the default).
360-
function propose!(
362+
function propose!!(
361363
rng::AbstractRNG,
362364
vi::AbstractVarInfo,
363365
model::Model,
@@ -378,13 +380,11 @@ function propose!(
378380
# TODO: Make this compatible with immutable `VarInfo`.
379381
# Update the values in the VarInfo.
380382
set_namedtuple!(vi, trans.params)
381-
setlogp!!(vi, trans.lp)
382-
383-
return vi
383+
return setlogp!!(vi, trans.lp)
384384
end
385385

386386
# Make a proposal if we DO have a covariance proposal matrix.
387-
function propose!(
387+
function propose!!(
388388
rng::AbstractRNG,
389389
vi::AbstractVarInfo,
390390
model::Model,
@@ -403,12 +403,7 @@ function propose!(
403403
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
404404
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
405405

406-
# TODO: Make this compatible with immutable `VarInfo`.
407-
# Update the values in the VarInfo.
408-
setindex!!(vi, trans.params, spl)
409-
setlogp!!(vi, trans.lp)
410-
411-
return vi
406+
return setlogp!!(DynamicPPL.unflatten(vi, spl, trans.params), trans.lp)
412407
end
413408

414409
function DynamicPPL.initialstep(
@@ -420,7 +415,7 @@ function DynamicPPL.initialstep(
420415
)
421416
# If we're doing random walk with a covariance matrix,
422417
# just link everything before sampling.
423-
maybe_link!(vi, spl, spl.alg.proposals)
418+
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)
424419

425420
return Transition(vi), vi
426421
end
@@ -435,7 +430,7 @@ function AbstractMCMC.step(
435430
# Cases:
436431
# 1. A covariance proposal matrix
437432
# 2. A bunch of NamedTuples that specify the proposal space
438-
propose!(rng, vi, model, spl, spl.alg.proposals)
433+
vi = propose!!(rng, vi, model, spl, spl.alg.proposals)
439434

440435
return Transition(vi), vi
441436
end

0 commit comments

Comments
 (0)