Skip to content

sample with LogDensityFunction: part 1 - hmc.jl, sghmc.jl, DynamicHMCExt #2588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: sample-ldf
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jun 8, 2025

This PR moves in the general direction of #2555.

It will take a long time to get everything to work, so I am trying to do this incrementally.

TODO

  • Fix existing Hamiltonian tests
  • Disable Gibbs tests (specifically, Gibbs tests that use Hamiltonians)
  • Add new tests for the new sample interface
  • Add new tests to make sure that the Hamiltonians satisfy the InferenceAlgorithm interface

Summary

The fundamental idea (see #2555) is that we want

sample(::Union{Model, LogDensityFunction}, ::Union{InferenceAlgorithm, Sampler{<:InferenceAlgorithm}}, N)

to always forward to

sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm}, N)

(along the way, we construct the LDF if we need to, and also construct the sampler if we need to).

Then, the concrete AbstractMCMC interface functions (i.e., mcmcsample, step) do not ever see a Model, they only see an LDF.

The future of DynamicPPL.initialstep

Note that this allows us to sidestep (and eventually, completely remove) DynamicPPL.initialstep. The reason why that function exists is because AbstractMCMC.step would do two things: first, generate the VarInfo that would eventually go into the LDF, and secondly, call initialstep (which was sampler-specific). Since the VarInfo generation bit is now handled in the LDF construction, it means that instead of having an extra function, we can just go back to implementing the two basic AbstractMCMC.step methods, which is a nice bonus.

Changes in this PR

Changing this all at once is bound to be not only impossible to do but also impossible to review. Thus, I've decided to (try to) implement this in a few stages. This PR is probably the one that makes the most sweeping changes, and also establishes the interface required. It:

  1. Establishes the desired method dispatch behaviour for sample (see src/mcmc/abstractmcmc.jl). Because we aren't ready to extend this to every sampler and inference algorithm yet, these methods dispatch only on LDFCompatibleAlgorithm or LDFCompatibleSampler, which are defined at the top of the file. The idea is that we'll add samplers as we go along, and one day we'll eventually be ready to remove this type and just use InferenceAlgorithm.

  2. When automatically constructing the LDF, there are a few things that we need to know to construct it properly:

    • Does the VarInfo need to be linked?
    • Does the LDF need to be constructed with an adtype?

    This PR therefore also introduces interface functions that all (LDF-compatible) samplers must conform to, namely requires_unconstrained_space(::AbstractSampler) and get_adtype(::AbstractSampler). Sensible defaults of true and nothing are given. Note that these functions were already floating around the Turing codebase, so all I've really done is to bring it together and actually write docstrings for them.

  3. Finally, there is an update_sample_kwargs function which samplers can use as a hook to modify the keyword arguments sent to sample(). See comments below for more details.

Fortunately

This doesn't actually require any changes to DynamicPPL, which I found to be a huge relief!

It's likely that some of the code in this PR will eventually be moved to DynamicPPL, as they don't have any non-DynamicPPL dependencies. But that can be handled very easily at a later stage, once we're confident that this all works.

Unfortunately

Changing the interface one sampler at a time completely breaks Gibbs, because for Gibbs to work, it requires all of its component samplers to be updated. So we may have to live with the Gibbs tests being broken for a while, and rely on me promising that I'll fix it at some point in time. In this PR, I've disabled the Gibbs tests that live outside test/mcmc/gibbs.jl.

Because I don't know how long this will take me, I don't even want to merge this into breaking, as I don't want to have a new release held up by the fact that only half the changes have been done. I've created a new base branch, sample-ldf, to collect all the work on this. When we're happy with it, we can merge that into breaking.

@penelopeysm penelopeysm changed the base branch from sample-ldf to main June 8, 2025 13:43
@penelopeysm penelopeysm changed the base branch from main to sample-ldf June 8, 2025 13:43
Copy link
Contributor

github-actions bot commented Jun 8, 2025

Turing.jl documentation for PR #2588 is available at:
https://TuringLang.github.io/Turing.jl/previews/PR2588/

@penelopeysm penelopeysm changed the title sample with LogDensityFunction: part 1 - HMC sample with LogDensityFunction: part 1 - hmc.jl + sghmc.jl Jun 8, 2025
@penelopeysm penelopeysm changed the title sample with LogDensityFunction: part 1 - hmc.jl + sghmc.jl sample with LogDensityFunction: part 1 - hmc.jl, sghmc.jl, DynamicHMCExt Jun 8, 2025
Copy link

codecov bot commented Jun 8, 2025

Codecov Report

Attention: Patch coverage is 59.15493% with 58 lines in your changes missing coverage. Please review.

Project coverage is 35.32%. Comparing base (e84aec1) to head (a527fd8).

Files with missing lines Patch % Lines
src/mcmc/abstractmcmc.jl 63.79% 21 Missing ⚠️
src/mcmc/sghmc.jl 0.00% 16 Missing ⚠️
ext/TuringDynamicHMCExt.jl 0.00% 7 Missing ⚠️
src/mcmc/algorithm.jl 54.54% 5 Missing ⚠️
src/mcmc/hmc.jl 82.14% 5 Missing ⚠️
src/mcmc/gibbs.jl 0.00% 2 Missing ⚠️
src/mcmc/Inference.jl 83.33% 1 Missing ⚠️
src/mcmc/mh.jl 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (e84aec1) and HEAD (a527fd8). Click for more details.

HEAD has 22 uploads less than BASE
Flag BASE (e84aec1) HEAD (a527fd8)
28 6
Additional details and impacted files
@@               Coverage Diff               @@
##           sample-ldf    #2588       +/-   ##
===============================================
- Coverage       85.50%   35.32%   -50.18%     
===============================================
  Files              22       22               
  Lines            1456     1503       +47     
===============================================
- Hits             1245      531      -714     
- Misses            211      972      +761     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the changes in sghmc.jl are quite similar to the ones in this file, so I added some comments explaining.

struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
logdensity::L
struct DynamicNUTSState{V<:DynamicPPL.AbstractVarInfo,C,M,S}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler state, traditionally, has included the LogDensityFunction as a field so that it doesn't need to be re-constructed on each iteration from the model + varinfo. This is no longer necessary because the LDF is itself an argument to AbstractMCMC.step.

Comment on lines -58 to -70
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi)
vi = DynamicPPL.link!!(vi, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
end

# Define log-density function.
= DynamicPPL.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this stuff is now handled inside AbstractMCMC.sample(), so there's no longer a need to duplicate this code inside every initialstep method.

Comment on lines 245 to 250
function AbstractMCMC.bundle_samples(
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
state,
chain_type::Type{MCMCChains.Chains};
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature is not super ideal, but it minimises breakage right now. Eventually when everything is fixed we can change the Union to just LDF.

Comment on lines -85 to -99
# Handle setting `nadapts` and `discard_initial`
function AbstractMCMC.sample(
rng::AbstractRNG,
model::DynamicPPL.Model,
sampler::Sampler{<:AdaptiveHamiltonian},
N::Integer;
chain_type=DynamicPPL.default_chain_type(sampler),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
nadapts=sampler.alg.n_adapts,
discard_adapt=true,
discard_initial=-1,
kwargs...,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this overload was purely to modify the kwargs to sample. I ditched it in favour of adding a new hook, update_sample_kwargs, which does the same thing without abusing multiple dispatch. I think that function does the same thing. It's quite hard to prove this, although separating it into a different function does allow us to write unit tests for it, so that's another benefit.

(Note that overloading sample() for individual samplers like this is quite precarious because we can't recursively call AbstractMCMC.sample or we will end up with infinite recursion -- it has to call mcmcsample. So, there's no way to 'extend' this with extra behaviour by e.g. calling another method of sample before calling mcmcsample.)

Comment on lines -470 to +436
for alg in (:HMC, :HMCDA, :NUTS)
@eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT
end
getmetricT(::HMC{<:Any,metricT}) where {metricT} = metricT
getmetricT(::HMCDA{<:Any,metricT}) where {metricT} = metricT
getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metaprogramming is cool and all, but this wasn't really necessary, imo.

Comment on lines 298 to +304
@testset "$(alg)" for alg in algs
# Construct a HMC state by taking a single step
vi = DynamicPPL.VarInfo(gdemo_default)
vi = DynamicPPL.link(vi, gdemo_default)
ldf = LogDensityFunction(gdemo_default, vi; adtype=Turing.DEFAULT_ADTYPE)
spl = Sampler(alg)
hmc_state = DynamicPPL.initialstep(
Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default)
)[2]
_, hmc_state = AbstractMCMC.step(Random.default_rng(), ldf, spl)
Copy link
Member Author

@penelopeysm penelopeysm Jun 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally, I think this test reveals one drawback of the current proposal: it becomes more annoying to directly call the AbstractMCMC interface. Let's say we want to benchmark the first step of a given sampler (for example, we were doing this the other day on the Gibbs sampler). Previously, we'd do:

rng = Random.default_rng()
model = ...
spl = ...
@be AbstractMCMC.step(rng, model, spl)

Now, we have to do:

rng = Random.default_rng()
model = ...
vi = link(VarInfo(model), model)
ldf = LogDensityFunction(model, vi; adtype=AutoForwardDiff())
spl = ...
@be AbstractMCMC.step(rng, ldf, spl)

I think this is a fairly small price to pay because the occasions where we reach directly for AbstractMCMC interface are quite few, and the code simplification is more important than this. But I thought this was probably something just worth noting.

This would be less problematic if we introduced more convenient constructors for LDF: TuringLang/DynamicPPL.jl#863 so it might be worth keeping that in mind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant