diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d0e00b45f..745726d46 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -6,9 +6,7 @@ import_to_using = false # These ignores should be removed once the relevant PRs are merged/closed. ignore = [ # https://github.com/TuringLang/Turing.jl/pull/2231/files - "src/experimental/gibbs.jl", "src/mcmc/abstractmcmc.jl", - "test/experimental/gibbs.jl", "test/test_utils/numerical_tests.jl", # https://github.com/TuringLang/Turing.jl/pull/2218/files "src/mcmc/Inference.jl", diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8de296e5e..770eab9a7 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -22,9 +22,8 @@ jobs: - "mcmc/hmc.jl" - "mcmc/abstractmcmc.jl" - "mcmc/Inference.jl" - - "experimental/gibbs.jl" - "mcmc/ess.jl" - - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" + - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl" version: - '1.7' - '1' @@ -79,7 +78,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 # TODO: Use julia-actions/julia-runtest when test_args are supported # Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs - # Ref https://github.com/julia-actions/julia-runtest/pull/73 + # Ref https://github.com/julia-actions/julia-runtest/pull/73 - name: Call Pkg.test run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/HISTORY.md b/HISTORY.md index 5b1cad0ed..11d08e12c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# Release 0.35.0 + +## Breaking changes + +0.35.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. + +The new Gibbs sampler supports the same user-facing interface as the old one. However, given +that the internals of it having been completely rewritten in a very different manner, there +may be accidental breakage that we haven't anticipated. Please report any you find. + +`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. + +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. + +Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`. + # Release 0.33.0 ## Breaking changes diff --git a/Project.toml b/Project.toml index 5f5c86b04..e22c4f4ae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.34.1" +version = "0.35.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/Turing.jl b/src/Turing.jl index 8dfb8df28..027c190a3 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -55,7 +55,6 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation -include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release ########### @@ -86,7 +85,6 @@ export @model, # modelling Emcee, ESS, Gibbs, - GibbsConditional, HMC, # Hamiltonian-like sampling SGLD, SGHMC, diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl deleted file mode 100644 index 518538e6c..000000000 --- a/src/experimental/Experimental.jl +++ /dev/null @@ -1,16 +0,0 @@ -module Experimental - -using Random: Random -using AbstractMCMC: AbstractMCMC -using DynamicPPL: DynamicPPL, VarName -using Accessors: Accessors - -using DocStringExtensions: TYPEDFIELDS -using Distributions - -using ..Turing: Turing -using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm - -include("gibbs.jl") - -end diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl deleted file mode 100644 index 596e6e283..000000000 --- a/src/experimental/gibbs.jl +++ /dev/null @@ -1,488 +0,0 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -# Needed to support broadcasting over columns for `MultivariateDistribution`s. -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - -""" - preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) - -Returns the preferred value type for a variable with the given `varinfo`. -""" -preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict -preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple -function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) - # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. - namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,typeof(identity)} - end - return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict -end - -""" - condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) - -Return a `GibbsContext` with the given values treated as conditioned. - -# Arguments -- `context::DynamicPPL.AbstractContext`: The context to condition. -- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. - If multiple values are provided, we recursively condition on each of them. -""" -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end - -# For `DynamicPPL.AbstractVarInfo` we just extract the values. -""" - condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) - -Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. -""" -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) -end -function condition_gibbs( - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo, - varinfos::DynamicPPL.AbstractVarInfo... -) - return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) -end -# Allow calling this on a `DynamicPPL.Model` directly. -function condition_gibbs(model::DynamicPPL.Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end - - -""" - make_conditional_model(model, varinfo, varinfos) - -Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. - -# Examples -```julia-repl -julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); - -julia> # A separate varinfo for each variable in `model`. - varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); - -julia> # The varinfo we want to NOT condition on. - target_varinfo = first(varinfos); - -julia> # Results in a model with only `m` conditioned. - conditioned_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos); - -julia> result = conditioned_model(); - -julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` -true - -julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` -true -``` -""" -function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos) - # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs( - model, - filter(Base.Fix1(!==, target_varinfo), varinfos)... - ) -end -# Assumes the ones given are the ones to condition on. -function make_conditional(model::DynamicPPL.Model, varinfos) - return condition_gibbs( - model, - varinfos... - ) -end - -# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` -# or an `AbstractInferenceAlgorithm`. -wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - -""" - Gibbs - -A type representing a Gibbs sampler. - -# Fields -$(TYPEDFIELDS) -""" -struct Gibbs{V,A} <: InferenceAlgorithm - "varnames representing variables for each sampler" - varnames::V - "samplers for each entry in `varnames`" - samplers::A -end - -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs( - map(s -> VarName{s}(), keys(algs)), - map(wrap_algorithm_maybe, values(algs)), - ) -end - -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) -end -function Gibbs(algs::Pair...) - return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) -end - -# TODO: Remove when no longer needed. -DynamicPPL.getspace(::Gibbs) = () - -struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} - vi::V - states::S -end - -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] - -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - vi_base::DynamicPPL.AbstractVarInfo; - initial_params=nothing, - kwargs..., -) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers - - # 1. Run the model once to get the varnames present + initial values to condition on. - vi_base = DynamicPPL.VarInfo(model) - - # Simple way of setting the initial parameters: set them in the `vi_base` - # if they are given so they propagate to the subset varinfos used by each sampler. - if initial_params !== nothing - vi_base = DynamicPPL.unflatten(vi_base, initial_params) - end - - # Create the varinfos for each sampler. - varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) - initial_params_all = if initial_params === nothing - fill(nothing, length(varnames)) - else - # Extract from the `vi_base`, which should have the values set correctly from above. - map(vi -> vi[:], varinfos) - end - - # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local - # Construct the conditional model. - model_local = make_conditional(model, varinfo_local, varinfos) - - # Take initial step. - new_state_local = last(AbstractMCMC.step( - rng, model_local, sampler_local; - # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. - # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, - kwargs... - )) - - # Return the new state and the invlinked `varinfo`. - vi_local_state = Turing.Inference.varinfo(new_state_local) - vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink(vi_local_state, sampler_local, model_local) - else - vi_local_state - end - return (new_state_local, vi_local_state_linked) - end - - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, - state::GibbsState; - kwargs..., -) - alg = spl.alg - samplers = alg.samplers - states = state.states - varinfos = map(Turing.Inference.varinfo, state.states) - @assert length(samplers) == length(state.states) - - # TODO: move this into a recursive function so we can unroll when reasonable? - for index = 1:length(samplers) - # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, - model, - samplers, - states, - varinfos, - index; - kwargs..., - ) - - # Update the `states` and `varinfos`. - states = Accessors.setindex(states, new_state_local, index) - varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) - end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, - merge(vi_base, first(varinfos)), - firstindex(varinfos), - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Turing.Inference.Transition(model, vi), GibbsState(vi, states) -end - -# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) - # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide - # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact - # same `selector` as before but now with `rerun` set to `true` if needed. - return Accessors.@set sampler.selector.rerun = true -end - -# Interface we need a sampler to implement to work as a component in a Gibbs sampler. -""" - gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - -Check if the log-probability of the destination model needs to be recomputed. - -Defaults to `true` -""" -function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - return true -end - -# TODO: Remove `rng`? -function Turing.Inference.recompute_logprob!!( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler, - state -) - varinfo = Turing.Inference.varinfo(state) - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler) - # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed - # `varinfo`, even if `varinfo` was linked. - varinfo_new = last(DynamicPPL.evaluate!!( - model, - varinfo, - # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. - DynamicPPL.SamplingContext(rng, sampler_rerun) - )) - # Update the state we're about to use if need be. - # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) -end - -function gibbs_step_inner( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - samplers, - states, - varinfos, - index; - kwargs..., -) - # Needs to do a a few things. - sampler_local = samplers[index] - state_local = states[index] - varinfo_local = varinfos[index] - - # Make sure that all `varinfos` are linked. - varinfos_invlinked = map(varinfos) do vi - # NOTE: This is immutable linking! - # TODO: Do we need the `istrans` check here or should we just always use `invlink`? - # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 - DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi - end - varinfo_local_invlinked = varinfos_invlinked[index] - - # 1. Create conditional model. - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) - - # Extract the previous sampler and state. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - state_previous = states[index == 1 ? length(states) : index - 1] - - # 1. Re-run the sampler if needed. - if gibbs_requires_recompute_logprob( - model_local, - sampler_local, - sampler_previous, - state_local, - state_previous - ) - state_local = Turing.Inference.recompute_logprob!!( - rng, - model_local, - sampler_local, - state_local, - ) - end - - # 2. Take step with local sampler. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local, - state_local; - kwargs..., - ), - ) - - # 3. Extract the new varinfo. - # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = Turing.Inference.varinfo(new_state_local) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - - # TODO: alternatively, we can return `states_new, varinfos_new, index_new` - return (new_state_local, varinfo_local_state_invlinked) -end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index b7bdf206b..495559871 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -46,7 +46,6 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling - GibbsConditional, HMC, SGLD, PolynomialStepsize, @@ -63,7 +62,6 @@ export InferenceAlgorithm, observe, dot_observe, predict, - isgibbscomponent, externalsampler ####################### @@ -526,22 +524,21 @@ end # Concrete algorithm implementations. # ####################################### +include("abstractmcmc.jl") include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") include("particle_mcmc.jl") -include("gibbs_conditional.jl") include("gibbs.jl") include("sghmc.jl") include("emcee.jl") -include("abstractmcmc.jl") ################ # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee) +for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index a350d2908..965c79706 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -27,6 +27,10 @@ function varinfo(state::TuringState) # TODO: Do we need to link here first? return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) end +varinfo(state::AbstractVarInfo) = state +# TODO(mhauru) Could we have a type bound on the argument below, for documentation purposes? +varinfo(state) = state.vi + # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 736845b67..571d694e3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,101 +1,220 @@ -### -### Gibbs samplers / compositional samplers. -### +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. +struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + values::Values + context::Ctx +end -""" - isgibbscomponent(alg) +Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) -Determine whether algorithm `alg` is allowed as a Gibbs component. -""" -isgibbscomponent(alg) = false +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) + return GibbsContext(context.values, childcontext) +end + +# has and get +function has_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.hasvalue(context.values, vn) +end +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(has_conditioned_gibbs, context), vns) +end -isgibbscomponent(::ESS) = true -isgibbscomponent(::GibbsConditional) = true -isgibbscomponent(::Hamiltonian) = true -isgibbscomponent(::MH) = true -isgibbscomponent(::PG) = true +function get_conditioned_gibbs(context::GibbsContext, vn::VarName) + return DynamicPPL.getvalue(context.values, vn) +end +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end -const TGIBBS = Union{InferenceAlgorithm,GibbsConditional} +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end -""" - Gibbs(algs...) + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) +end -Compositional MCMC interface. Gibbs sampling combines one or more -sampling algorithms, each of which samples from a different set of -variables in a model. +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end -Example: -```julia -@model function gibbs_example(x) - v1 ~ Normal(0,1) - v2 ~ Categorical(5) + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) end -# Use PG for a 'v2' variable, and use HMC for the 'v1' variable. -# Note that v2 is discrete, so the PG sampler is more appropriate -# than is HMC. -alg = Gibbs(HMC(0.2, 3, :v1), PG(20, :v2)) -``` +# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. +make_broadcastable(x) = x +make_broadcastable(dist::Distribution) = tuple(dist) -One can also pass the number of iterations for each Gibbs component using the following syntax: -- `alg = Gibbs((HMC(0.2, 3, :v1), n_hmc), (PG(20, :v2), n_pg))` -where `n_hmc` and `n_pg` are the number of HMC and PG iterations for each Gibbs iteration. +# Need the following two methods to properly support broadcasting over columns. +broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) +function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) + return loglikelihood(dist, x) +end -Tips: -- `HMC` and `NUTS` are fast samplers and can throw off particle-based -methods like Particle Gibbs. You can increase the effectiveness of particle sampling by including -more particles in the particle sampler. -""" -struct Gibbs{space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} <: InferenceAlgorithm - algs::A # component sampling algorithms - iterations::B - function Gibbs{space,N,A,B}( - algs::A, iterations::B - ) where {space,N,A<:NTuple{N,TGIBBS},B<:NTuple{N,Int}} - all(isgibbscomponent, algs) || - error("all algorithms have to support Gibbs sampling") - return new{space,N,A,B}(algs, iterations) +# Needed to support broadcasting over columns for `MultivariateDistribution`s. +reconstruct_getvalue(dist, x) = x +function reconstruct_getvalue( + dist::MultivariateDistribution, x::AbstractVector{<:AbstractVector{<:Real}} +) + return reduce(hcat, x[2:end]; init=x[1]) +end + +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vns, vi + ) +end + +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi + ) end -function Gibbs(alg1::TGIBBS, algrest::Vararg{TGIBBS,N}) where {N} - algs = (alg1, algrest...) - iterations = ntuple(Returns(1), Val(N + 1)) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +""" + preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) + +Returns the preferred value type for a variable with the given `varinfo`. +""" +preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict +preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple +function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) + # We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`. + namedtuple_compatible = all(varinfo.metadata) do md + eltype(md.vns) <: VarName{<:Any,typeof(identity)} + end + return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end -function Gibbs(arg1::Tuple{<:TGIBBS,Int}, argrest::Vararg{Tuple{<:TGIBBS,Int},N}) where {N} - allargs = (arg1, argrest...) - algs = map(first, allargs) - iterations = map(last, allargs) - # obtain space for sampling algorithms - space = Tuple(union(getspace.(algs)...)) - return Gibbs{space,N + 1,typeof(algs),typeof(iterations)}(algs, iterations) +""" + condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + +Return a `GibbsContext` with the given values treated as conditioned. + +# Arguments +- `context::DynamicPPL.AbstractContext`: The context to condition. +- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. + If multiple values are provided, we recursively condition on each of them. +""" +condition_gibbs(context::DynamicPPL.AbstractContext) = context +# For `NamedTuple` and `AbstractDict` we just construct the context. +function condition_gibbs( + context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict} +) + return GibbsContext(values, context) +end +# If we get more than one argument, we just recurse. +function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) + return condition_gibbs(condition_gibbs(context, value), values...) end +# For `DynamicPPL.AbstractVarInfo` we just extract the values. """ - GibbsState{V<:VarInfo, S<:Tuple{Vararg{Sampler}}} + condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) -Stores a `VarInfo` for use in sampling, and a `Tuple` of `Samplers` that -the `Gibbs` sampler iterates through for each `step!`. +Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. """ -struct GibbsState{V<:VarInfo,S<:Tuple{Vararg{Sampler}},T} - vi::V - samplers::S - states::T +function condition_gibbs( + context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo +) + return condition_gibbs( + context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)) + ) +end +function condition_gibbs( + context::DynamicPPL.AbstractContext, + varinfo::DynamicPPL.AbstractVarInfo, + varinfos::DynamicPPL.AbstractVarInfo..., +) + return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) +end +# Allow calling this on a `DynamicPPL.Model` directly. +function condition_gibbs(model::DynamicPPL.Model, values...) + return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) end -# extract varinfo object from state """ - gibbs_varinfo(model, sampler, state) + make_conditional_model(model, varinfo, varinfos) + +Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. + +# Examples +```julia-repl +julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); + +julia> # A separate varinfo for each variable in `model`. + varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); + +julia> # The varinfo we want to NOT condition on. + target_varinfo = first(varinfos); + +julia> # Results in a model with only `m` conditioned. + conditioned_model = make_conditional(model, target_varinfo, varinfos); + +julia> result = conditioned_model(); -Return the variables corresponding to the current `state` of the Gibbs component `sampler`. +julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` +true + +julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` +true +``` """ -gibbs_varinfo(model, sampler, state) = varinfo(state) -varinfo(state) = state.vi -varinfo(state::AbstractVarInfo) = state +function make_conditional( + model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos +) + # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. + return condition_gibbs(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) +end +# Assumes the ones given are the ones to condition on. +function make_conditional(model::DynamicPPL.Model, varinfos) + return condition_gibbs(model, varinfos...) +end + +# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` +# or an `AbstractInferenceAlgorithm`. +wrap_algorithm_maybe(x) = x +wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) """ gibbs_state(model, sampler, state, varinfo) @@ -130,122 +249,298 @@ function gibbs_state( end """ - gibbs_rerun(prev_alg, alg) + Gibbs -Check if the model should be rerun to recompute the log density before sampling with the -Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. +A type representing a Gibbs sampler. -By default, the function returns `true`. +# Fields +$(TYPEDFIELDS) """ -gibbs_rerun(prev_alg, alg) = true +struct Gibbs{V,A} <: InferenceAlgorithm + "varnames representing variables for each sampler" + varnames::V + "samplers for each entry in `varnames`" + samplers::A +end + +# NamedTuple +Gibbs(; algs...) = Gibbs(NamedTuple(algs)) +function Gibbs(algs::NamedTuple) + return Gibbs( + map(s -> VarName{s}(), keys(algs)), map(wrap_algorithm_maybe, values(algs)) + ) +end + +# AbstractDict +function Gibbs(algs::AbstractDict) + return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) +end +function Gibbs(algs::Pair...) + return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms -gibbs_rerun(::GibbsConditional, ::MH) = false -gibbs_rerun(::Hamiltonian, ::MH) = false +# The below two constructors only provide backwards compatibility with the constructor of +# the old Gibbs sampler. They are deprecated and will be removed in the future. +function Gibbs(algs::InferenceAlgorithm...) + varnames = map(algs) do alg + space = getspace(alg) + if (space isa VarName) + space + elseif (space isa Symbol) + VarName{space}() + else + tuple((s isa Symbol ? VarName{s}() : s for s in space)...) + end + end + msg = ( + "Specifying which sampler to use with which variable using syntax like " * + "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * + "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * + "counts for different subsamplers, use e.g. " * + "`Gibbs(@varname(x) => NUTS(), @varname(x) => NUTS(), @varname(y) => MH())`" + ) + Base.depwarn(msg, :Gibbs) + return Gibbs(varnames, map(wrap_algorithm_maybe, algs)) +end + +function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) + algs = Iterators.map(first, algs_with_iters) + iters = Iterators.map(last, algs_with_iters) + algs_duplicated = Iterators.flatten(( + Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters) + )) + # This calls the other deprecated constructor from above, hence no need for a depwarn + # here. + return Gibbs(algs_duplicated...) +end -# `vi.logp` already contains the log joint probability if the previous sampler -# used a `GibbsConditional` or a `MH` algorithm -gibbs_rerun(::MH, ::Hamiltonian) = false -gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false +# TODO: Remove when no longer needed. +DynamicPPL.getspace(::Gibbs) = () -# do not have to recompute `vi.logp` since it is not used in `step` -gibbs_rerun(prev_alg, ::GibbsConditional) = false +struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} + vi::V + states::S +end -# Do not recompute `vi.logp` since it is reset anyway in `step` -gibbs_rerun(prev_alg, ::PG) = false +_maybevec(x) = vec(x) # assume it's iterable +_maybevec(x::Tuple) = [x...] +_maybevec(x::VarName) = [x] -# Initialize the Gibbs sampler. function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, vi::AbstractVarInfo; kwargs... + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + vi_base::DynamicPPL.AbstractVarInfo; + initial_params=nothing, + kwargs..., ) - # TODO: Technically this only works for `VarInfo` or `ThreadSafeVarInfo{<:VarInfo}`. - # Should we enforce this? - - # Create tuple of samplers - algs = spl.alg.algs - i = 0 - samplers = map(algs) do alg - i += 1 - if i == 1 - prev_alg = algs[end] - else - prev_alg = algs[i - 1] - end - rerun = gibbs_rerun(prev_alg, alg) - selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun) - Sampler(alg, model, selector) - end + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + + # 1. Run the model once to get the varnames present + initial values to condition on. + vi_base = DynamicPPL.VarInfo(rng, model) - # Add Gibbs to gids for all variables. - for sym in keys(vi.metadata) - vns = getfield(vi.metadata, sym).vns + # Simple way of setting the initial parameters: set them in the `vi_base` + # if they are given so they propagate to the subset varinfos used by each sampler. + if initial_params !== nothing + vi_base = DynamicPPL.unflatten(vi_base, initial_params) + end - for vn in vns - # update the gid for the Gibbs sampler - DynamicPPL.updategid!(vi, vn, spl) + # Create the varinfos for each sampler. + varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + initial_params_all = if initial_params === nothing + fill(nothing, length(varnames)) + else + # Extract from the `vi_base`, which should have the values set correctly from above. + map(vi -> vi[:], varinfos) + end - # try to store each subsampler's gid in the VarInfo - for local_spl in samplers - DynamicPPL.updategid!(vi, vn, local_spl) - end + # 2. Construct a varinfo for every vn + sampler combo. + states_and_varinfos = map( + samplers, varinfos, initial_params_all + ) do sampler_local, varinfo_local, initial_params_local + # Construct the conditional model. + model_local = make_conditional(model, varinfo_local, varinfos) + + # Take initial step. + new_state_local = last( + AbstractMCMC.step( + rng, + model_local, + sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs..., + ), + ) + + # Return the new state and the invlinked `varinfo`. + vi_local_state = varinfo(new_state_local) + vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) + DynamicPPL.invlink(vi_local_state, sampler_local, model_local) + else + vi_local_state end + return (new_state_local, vi_local_state_linked) end - # Compute initial states of the local samplers. - states = map(samplers) do local_spl - # Recompute `vi.logp` if needed. - if local_spl.selector.rerun - vi = last( - DynamicPPL.evaluate!!( - model, vi, DynamicPPL.SamplingContext(rng, local_spl) - ), - ) - end + states = map(first, states_and_varinfos) + varinfos = map(last, states_and_varinfos) - # Compute initial state. - _, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...) + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - # Update `VarInfo` object. - vi = gibbs_varinfo(model, local_spl, state) + return Transition(model, vi), GibbsState(vi, states) +end - return state +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; + kwargs..., +) + alg = spl.alg + samplers = alg.samplers + states = state.states + varinfos = map(varinfo, state.states) + @assert length(samplers) == length(state.states) + + # TODO: move this into a recursive function so we can unroll when reasonable? + for index in 1:length(samplers) + # Take the inner step. + new_state_local, new_varinfo_local = gibbs_step_inner( + rng, model, samplers, states, varinfos, index; kwargs... + ) + + # Update the `states` and `varinfos`. + states = Accessors.setindex(states, new_state_local, index) + varinfos = Accessors.setindex(varinfos, new_varinfo_local, index) end - # Compute initial transition and state. - transition = Transition(model, vi) - state = GibbsState(vi, samplers, states) + # Combine the resulting varinfo objects. + # The last varinfo holds the correctly computed logp. + vi_base = state.vi + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!( + varinfos, merge(vi_base, first(varinfos)), firstindex(varinfos) + ) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)) + ) - return transition, state + return Transition(model, vi), GibbsState(vi, states) end -# Subsequent steps -function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Gibbs}, state::GibbsState; kwargs... -) - # Iterate through each of the samplers. - vi = state.vi - samplers = state.samplers - states = map(samplers, spl.alg.iterations, state.states) do _sampler, iteration, _state - # Recompute `vi.logp` if needed. - if _sampler.selector.rerun - vi = last(DynamicPPL.evaluate!!(model, rng, vi, _sampler)) - end +# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. +function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) + # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide + # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact + # same `selector` as before but now with `rerun` set to `true` if needed. + return Accessors.@set sampler.selector.rerun = true +end - # Update state of current sampler with updated `VarInfo` object. - current_state = gibbs_state(model, _sampler, _state, vi) +# Interface we need a sampler to implement to work as a component in a Gibbs sampler. +""" + gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - # Step through the local sampler. - newstate = current_state - for _ in 1:iteration - _, newstate = AbstractMCMC.step(rng, model, _sampler, newstate; kwargs...) - end +Check if the log-probability of the destination model needs to be recomputed. + +Defaults to `true` +""" +function gibbs_requires_recompute_logprob( + model_dst, sampler_dst, sampler_src, state_dst, state_src +) + return true +end - # Update `VarInfo` object. - vi = gibbs_varinfo(model, _sampler, newstate) +# TODO: Remove `rng`? +function recompute_logprob!!( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, state +) + vi = varinfo(state) + # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, + # e.g. log-likelihood in the scenario of `ESS`. + # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. + sampler_rerun = make_rerun_sampler(model, sampler) + # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed + # `varinfo`, even if `varinfo` was linked. + vi_new = last( + DynamicPPL.evaluate!!( + model, + vi, + # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. + DynamicPPL.SamplingContext(rng, sampler_rerun), + ) + ) + # Update the state we're about to use if need be. + # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + return gibbs_state(model, sampler, state, vi_new) +end + +function gibbs_step_inner( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + samplers, + states, + varinfos, + index; + kwargs..., +) + # Needs to do a a few things. + sampler_local = samplers[index] + state_local = states[index] + varinfo_local = varinfos[index] + + # Make sure that all `varinfos` are linked. + varinfos_invlinked = map(varinfos) do vi + # NOTE: This is immutable linking! + # TODO: Do we need the `istrans` check here or should we just always use `invlink`? + # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 + DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi + end + varinfo_local_invlinked = varinfos_invlinked[index] + + # 1. Create conditional model. + # Construct the conditional model. + # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, + # otherwise we're conditioning on values which are not in the support of the + # distributions. + model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) + + # Extract the previous sampler and state. + sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] + state_previous = states[index == 1 ? length(states) : index - 1] + + # 1. Re-run the sampler if needed. + if gibbs_requires_recompute_logprob( + model_local, sampler_local, sampler_previous, state_local, state_previous + ) + state_local = recompute_logprob!!(rng, model_local, sampler_local, state_local) + end - return newstate + # 2. Take step with local sampler. + new_state_local = last( + AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...) + ) + + # 3. Extract the new varinfo. + # Return the resulting state and invlinked `varinfo`. + varinfo_local_state = varinfo(new_state_local) + varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) + DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) + else + varinfo_local_state end - return Transition(model, vi), GibbsState(vi, samplers, states) + # TODO: alternatively, we can return `states_new, varinfos_new, index_new` + return (new_state_local, varinfo_local_state_invlinked) end diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl deleted file mode 100644 index fda79315b..000000000 --- a/src/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,88 +0,0 @@ -""" - GibbsConditional(sym, conditional) - -A "pseudo-sampler" to manually provide analytical Gibbs conditionals to `Gibbs`. -`GibbsConditional(:x, cond)` will sample the variable `x` according to the conditional `cond`, which -must therefore be a function from a `NamedTuple` of the conditioned variables to a `Distribution`. - - -The `NamedTuple` that is passed in contains all random variables from the model in an unspecified -order, taken from the [`VarInfo`](@ref) object over which the model is run. Scalars and vectors are -stored in their respective shapes. The tuple also contains the value of the conditioned variable -itself, which can be useful, but using it creates something that is not a Gibbs sampler anymore (see -[here](https://github.com/TuringLang/Turing.jl/pull/1275#discussion_r434240387)). - -# Examples - -```julia -α_0 = 2.0 -θ_0 = inv(3.0) -x = [1.5, 2.0] -N = length(x) - -@model function inverse_gdemo(x) - λ ~ Gamma(α_0, θ_0) - σ = sqrt(1 / λ) - m ~ Normal(0, σ) - @. x ~ \$(Normal(m, σ)) -end - -# The conditionals can be formulated in terms of the following statistics: -x_bar = mean(x) # sample mean -s2 = var(x; mean=x_bar, corrected=false) # sample variance -m_n = N * x_bar / (N + 1) - -function cond_m(c) - λ_n = c.λ * (N + 1) - σ_n = sqrt(1 / λ_n) - return Normal(m_n, σ_n) -end - -function cond_λ(c) - α_n = α_0 + (N - 1) / 2 + 1 - β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0) - return Gamma(α_n, inv(β_n)) -end - -m = inverse_gdemo(x) - -sample(m, Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)), 10) -``` -""" -struct GibbsConditional{S,C} - conditional::C - - function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) - end -end - -DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,) - -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - return nothing, vi -end - -function AbstractMCMC.step( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:GibbsConditional}, - vi::AbstractVarInfo; - kwargs..., -) - condvals = DynamicPPL.values_as(DynamicPPL.invlink(vi, model), NamedTuple) - conddist = spl.alg.conditional(condvals) - updated = rand(rng, conddist) - # Setindex allows only vectors in this case. - vi = setindex!!(vi, [updated;], spl) - # Update log joint probability. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - - return nothing, vi -end diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl deleted file mode 100644 index 0f0740f14..000000000 --- a/test/experimental/gibbs.jl +++ /dev/null @@ -1,270 +0,0 @@ -module ExperimentalGibbsTests - -using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo -using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical, two_sample_test -using DynamicPPL -using Random -using Test -using Turing -using Turing.Inference: AdvancedHMC, AdvancedMH -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff - -function check_transition_varnames( - transition::Turing.Inference.Transition, - parent_varnames -) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end - # Varnames in `transition` should be subsumed by those in `vns`. - for vn in transition_varnames - @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) - end -end - -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::Model) = true - -@testset "Gibbs using `condition`" begin - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] - - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end - - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end - - @testset "comparison with 'gold-standard' samples" begin - num_iterations = 1_000 - thinning = 10 - num_chains = 4 - - # Determine initial parameters to make comparison as fair as possible. - posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) - - # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Experimental.Gibbs( - vns_s => sampler_inner, - vns_m => sampler_inner, - ) - Random.seed!(42) - chain = sample( - model, - sampler, - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - discard_initial=1_000, - thinning=thinning - ) - - # "Ground truth" samples. - # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. - Random.seed!(42) - chain_true = sample( - model, - NUTS(), - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - thinning=thinning, - ) - - # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i = 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) - # Let's make sure that the significance level is not too low by - # checking that the KS test fails for some simple transformations. - # TODO: Replace the heuristic below with closed-form implementations - # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) - end - end - end - end - - @testset "multiple varnames" begin - rng = Random.default_rng() - - @testset "with both `s` and `m` as random" begin - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - - # `sample` - Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) - end - - @testset "without `m` as random" begin - model = gdemo(1.5, 2.0) | (m=7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - end - end - - @testset "CSMC + ESS" begin - rng = Random.default_rng() - model = MoGtest_default - alg = Turing.Experimental.Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) - check_MoGtest_default(chain, atol = 0.2) - end - - @testset "CSMC + ESS (usage of implicit varname)" begin - rng = Random.default_rng() - model = MoGtest_default_z_vector - alg = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) - check_MoGtest_default_z_vector(chain, atol = 0.2) - end - - @testset "externsalsampler" begin - @model function demo_gibbs_external() - m1 ~ Normal() - m2 ~ Normal() - - -1 ~ Normal(m1, 1) - +1 ~ Normal(m1 + m2, 1) - - return (; m1, m2) - end - - model = demo_gibbs_external() - samplers_inner = [ - externalsampler(AdvancedMH.RWMH(1)), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), - ] - @testset "$(sampler_inner)" for sampler_inner in samplers_inner - sampler = Turing.Experimental.Gibbs( - @varname(m1) => sampler_inner, - @varname(m2) => sampler_inner, - ) - Random.seed!(42) - chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) - check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) - end - end -end - -end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15ec6149c..4a6e0e9a6 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -33,15 +33,15 @@ ADUtils.install_tapir && import Tapir PG(10), IS(), MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) else ( HMC(0.1, 7; adtype=adbackend), IS(), MH(), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), ) end for sampler in samplers @@ -85,7 +85,7 @@ ADUtils.install_tapir && import Tapir alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(PG(30, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) chn1 = sample(gdemo_default, alg1, 5000; save_state=true) check_gdemo(chn1) @@ -234,7 +234,7 @@ ADUtils.install_tapir && import Tapir smc = SMC() pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p; adtype=adbackend), PG(10, :x)) + gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -261,7 +261,7 @@ ADUtils.install_tapir && import Tapir return s, m end - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m; adtype=adbackend)) + gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -367,7 +367,7 @@ ADUtils.install_tapir && import Tapir @test all(isone, res_pg[:x]) end @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) + alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) chn = sample(gdemo_default, alg, 1000) end @testset "vectorization @." begin diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 0a1c23a9e..8d9697d9a 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -38,7 +38,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(ESS(:m), MH(:s)) + s3 = Gibbs(; m=ESS(), s=MH()) c5 = sample(gdemo_default, s3, N) end @@ -52,13 +52,17 @@ using Turing check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1) Random.seed!(100) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) # MoGtest Random.seed!(125) - alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) + alg = Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain; atol=0.1) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 6868cb5e8..5082a5f4f 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -1,55 +1,111 @@ module GibbsTests -using ..Models: MoGtest_default, gdemo, gdemo_default -using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical +using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo, gdemo_default +using ..NumericalTests: + check_MoGtest_default, + check_MoGtest_default_z_vector, + check_gdemo, + check_numerical, + two_sample_test import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample +using DynamicPPL: DynamicPPL using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff -using Test: @test, @testset +using Test: @test, @test_deprecated, @testset using Turing using Turing: Inference +using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess ADUtils.install_tapir && import Tapir +function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) + transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val + [first(vn_and_val)] + end + # Varnames in `transition` should be subsumed by those in `parent_varnames`. + for vn in transition_varnames + @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) + end +end + +const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, + DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, +} +has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false +has_dot_assume(::DynamicPPL.Model) = true + @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends - @testset "gibbs constructor" begin - N = 500 - s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) - s2 = Gibbs(PG(10, :s, :m)) - s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - for s in (s1, s2, s3, s4, s5, s6) + @testset "Deprecated Gibbs constructors" begin + N = 10 + @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) + @test_deprecated s2 = Gibbs(PG(10, :s, :m)) + @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) + @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) + @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) + for s in (s1, s2, s3, s4, s5, s6, s7) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - c1 = sample(gdemo_default, s1, N) - c2 = sample(gdemo_default, s2, N) - c3 = sample(gdemo_default, s3, N) - c4 = sample(gdemo_default, s4, N) - c5 = sample(gdemo_default, s5, N) - c6 = sample(gdemo_default, s6, N) + # Check that the samplers work despite using the deprecated constructor. + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) - # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) + @test sample(gdemo_default, g, N) isa MCMCChains.Chains + end + + @testset "Gibbs constructors" begin + N = 10 + s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5, :s, :m; adtype=adbackend)) + s2 = Gibbs((@varname(s), @varname(m)) => PG(10)) + s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) + s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) + s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend)) + s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) + s7 = Gibbs((@varname(s), @varname(m)) => PG(10)) + s8 = begin + hmc = HMC(0.1, 5; adtype=adbackend) + pg = PG(10) + vns = @varname(s) + vnm = @varname(m) + Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) + end + for s in (s1, s2, s3, s4, s5, s6, s7, s8) + @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" + end - _, state = AbstractMCMC.step(Random.default_rng(), gdemo_default, g) - @test state.samplers[1].selector != g.selector - @test state.samplers[2].selector != g.selector - @test state.samplers[1].selector != state.samplers[2].selector + sample(gdemo_default, s1, N) + sample(gdemo_default, s2, N) + sample(gdemo_default, s3, N) + sample(gdemo_default, s4, N) + sample(gdemo_default, s5, N) + sample(gdemo_default, s6, N) + sample(gdemo_default, s7, N) + sample(gdemo_default, s8, N) - # run sampler: progress logging should be disabled and - # it should return a Chains object + g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains end + @testset "gibbs inference" begin Random.seed!(100) - alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -57,11 +113,11 @@ ADUtils.install_tapir && import Tapir Random.seed!(100) - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(; s=CSMC(15), m=ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) @@ -71,15 +127,17 @@ ADUtils.install_tapir && import Tapir Random.seed!(200) gibbs = Gibbs( - PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2; adtype=adbackend) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15), + (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend), ) chain = sample(MoGtest_default, gibbs, 10_000) check_MoGtest_default(chain; atol=0.15) Random.seed!(200) for alg in [ - Gibbs((MH(:s), 2), (HMC(0.2, 4, :m; adtype=adbackend), 1)), - Gibbs((MH(:s), 1), (HMC(0.2, 4, :m; adtype=adbackend), 2)), + # The new syntax for specifying a sampler to run twice for one variable. + Gibbs(s => MH(), s => MH(), m => HMC(0.2, 4; adtype=adbackend)), + Gibbs(s => MH(), m => HMC(0.2, 4), m => HMC(0.2, 4); adtype=adbackend), ] chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain; atol=0.15) @@ -113,9 +171,10 @@ ADUtils.install_tapir && import Tapir return nothing end - alg = Gibbs(MH(:s), HMC(0.2, 4, :m; adtype=adbackend)) + alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end + @testset "dynamic model" begin @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} N = length(y) @@ -136,10 +195,250 @@ ADUtils.install_tapir && import Tapir m[k] ~ Normal(1.0, 1.0) end end - model = imm(randn(100), 1.0) + model = imm(Random.randn(100), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100); - sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m; adtype=adbackend)), 100) + sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100) + end + + @testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end + + samplers = [ + Turing.Gibbs(vns_s => NUTS(), vns_m => NUTS()), + Turing.Gibbs(vns_s => NUTS(), vns_m => HMC(0.01, 4)), + ] + + if !has_dot_assume(model) + # Add in some MH samplers, which are not compatible with `.~`. + append!( + samplers, + [ + Turing.Gibbs(vns_s => HMC(0.01, 4), vns_m => MH()), + Turing.Gibbs(vns_s => MH(), vns_m => HMC(0.01, 4)), + ], + ) + end + + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler) + ) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(sampler), state + ) + check_transition_varnames(transition, vns) + end + end + + # Run the Gibbs sampler and NUTS on the same model, compare statistics of the + # chains. + @testset "comparison with 'gold-standard' samples" begin + num_iterations = 1_000 + thinning = 10 + num_chains = 4 + + # Determine initial parameters to make comparison as fair as possible. + posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) + initial_params = DynamicPPL.TestUtils.update_values!!( + DynamicPPL.VarInfo(model), + posterior_mean, + DynamicPPL.TestUtils.varnames(model), + )[:] + initial_params = fill(initial_params, num_chains) + + # Sampler to use for Gibbs components. + sampler_inner = HMC(0.1, 32) + sampler = Turing.Gibbs(vns_s => sampler_inner, vns_m => sampler_inner) + Random.seed!(42) + chain = sample( + model, + sampler, + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + discard_initial=1_000, + thinning=thinning, + ) + + # "Ground truth" samples. + # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. + Random.seed!(42) + chain_true = sample( + model, + NUTS(), + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + thinning=thinning, + ) + + # Perform KS test to ensure that the chains are similar. + xs = Array(chain) + xs_true = Array(chain_true) + for i in 1:size(xs, 2) + @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + # Let's make sure that the significance level is not too low by + # checking that the KS test fails for some simple transformations. + # TODO: Replace the heuristic below with closed-form implementations + # of the targets, once they are implemented in DynamicPPL. + @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + end + end + end + end + + @testset "multiple varnames" begin + rng = Random.default_rng() + + @testset "with both `s` and `m` as random" begin + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # `sample` + Random.seed!(42) + chain = sample(model, alg, 10_000; progress=false) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) + end + + @testset "without `m` as random" begin + model = gdemo(1.5, 2.0) | (m=7 / 6,) + vns = (@varname(s),) + alg = Turing.Gibbs(vns => MH()) + + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + end + end + + @testset "CSMC + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Gibbs( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + vns = ( + @varname(z1), + @varname(z2), + @varname(z3), + @varname(z4), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(MoGtest_default, alg, 1000; progress=false) + check_MoGtest_default(chain; atol=0.2) + end + + @testset "CSMC + ESS (usage of implicit varname)" begin + rng = Random.default_rng() + model = MoGtest_default_z_vector + alg = Turing.Gibbs( + @varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS() + ) + vns = ( + @varname(z[1]), + @varname(z[2]), + @varname(z[3]), + @varname(z[4]), + @varname(mu1), + @varname(mu2) + ) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end + + # Sample! + Random.seed!(42) + chain = sample(model, alg, 1000; progress=false) + check_MoGtest_default_z_vector(chain; atol=0.2) + end + + @testset "externsalsampler" begin + @model function demo_gibbs_external() + m1 ~ Normal() + m2 ~ Normal() + + -1 ~ Normal(m1, 1) + +1 ~ Normal(m1 + m2, 1) + + return (; m1, m2) + end + + model = demo_gibbs_external() + samplers_inner = [ + externalsampler(AdvancedMH.RWMH(1)), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoForwardDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff()), + externalsampler( + AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff(; compile=true) + ), + ] + @testset "$(sampler_inner)" for sampler_inner in samplers_inner + sampler = Turing.Gibbs( + @varname(m1) => sampler_inner, @varname(m2) => sampler_inner + ) + Random.seed!(42) + chain = sample( + model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0 + ) + check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1) + end end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl deleted file mode 100644 index 3f02c7594..000000000 --- a/test/mcmc/gibbs_conditional.jl +++ /dev/null @@ -1,172 +0,0 @@ -module GibbsConditionalTests - -using ..Models: gdemo, gdemo_default -using ..NumericalTests: check_gdemo, check_numerical -import ..ADUtils -using Clustering: Clustering -using Distributions: Categorical, InverseGamma, Normal, sample -using ForwardDiff: ForwardDiff -using LinearAlgebra: Diagonal, I -using Random: Random -using ReverseDiff: ReverseDiff -using StableRNGs: StableRNG -using StatsBase: counts -using StatsFuns: StatsFuns -using Test: @test, @testset -using Turing - -ADUtils.install_tapir && import Tapir - -@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends - Random.seed!(1000) - rng = StableRNG(123) - - @testset "gdemo" begin - # We consider the model - # ```math - # s ~ InverseGamma(2, 3) - # m ~ Normal(0, √s) - # xᵢ ~ Normal(m, √s), i = 1, …, N, - # ``` - # with ``N = 2`` observations ``x₁ = 1.5`` and ``x₂ = 2``. - - # The conditionals and posterior can be formulated in terms of the following statistics: - N = 2 - x_mean = 1.75 # sample mean ``∑ xᵢ / N`` - x_var = 0.0625 # sample variance ``∑ (xᵢ - x_bar)^2 / N`` - m_n = 3.5 / 3 # ``∑ xᵢ / (N + 1)`` - - # Conditional distribution - # ```math - # m | s, x ~ Normal(m_n, sqrt(s / (N + 1))) - # ``` - cond_m = let N = N, m_n = m_n - c -> Normal(m_n, sqrt(c.s / (N + 1))) - end - - # Conditional distribution - # ```math - # s | m, x ~ InverseGamma(2 + (N + 1) / 2, 3 + (m^2 + ∑ (xᵢ - m)^2) / 2) = - # InverseGamma(2 + (N + 1) / 2, 3 + m^2 / 2 + N / 2 * (x_var + (x_mean - m)^2)) - # ``` - cond_s = let N = N, x_mean = x_mean, x_var = x_var - c -> InverseGamma( - 2 + (N + 1) / 2, 3 + c.m^2 / 2 + N / 2 * (x_var + (x_mean - c.m)^2) - ) - end - - # Three Gibbs samplers: - # one for each variable fixed to the posterior mean - s_posterior_mean = 49 / 24 - sampler1 = Gibbs( - GibbsConditional(:m, cond_m), - GibbsConditional(:s, _ -> Normal(s_posterior_mean, 0)), - ) - chain = sample(rng, gdemo_default, sampler1, 10_000) - cond_m_mean = mean(cond_m((s=s_posterior_mean,))) - check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean]) - @test all(==(s_posterior_mean), chain[:s][2:end]) - - m_posterior_mean = 7 / 6 - sampler2 = Gibbs( - GibbsConditional(:m, _ -> Normal(m_posterior_mean, 0)), - GibbsConditional(:s, cond_s), - ) - chain = sample(rng, gdemo_default, sampler2, 10_000) - cond_s_mean = mean(cond_s((m=m_posterior_mean,))) - check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean]) - @test all(==(m_posterior_mean), chain[:m][2:end]) - - # and one for both using the conditional - sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s)) - chain = sample(rng, gdemo_default, sampler3, 10_000) - check_gdemo(chain) - end - - @testset "GMM" begin - Random.seed!(1000) - rng = StableRNG(123) - # We consider the model - # ```math - # μₖ ~ Normal(m, σ_μ), k = 1, …, K, - # zᵢ ~ Categorical(π), i = 1, …, N, - # xᵢ ~ Normal(μ_{zᵢ}, σₓ), i = 1, …, N, - # ``` - # with ``K = 2`` clusters, ``N = 20`` observations, and the following parameters: - K = 2 # number of clusters - π = fill(1 / K, K) # uniform cluster weights - m = 0.5 # prior mean of μₖ - σ²_μ = 4.0 # prior variance of μₖ - σ²_x = 0.01 # observation variance - N = 20 # number of observations - - # We generate data - μ_data = rand(rng, Normal(m, sqrt(σ²_μ)), K) - z_data = rand(rng, Categorical(π), N) - x_data = rand(rng, MvNormal(μ_data[z_data], σ²_x * I)) - - @model function mixture(x) - μ ~ $(MvNormal(fill(m, K), σ²_μ * I)) - z ~ $(filldist(Categorical(π), N)) - x ~ MvNormal(μ[z], $(σ²_x * I)) - return x - end - model = mixture(x_data) - - # Conditional distribution ``z | μ, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_z = let x = x_data, log_π = log.(π), σ_x = sqrt(σ²_x) - c -> begin - dists = map(x) do xi - logp = log_π .+ logpdf.(Normal.(c.μ, σ_x), xi) - return Categorical(StatsFuns.softmax!(logp)) - end - return arraydist(dists) - end - end - - # Conditional distribution ``μ | z, x`` - # see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf - cond_μ = let K = K, x_data = x_data, inv_σ²_μ = inv(σ²_μ), inv_σ²_x = inv(σ²_x) - c -> begin - # Convert cluster assignments to one-hot encodings - z_onehot = c.z .== (1:K)' - - # Count number of observations in each cluster - n = vec(sum(z_onehot; dims=1)) - - # Compute mean and variance of the conditional distribution - μ_var = @. inv(inv_σ²_x * n + inv_σ²_μ) - μ_mean = (z_onehot' * x_data) .* inv_σ²_x .* μ_var - - return MvNormal(μ_mean, Diagonal(μ_var)) - end - end - - estimate(chain, var) = dropdims(mean(Array(group(chain, var)); dims=1); dims=1) - function estimatez(chain, var, range) - z = Int.(Array(group(chain, var))) - return map(i -> findmax(counts(z[:, i], range))[2], 1:size(z, 2)) - end - - lμ_data, uμ_data = extrema(μ_data) - - # Compare three Gibbs samplers - sampler1 = Gibbs(GibbsConditional(:z, cond_z), GibbsConditional(:μ, cond_μ)) - sampler2 = Gibbs(GibbsConditional(:z, cond_z), MH(:μ)) - sampler3 = Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, :μ; adtype=adbackend)) - for sampler in (sampler1, sampler2, sampler3) - chain = sample(rng, model, sampler, 10_000) - - μ_hat = estimate(chain, :μ) - lμ_hat, uμ_hat = extrema(μ_hat) - @test isapprox([lμ_data, uμ_data], [lμ_hat, uμ_hat], atol=0.1) - - z_hat = estimatez(chain, :z, 1:2) - ari, _, _, _ = Clustering.randindex(z_data, Int.(z_hat)) - @test isapprox(ari, 1, atol=0.1) - end - end -end - -end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index dde977a6f..889be13c5 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -130,9 +130,9 @@ ADUtils.install_tapir && import Tapir @testset "hmcda inference" begin alg1 = HMCDA(500, 0.8, 0.015; adtype=adbackend) - # alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m; adtype=adbackend), HMC(0.25, 3, :s; adtype=adbackend)) + # alg2 = Gibbs(; m=HMCDA(200, 0.8, 0.35; adtype=adbackend), s=HMC(0.25, 3; adtype=adbackend)) - # alg3 = Gibbs(HMC(0.25, 3, :m; adtype=adbackend), PG(30, 3, :s)) + # alg3 = Gibbs(; m=HMC(0.25, 3; adtype=adbackend), s=PG(30, 3)) # alg3 = PG(50, 2000) res1 = sample(rng, gdemo_default, alg1, 3000) @@ -147,7 +147,7 @@ ADUtils.install_tapir && import Tapir @testset "hmcda+gibbs inference" begin rng = StableRNG(123) Random.seed!(12345) # particle samplers do not support user-provided `rng` yet - alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) + alg3 = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) check_gdemo(res3) @@ -200,9 +200,9 @@ ADUtils.install_tapir && import Tapir @test size(c2, 1) == 500 end @testset "AHMC resize" begin - alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s; adtype=adbackend)) - alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s; adtype=adbackend)) - alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s; adtype=adbackend)) + alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(rng, gdemo_default, alg1, 300) isa Chains @test sample(rng, gdemo_default, alg2, 300) isa Chains @test sample(rng, gdemo_default, alg3, 300) isa Chains diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index a01d3dc25..f454db5a0 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -32,7 +32,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(MH(:m), MH(:s)) + s4 = Gibbs(; m=MH(), s=MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -62,14 +62,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) Random.seed!(125) # MH within Gibbs - alg = Gibbs(MH(:m), MH(:s)) + alg = Gibbs(; m=MH(), s=MH()) chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params) check_gdemo(chain; atol=0.1) Random.seed!(125) # MoGtest gibbs = Gibbs( - CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1))) + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => MH((:mu1, GKernel(1))), + @varname(mu2) => MH((:mu2, GKernel(1))), ) chain = sample( MoGtest_default, @@ -167,7 +169,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg = Gibbs(MH((:μ, vc_μ)), MH((:σ, vc_σ))) + alg = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) chn = sample( mod, diff --git a/test/runtests.jl b/test/runtests.jl index 1aa8bb635..ba9aafd2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,7 +46,6 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin @timeit_include("mcmc/gibbs.jl") - @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") @@ -65,10 +64,6 @@ end end end - @testset "experimental" begin - @timeit_include("experimental/gibbs.jl") - end - @testset "variational optimisers" begin @timeit_include("variational/optimisers.jl") end diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index c1340464f..2dabc09bd 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(PG(20, 1, :x), HMC(0.2, 3, :y))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2