From a05f128deb8e519b2ecbdc729498d57e3ecc86cd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 11:20:23 +0000 Subject: [PATCH 01/12] Rework Gibbs constructors, and remove the dead test/experimental/gibbs.jl --- src/mcmc/gibbs.jl | 58 ++++---- test/experimental/gibbs.jl | 271 ----------------------------------- test/mcmc/ess.jl | 4 +- test/mcmc/gibbs.jl | 59 ++++---- test/mcmc/mh.jl | 6 +- test/skipped/explicit_ret.jl | 2 +- 6 files changed, 67 insertions(+), 333 deletions(-) delete mode 100644 test/experimental/gibbs.jl diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 0f2c78ebe..fe818c4a8 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -292,11 +292,33 @@ function set_selector(x::RepeatSampler) end set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) +to_varname(vn::VarName) = vn +to_varname(s::Symbol) = VarName{s}() + +to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] +# Any other value is assumed to be an iterable of VarNames and Symbols. +to_varname_list(t) = map(to_varname, collect(t)) + """ Gibbs A type representing a Gibbs sampler. +# Constructors + +`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single +variable name per sampler, one can also give an iterable of variables, all of which are +sampled by the same component sampler. + +Each variable name can be given as either a `Symbol` or a `VarName`. + +Some examples of valid constructors are: +```julia +Gibbs(:x => NUTS(), :y => MH()) +Gibbs(@varname(x) => NUTS(), @varname(y) => MH()) +Gibbs((@varname(x), :y) => NUTS(), :z => MH()) +``` + # Fields $(TYPEDFIELDS) """ @@ -310,35 +332,24 @@ struct Gibbs{V,A} <: InferenceAlgorithm if length(varnames) != length(samplers) throw(ArgumentError("Number of varnames and samplers must match.")) end + for spl in samplers if !isgibbscomponent(spl) msg = "All samplers must be valid Gibbs components, $(spl) is not." throw(ArgumentError(msg)) end end + + # Ensure that samplers have the same selector, and that varnames are lists of + # VarNames. + samplers = map(set_selector ∘ drop_space, samplers) + varnames = map(to_varname_list, varnames) return new{typeof(varnames),typeof(samplers)}(varnames, samplers) end end -to_varname(vn::VarName) = vn -to_varname(s::Symbol) = VarName{s}() -# Any other value is assumed to be an iterable. -to_varname(t) = map(to_varname, collect(t)) - -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs(map(to_varname, keys(algs)), map(set_selector ∘ drop_space, values(algs))) -end - -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs( - map(to_varname, collect(keys(algs))), map(set_selector ∘ drop_space, values(algs)) - ) -end function Gibbs(algs::Pair...) - return Gibbs(map(to_varname ∘ first, algs), map(set_selector ∘ drop_space ∘ last, algs)) + return Gibbs(map(first, algs), map(last, algs)) end # The below two constructors only provide backwards compatibility with the constructor of @@ -384,10 +395,9 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] -_maybevec(x::Symbol) = [x] +_maybecollect(x) = collect(x) # assume it's iterable +_maybecollect(x::VarName) = [x] +_maybecollect(x::Symbol) = [x] varinfo(state::GibbsState) = state.vi @@ -412,7 +422,7 @@ function DynamicPPL.initialstep( # Initialise each component sampler in turn, collect all their states. states = [] for (varnames_local, sampler_local) in zip(varnames, samplers) - varnames_local = _maybevec(varnames_local) + varnames_local = _maybecollect(varnames_local) # Get the initial values for this component sampler. initial_params_local = if initial_params === nothing nothing @@ -463,7 +473,7 @@ function AbstractMCMC.step( # Take the inner step. sampler_local = samplers[index] state_local = states[index] - varnames_local = _maybevec(varnames[index]) + varnames_local = _maybecollect(varnames[index]) vi, new_state_local = gibbs_step_inner( rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl deleted file mode 100644 index 70546350d..000000000 --- a/test/experimental/gibbs.jl +++ /dev/null @@ -1,271 +0,0 @@ -module ExperimentalGibbsTests - -using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo -using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical, two_sample_test -using DynamicPPL -using Random -using Test -using Turing -using Turing.Inference: AdvancedHMC, AdvancedMH -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff - -function check_transition_varnames( - transition::Turing.Inference.Transition, - parent_varnames -) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end - # Varnames in `transition` should be subsumed by those in `vns`. - for vn in transition_varnames - @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) - end -end - -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::Model) = true - -@testset "Gibbs using `condition`" begin - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] - - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end - - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end - - @testset "comparison with 'gold-standard' samples" begin - num_iterations = 1_000 - thinning = 10 - num_chains = 4 - - # Determine initial parameters to make comparison as fair as possible. - posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) - - # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Experimental.Gibbs( - vns_s => sampler_inner, - vns_m => sampler_inner, - ) - Random.seed!(42) - chain = sample( - model, - sampler, - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - discard_initial=1_000, - thinning=thinning - ) - - # "Ground truth" samples. - # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. - Random.seed!(42) - chain_true = sample( - model, - NUTS(), - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - thinning=thinning, - ) - - # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i = 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) - # Let's make sure that the significance level is not too low by - # checking that the KS test fails for some simple transformations. - # TODO: Replace the heuristic below with closed-form implementations - # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) - end - end - end - end - - @testset "multiple varnames" begin - rng = Random.default_rng() - - @testset "with both `s` and `m` as random" begin - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - - # `sample` - Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) - end - - @testset "without `m` as random" begin - model = gdemo(1.5, 2.0) | (m=7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - end - end - - @testset "CSMC + ESS" begin - rng = Random.default_rng() - model = MoGtest_default - alg = Turing.Experimental.Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) - check_MoGtest_default(chain, atol = 0.2) - end - - @testset "CSMC + ESS (usage of implicit varname)" begin - rng = Random.default_rng() - model = MoGtest_default_z_vector - alg = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) - check_MoGtest_default_z_vector(chain, atol = 0.2) - end - - @testset "externsalsampler" begin - @model function demo_gibbs_external() - m1 ~ Normal() - m2 ~ Normal() - - -1 ~ Normal(m1, 1) - +1 ~ Normal(m1 + m2, 1) - - return (; m1, m2) - end - - model = demo_gibbs_external() - samplers_inner = [ - externalsampler(AdvancedMH.RWMH(1)), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), - ] - @testset "$(sampler_inner)" for sampler_inner in samplers_inner - sampler = Turing.Experimental.Gibbs( - @varname(m1) => sampler_inner, - @varname(m2) => sampler_inner, - ) - Random.seed!(42) - chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) - check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) - end - end -end - -end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 6db469b76..5533d11d7 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -40,7 +40,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(; m=ESS(), s=MH()) + s3 = Gibbs(:m => ESS(), :s => MH()) c5 = sample(gdemo_default, s3, N) end @@ -59,7 +59,7 @@ using Turing end @testset "gdemo with CSMC + ESS" begin - alg = Gibbs(; s=CSMC(15), m=ESS()) + alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 43bdcdbb8..1d7208b43 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -130,6 +130,9 @@ end @test_throws ArgumentError Gibbs( @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) ) + # Values that we don't know how to convert to VarNames. + @test_throws MethodError Gibbs(1 => NUTS()) + @test_throws MethodError Gibbs("x" => NUTS()) end # Test that the samplers are being called in the correct order, on the correct target @@ -227,14 +230,14 @@ end nuts = NUTS() # Sample with all sorts of combinations of samplers and targets. sampler = Gibbs( - (@varname(s),) => AlgWrapper(mh), + @varname(s) => AlgWrapper(mh), (@varname(s), @varname(m)) => AlgWrapper(mh), - (@varname(m),) => AlgWrapper(pg), - (@varname(xs),) => AlgWrapper(hmc), - (@varname(ys),) => AlgWrapper(nuts), - (@varname(ys),) => AlgWrapper(nuts), + @varname(m) => AlgWrapper(pg), + @varname(xs) => AlgWrapper(hmc), + @varname(ys) => AlgWrapper(nuts), + @varname(ys) => AlgWrapper(nuts), (@varname(xs), @varname(ys)) => AlgWrapper(hmc), - (@varname(s),) => AlgWrapper(mh), + @varname(s) => AlgWrapper(mh), ) chain = sample(test_model(-1), sampler, 2) @@ -300,17 +303,12 @@ end # Two variables being sampled by one sampler. s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) s2 = Gibbs((@varname(s), :m) => PG(10)) - # One variable per sampler, using the keyword arg interface. - s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) - # As above but using a Dict of VarNames. - s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) # As above but different samplers and using kwargs. - s5 = Gibbs(; s=CSMC(3), m=HMCDA(200, 0.65, 0.15; adtype=adbackend)) - s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) - s7 = Gibbs(Dict((:s, @varname(m)) => PG(10))) + s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15; adtype=adbackend)) + s4 = Gibbs(@varname(s) => HMC(0.1, 5; adtype=adbackend), @varname(m) => ESS()) # Multiple instnaces of the same sampler. This implements running, in this case, # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. - s8 = begin + s5 = begin hmc = HMC(0.1, 5; adtype=adbackend) pg = PG(10) vns = @varname(s) @@ -318,23 +316,20 @@ end Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) end # Same thing but using RepeatSampler. - s9 = Gibbs( + s6 = Gibbs( @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), @varname(m) => RepeatSampler(PG(10), 2), ) - for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9) + for s in (s1, s2, s3, s4, s5, s6) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - sample(gdemo_default, s1, N) - sample(gdemo_default, s2, N) - sample(gdemo_default, s3, N) - sample(gdemo_default, s4, N) - sample(gdemo_default, s5, N) - sample(gdemo_default, s6, N) - sample(gdemo_default, s7, N) - sample(gdemo_default, s8, N) - sample(gdemo_default, s9, N) + @test sample(gdemo_default, s1, N) isa MCMCChains.Chains + @test sample(gdemo_default, s2, N) isa MCMCChains.Chains + @test sample(gdemo_default, s3, N) isa MCMCChains.Chains + @test sample(gdemo_default, s4, N) isa MCMCChains.Chains + @test sample(gdemo_default, s5, N) isa MCMCChains.Chains + @test sample(gdemo_default, s6, N) isa MCMCChains.Chains g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains @@ -344,7 +339,7 @@ end # posterior mean. @testset "Gibbs inference" begin @testset "CSMC and HMC on gdemo" begin - alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -352,13 +347,13 @@ end end @testset "MH and HMCDA on gdemo" begin - alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @testset "CSMC and ESS on gdemo" begin - alg = Gibbs(; s=CSMC(15), m=ESS()) + alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @@ -435,7 +430,7 @@ end return nothing end - alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end @@ -463,11 +458,11 @@ end num_samples = 10_000 model = imm(Random.randn(num_zs), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); + # sample(model, Gibbs(:z => MH(), :m => HMC(0.01, 4)), 100); chn = sample( StableRNG(23), model, - Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), + Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), num_samples, ) # The number of m variables that have a non-zero value in a sample. @@ -527,7 +522,7 @@ end # TODO(mhauru) This is broken because of # https://github.com/TuringLang/DynamicPPL.jl/issues/700. @test_broken ( - sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); + sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100); true ) end diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 3823c2986..7d5a841d4 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -34,7 +34,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(; m=MH(), s=MH()) + s4 = Gibbs(:m => MH(), :s => MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -69,7 +69,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) end @testset "gdemo_default with MH-within-Gibbs" begin - alg = Gibbs(; m=MH(), s=MH()) + alg = Gibbs(:m => MH(), :s => MH()) chain = sample( StableRNG(seed), gdemo_default, alg, 10_000; discard_initial, initial_params ) @@ -177,7 +177,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # with small-valued VC matrix to check if we only see very small steps vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg_small = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) + alg_small = Gibbs(:μ => MH((:μ, vc_μ)), :σ => MH((:σ, vc_σ))) alg_big = MH() chn_small = sample(StableRNG(seed), mod, alg_small, 1_000) chn_big = sample(StableRNG(seed), mod, alg_big, 1_000) diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index 2dabc09bd..7472e8c31 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(:x => PG(20, 1), :y => HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2 From 4c79067080be99c9e977225249f89786767ee3fc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 11:20:52 +0000 Subject: [PATCH 02/12] Update HISTORY.md --- HISTORY.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ff50fb779..011d42414 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,13 +4,11 @@ 0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. -The new Gibbs sampler supports the same user-facing interface as the old one. However, given -that the internals of it having been completely rewritten in a very different manner, there -may be accidental breakage that we haven't anticipated. Please report any you find. +The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find. `GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. -The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables of `VarName`s to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable. Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`. From 27f4e22ac6d3a4cc4b0fc0613f7bf6d96f3fcbe8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 11:34:50 +0000 Subject: [PATCH 03/12] Clarify docstring --- src/mcmc/gibbs.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index fe818c4a8..86433484c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -319,6 +319,9 @@ Gibbs(@varname(x) => NUTS(), @varname(y) => MH()) Gibbs((@varname(x), :y) => NUTS(), :z => MH()) ``` +Currently only variable names without indexing are supported, so for instance +`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future. + # Fields $(TYPEDFIELDS) """ From 1579786dad53ba8b3411cc123bd057dc20fc9e3d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 11:35:12 +0000 Subject: [PATCH 04/12] Remove unnecessary _maybecollect in gibbs.jl --- src/mcmc/gibbs.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 86433484c..1c89dafb5 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -398,10 +398,6 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -_maybecollect(x) = collect(x) # assume it's iterable -_maybecollect(x::VarName) = [x] -_maybecollect(x::Symbol) = [x] - varinfo(state::GibbsState) = state.vi function DynamicPPL.initialstep( @@ -425,7 +421,6 @@ function DynamicPPL.initialstep( # Initialise each component sampler in turn, collect all their states. states = [] for (varnames_local, sampler_local) in zip(varnames, samplers) - varnames_local = _maybecollect(varnames_local) # Get the initial values for this component sampler. initial_params_local = if initial_params === nothing nothing @@ -476,7 +471,6 @@ function AbstractMCMC.step( # Take the inner step. sampler_local = samplers[index] state_local = states[index] - varnames_local = _maybecollect(varnames[index]) vi, new_state_local = gibbs_step_inner( rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) From 6bb784d5f723213bbdf364d253f412b3eafd6c4c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 11:50:22 +0000 Subject: [PATCH 05/12] Fix a bug --- src/mcmc/gibbs.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 1c89dafb5..5622f4f34 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -471,6 +471,7 @@ function AbstractMCMC.step( # Take the inner step. sampler_local = samplers[index] state_local = states[index] + varnames_local = varnames[index] vi, new_state_local = gibbs_step_inner( rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) From ded6c508997a69ed93bc9eba81c3081aee7c2d3a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 14:02:14 +0000 Subject: [PATCH 06/12] Fix more Gibbs constructors in tests --- src/mcmc/gibbs.jl | 3 +++ test/dynamicppl/compiler.jl | 6 +++--- test/mcmc/Inference.jl | 12 ++++++------ test/mcmc/hmc.jl | 10 ++++++---- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 5622f4f34..0567cd282 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -379,6 +379,9 @@ function Gibbs(algs::InferenceAlgorithm...) return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) end +# This disambiguates a method ambiguity. To be removed when the above deprecated one is. +Gibbs() = Gibbs([], []) + function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) algs = Iterators.map(first, algs_with_iters) iters = Iterators.map(last, algs_with_iters) diff --git a/test/dynamicppl/compiler.jl b/test/dynamicppl/compiler.jl index 7939c7beb..7f5726614 100644 --- a/test/dynamicppl/compiler.jl +++ b/test/dynamicppl/compiler.jl @@ -54,7 +54,7 @@ const gdemo_default = gdemo_d() smc = SMC() pg = PG(10) - gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -81,7 +81,7 @@ const gdemo_default = gdemo_d() return s, m end - gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d() end @testset "sample" begin - alg = Gibbs(; m=HMC(0.2, 3), s=PG(10)) + alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10)) chn = sample(gdemo_default, alg, 1000) end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9356fbcc1..da29e7708 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -31,8 +31,8 @@ using Turing PG(10), IS(), MH(), - Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), - Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), + Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)), + Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()), ) for sampler in samplers Random.seed!(5) @@ -81,7 +81,7 @@ using Turing @testset "chain save/resume" begin alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) + alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend)) chn1 = sample(StableRNG(seed), gdemo_default, alg1, 2_000; save_state=true) check_gdemo(chn1) @@ -260,7 +260,7 @@ using Turing smc = SMC() pg = PG(10) - gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10)) chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) @@ -288,7 +288,7 @@ using Turing return s, m end - gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend)) chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2) end @@ -415,7 +415,7 @@ using Turing end @testset "sample" begin - alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) + alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10)) chn = sample(StableRNG(seed), gdemo_default, alg, 10) end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 47ff73b1c..d45846f3d 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -146,7 +146,9 @@ using Turing # explicitly specifying the seeds here. @testset "hmcda+gibbs inference" begin Random.seed!(12345) - alg = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) + alg = Gibbs( + :s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend) + ) res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000) check_gdemo(res) end @@ -199,9 +201,9 @@ using Turing end @testset "AHMC resize" begin - alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) - alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) - alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) + alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains From 3faac5320d2bbbf274cc9c682e2a9071829e52e4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 17:52:42 +0000 Subject: [PATCH 07/12] Improve HISTORY.md note Co-authored-by: Penelope Yong --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 011d42414..64be69106 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,7 +8,7 @@ The new Gibbs sampler currently supports the same user-facing interface as the o `GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. -The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables of `VarName`s to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable. +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable. Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`. From 7b308ceb18616c1cc270398945d2d90fe13bcea6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 17:56:20 +0000 Subject: [PATCH 08/12] Apply proposals from code review --- src/mcmc/gibbs.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 0567cd282..84c96125e 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -293,11 +293,11 @@ end set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) to_varname(vn::VarName) = vn -to_varname(s::Symbol) = VarName{s}() +to_varname(s::Symbol) = VarName(s) to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] # Any other value is assumed to be an iterable of VarNames and Symbols. -to_varname_list(t) = map(to_varname, collect(t)) +to_varname_list(t) = collect(map(to_varname, t)) """ Gibbs @@ -357,7 +357,8 @@ end # The below two constructors only provide backwards compatibility with the constructor of # the old Gibbs sampler. They are deprecated and will be removed in the future. -function Gibbs(algs::InferenceAlgorithm...) +function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...) + algs = [alg1, other_algs...] varnames = map(algs) do alg space = getspace(alg) if (space isa VarName) @@ -379,9 +380,6 @@ function Gibbs(algs::InferenceAlgorithm...) return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) end -# This disambiguates a method ambiguity. To be removed when the above deprecated one is. -Gibbs() = Gibbs([], []) - function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) algs = Iterators.map(first, algs_with_iters) iters = Iterators.map(last, algs_with_iters) From eb37329e04d0cf26ede36f9c5b4306575aceb945 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 18:00:18 +0000 Subject: [PATCH 09/12] Add type bounds to Gibbs type parameters --- src/mcmc/gibbs.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 84c96125e..d72135874 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -325,7 +325,8 @@ Currently only variable names without indexing are supported, so for instance # Fields $(TYPEDFIELDS) """ -struct Gibbs{V,A} <: InferenceAlgorithm +struct Gibbs{V<:AbstractVector{<:AbstractVector{<:VarName}},A<:AbstractVector} <: + InferenceAlgorithm "varnames representing variables for each sampler" varnames::V "samplers for each entry in `varnames`" From 46ef2f990d325a3ecda1b87beeb61cbc13b8d803 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 18:02:49 +0000 Subject: [PATCH 10/12] Style improvements to gibbs.jl --- src/mcmc/gibbs.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index d72135874..068c24a00 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -292,12 +292,9 @@ function set_selector(x::RepeatSampler) end set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) -to_varname(vn::VarName) = vn -to_varname(s::Symbol) = VarName(s) - -to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] +to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)] # Any other value is assumed to be an iterable of VarNames and Symbols. -to_varname_list(t) = collect(map(to_varname, t)) +to_varname_list(t) = collect(map(VarName, t)) """ Gibbs From a983ad17a1b0d5917fa4e54cd70ff8aefb3546af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 7 Jan 2025 18:09:17 +0000 Subject: [PATCH 11/12] Fix method ambiguity --- src/mcmc/gibbs.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 068c24a00..52971e08d 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -378,7 +378,11 @@ function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...) return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) end -function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) +function Gibbs( + alg_with_iters1::Tuple{<:InferenceAlgorithm,Int}, + other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}..., +) + algs_with_iters = [alg_with_iters1, other_algs_with_iters...] algs = Iterators.map(first, algs_with_iters) iters = Iterators.map(last, algs_with_iters) algs_duplicated = Iterators.flatten(( From 4f0b970febf569892a32dfb3d76407e1df8a6c6a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 10:18:52 +0000 Subject: [PATCH 12/12] Modify type signature of Gibbs --- src/mcmc/gibbs.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 52971e08d..ada5f611b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -322,8 +322,10 @@ Currently only variable names without indexing are supported, so for instance # Fields $(TYPEDFIELDS) """ -struct Gibbs{V<:AbstractVector{<:AbstractVector{<:VarName}},A<:AbstractVector} <: +struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: InferenceAlgorithm + # TODO(mhauru) Revisit whether A should have a fixed element type once + # InferenceAlgorithm/Sampler types have been cleaned up. "varnames representing variables for each sampler" varnames::V "samplers for each entry in `varnames`" @@ -343,9 +345,9 @@ struct Gibbs{V<:AbstractVector{<:AbstractVector{<:VarName}},A<:AbstractVector} < # Ensure that samplers have the same selector, and that varnames are lists of # VarNames. - samplers = map(set_selector ∘ drop_space, samplers) - varnames = map(to_varname_list, varnames) - return new{typeof(varnames),typeof(samplers)}(varnames, samplers) + samplers = tuple(map(set_selector ∘ drop_space, samplers)...) + varnames = tuple(map(to_varname_list, varnames)...) + return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers) end end