-
Notifications
You must be signed in to change notification settings - Fork 226
sample with LogDensityFunction: part 1 - hmc.jl
, sghmc.jl
, DynamicHMCExt
#2588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sample-ldf
Are you sure you want to change the base?
Conversation
Turing.jl documentation for PR #2588 is available at: |
hmc.jl
+ sghmc.jl
hmc.jl
+ sghmc.jl
hmc.jl
, sghmc.jl
, DynamicHMCExt
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## sample-ldf #2588 +/- ##
===============================================
- Coverage 85.50% 35.32% -50.18%
===============================================
Files 22 22
Lines 1456 1503 +47
===============================================
- Hits 1245 531 -714
- Misses 211 972 +761 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the changes in sghmc.jl are quite similar to the ones in this file, so I added some comments explaining.
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} | ||
logdensity::L | ||
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sampler state, traditionally, has included the LogDensityFunction
as a field so that it doesn't need to be re-constructed on each iteration from the model + varinfo. This is no longer necessary because the LDF is itself an argument to AbstractMCMC.step
.
# Ensure that initial sample is in unconstrained space. | ||
if !DynamicPPL.islinked(vi) | ||
vi = DynamicPPL.link!!(vi, model) | ||
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) | ||
end | ||
|
||
# Define log-density function. | ||
ℓ = DynamicPPL.LogDensityFunction( | ||
model, | ||
vi, | ||
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); | ||
adtype=spl.alg.adtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this stuff is now handled inside AbstractMCMC.sample()
, so there's no longer a need to duplicate this code inside every initialstep
method.
function AbstractMCMC.bundle_samples( | ||
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, | ||
model::AbstractModel, | ||
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction}, | ||
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, | ||
state, | ||
chain_type::Type{MCMCChains.Chains}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature is not super ideal, but it minimises breakage right now. Eventually when everything is fixed we can change the Union to just LDF.
# Handle setting `nadapts` and `discard_initial` | ||
function AbstractMCMC.sample( | ||
rng::AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::Sampler{<:AdaptiveHamiltonian}, | ||
N::Integer; | ||
chain_type=DynamicPPL.default_chain_type(sampler), | ||
resume_from=nothing, | ||
initial_state=DynamicPPL.loadstate(resume_from), | ||
progress=PROGRESS[], | ||
nadapts=sampler.alg.n_adapts, | ||
discard_adapt=true, | ||
discard_initial=-1, | ||
kwargs..., | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this overload was purely to modify the kwargs to sample. I ditched it in favour of adding a new hook, update_sample_kwargs
, which does the same thing without abusing multiple dispatch. I think that function does the same thing. It's quite hard to prove this, although separating it into a different function does allow us to write unit tests for it, so that's another benefit.
(Note that overloading sample() for individual samplers like this is quite precarious because we can't recursively call AbstractMCMC.sample
or we will end up with infinite recursion -- it has to call mcmcsample
. So, there's no way to 'extend' this with extra behaviour by e.g. calling another method of sample
before calling mcmcsample
.)
for alg in (:HMC, :HMCDA, :NUTS) | ||
@eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT | ||
end | ||
getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT | ||
getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT | ||
getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metaprogramming is cool and all, but this wasn't really necessary, imo.
@testset "$(alg)" for alg in algs | ||
# Construct a HMC state by taking a single step | ||
vi = DynamicPPL.VarInfo(gdemo_default) | ||
vi = DynamicPPL.link(vi, gdemo_default) | ||
ldf = LogDensityFunction(gdemo_default, vi; adtype=Turing.DEFAULT_ADTYPE) | ||
spl = Sampler(alg) | ||
hmc_state = DynamicPPL.initialstep( | ||
Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default) | ||
)[2] | ||
_, hmc_state = AbstractMCMC.step(Random.default_rng(), ldf, spl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally, I think this test reveals one drawback of the current proposal: it becomes more annoying to directly call the AbstractMCMC interface. Let's say we want to benchmark the first step of a given sampler (for example, we were doing this the other day on the Gibbs sampler). Previously, we'd do:
rng = Random.default_rng()
model = ...
spl = ...
@be AbstractMCMC.step(rng, model, spl)
Now, we have to do:
rng = Random.default_rng()
model = ...
vi = link(VarInfo(model), model)
ldf = LogDensityFunction(model, vi; adtype=AutoForwardDiff())
spl = ...
@be AbstractMCMC.step(rng, ldf, spl)
I think this is a fairly small price to pay because the occasions where we reach directly for AbstractMCMC interface are quite few, and the code simplification is more important than this. But I thought this was probably something just worth noting.
This would be less problematic if we introduced more convenient constructors for LDF: TuringLang/DynamicPPL.jl#863 so it might be worth keeping that in mind.
This PR moves in the general direction of #2555.
It will take a long time to get everything to work, so I am trying to do this incrementally.
TODO
sample
interfaceHamiltonian
s satisfy theInferenceAlgorithm
interfaceSummary
The fundamental idea (see #2555) is that we want
sample(::Union{Model, LogDensityFunction}, ::Union{InferenceAlgorithm, Sampler{<:InferenceAlgorithm}}, N)
to always forward to
sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm}, N)
(along the way, we construct the LDF if we need to, and also construct the sampler if we need to).
Then, the concrete AbstractMCMC interface functions (i.e.,
mcmcsample
,step
) do not ever see a Model, they only see an LDF.The future of
DynamicPPL.initialstep
Note that this allows us to sidestep (and eventually, completely remove)
DynamicPPL.initialstep
. The reason why that function exists is becauseAbstractMCMC.step
would do two things: first, generate the VarInfo that would eventually go into the LDF, and secondly, callinitialstep
(which was sampler-specific). Since the VarInfo generation bit is now handled in the LDF construction, it means that instead of having an extra function, we can just go back to implementing the two basicAbstractMCMC.step
methods, which is a nice bonus.Changes in this PR
Changing this all at once is bound to be not only impossible to do but also impossible to review. Thus, I've decided to (try to) implement this in a few stages. This PR is probably the one that makes the most sweeping changes, and also establishes the interface required. It:
Establishes the desired method dispatch behaviour for
sample
(seesrc/mcmc/abstractmcmc.jl
). Because we aren't ready to extend this to every sampler and inference algorithm yet, these methods dispatch only onLDFCompatibleAlgorithm
orLDFCompatibleSampler
, which are defined at the top of the file. The idea is that we'll add samplers as we go along, and one day we'll eventually be ready to remove this type and just useInferenceAlgorithm
.When automatically constructing the LDF, there are a few things that we need to know to construct it properly:
This PR therefore also introduces interface functions that all (LDF-compatible) samplers must conform to, namely
requires_unconstrained_space(::AbstractSampler)
andget_adtype(::AbstractSampler)
. Sensible defaults oftrue
andnothing
are given. Note that these functions were already floating around the Turing codebase, so all I've really done is to bring it together and actually write docstrings for them.Finally, there is an
update_sample_kwargs
function which samplers can use as a hook to modify the keyword arguments sent tosample()
. See comments below for more details.Fortunately
This doesn't actually require any changes to DynamicPPL, which I found to be a huge relief!
It's likely that some of the code in this PR will eventually be moved to DynamicPPL, as they don't have any non-DynamicPPL dependencies. But that can be handled very easily at a later stage, once we're confident that this all works.
Unfortunately
Changing the interface one sampler at a time completely breaks Gibbs, because for Gibbs to work, it requires all of its component samplers to be updated. So we may have to live with the Gibbs tests being broken for a while, and rely on me promising that I'll fix it at some point in time. In this PR, I've disabled the Gibbs tests that live outside
test/mcmc/gibbs.jl
.Because I don't know how long this will take me, I don't even want to merge this into
breaking
, as I don't want to have a new release held up by the fact that only half the changes have been done. I've created a new base branch,sample-ldf
, to collect all the work on this. When we're happy with it, we can merge that intobreaking
.