From 061acbe7fab5fac3e1af6fb43fd69938b9e8a5a4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Mar 2025 10:34:04 +0000 Subject: [PATCH 1/6] Release 0.36 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a9463a821..a1bf65fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.0" +version = "0.36.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From fc323985ee744a38ce3013915150f0891d4af3ab Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 28 Mar 2025 17:01:39 +0000 Subject: [PATCH 2/6] AbstractPPL 0.11 + change prefixing behaviour (#830) * AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading --- HISTORY.md | 49 ++++++++++++++++ Project.toml | 2 +- docs/Project.toml | 1 + docs/src/api.md | 2 +- src/DynamicPPL.jl | 5 +- src/contexts.jl | 26 ++++----- src/debug_utils.jl | 2 +- src/model.jl | 58 ++++++------------- src/submodel_macro.jl | 16 +++--- src/utils.jl | 8 ++- test/Project.toml | 2 +- test/compiler.jl | 11 ++-- test/contexts.jl | 126 +++++++++++++++++++++++------------------- test/debug_utils.jl | 4 +- test/deprecated.jl | 2 +- test/model.jl | 6 +- 16 files changed, 180 insertions(+), 140 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 3ea8071f3..cd2757edc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,54 @@ # DynamicPPL Changelog +## 0.36.0 + +**Breaking changes** + +### VarName prefixing behaviour + +The way in which VarNames in submodels are prefixed has been changed. +This is best explained through an example. +Consider this model and submodel: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model outer() = a ~ to_submodel(inner()) +``` + +In previous versions, the inner variable `x` would be saved as `a.x`. +However, this was represented as a single symbol `Symbol("a.x")`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{Symbol("a.x"), typeof(identity)} + optic: identity (function of type typeof(identity)) +``` + +Now, the inner variable is stored as a field `x` on the VarName `a`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{:a, Accessors.PropertyLens{:x}} + optic: Accessors.PropertyLens{:x} (@o _.x) +``` + +In practice, this means that if you are trying to condition a variable in the submodel, you now need to use + +```julia +outer() | (@varname(a.x) => 1.0,) +``` + +instead of either of these (which would have worked previously) + +```julia +outer() | (@varname(var"a.x") => 1.0,) +outer() | (a.x=1.0,) +``` + +If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) + ## 0.35.5 Several internal methods have been removed: diff --git a/Project.toml b/Project.toml index d5185d727..516dee26e 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/Project.toml b/docs/Project.toml index fa57f2c1c..40a719e03 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/docs/src/api.md b/docs/src/api.md index 9c8249c97..2f6376f5d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -149,7 +149,7 @@ In the past, one would instead embed sub-models using [`@submodel`](@ref), which In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs -prefix +DynamicPPL.prefix ``` Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 50fe0edc7..9f45718c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -21,6 +21,9 @@ using DocStringExtensions using Random: Random +# For extending +import AbstractPPL: predict + # TODO: Remove these when it's possible. import Bijectors: link, invlink @@ -39,8 +42,6 @@ import Base: keys, haskey -import AbstractPPL: predict - # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/contexts.jl b/src/contexts.jl index a54c60374..58ac612b8 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,25 +260,21 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} return PrefixContext{Prefix}(child) end -const PREFIX_SEPARATOR = Symbol(".") - -@generated function PrefixContext{PrefixOuter}( - context::PrefixContext{PrefixInner} -) where {PrefixOuter,PrefixInner} - return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - context.context - )) -end +""" + prefix(ctx::AbstractContext, vn::VarName) +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - vn_prefixed_inner = prefix(childcontext(ctx), vn) - return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}( - getoptic(vn_prefixed_inner) - ) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +end +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) end -prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn) +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) +end """ prefix(model::Model, x) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 328fe6983..529092e8e 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = prefix(context, varname) + prefixed_varname = DynamicPPL.prefix(context, varname) if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure error("varname $prefixed_varname used multiple times in model") diff --git a/src/model.jl b/src/model.jl index a0451b1b6..b4d5f6bb7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -243,7 +243,7 @@ julia> model() ≠ 1.0 true julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`. - conditioned_model = model | (var"inner.m" = 1.0, ); + conditioned_model = model | (@varname(inner.m) => 1.0, ); julia> conditioned_model() 1.0 @@ -255,15 +255,6 @@ julia> conditioned_model_fail() ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported [...] ``` - -And similarly when using `Dict`: - -```jldoctest condition -julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0); - -julia> conditioned_model_dict() -1.0 -``` """ function AbstractPPL.condition(model::Model, values...) # Positional arguments - need to handle cases carefully @@ -443,16 +434,16 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0); + cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); -julia> conditioned(cm).x +julia> conditioned(cm)[@varname(x)] 100.0 -julia> conditioned(cm).var"a.m" +julia> conditioned(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # No variables are sampled @@ -583,7 +574,7 @@ julia> model = demo_outer(); julia> model() ≠ 1.0 true -julia> fixed_model = fix(model, var"inner.m" = 1.0, ); +julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, )); julia> fixed_model() 1.0 @@ -599,24 +590,9 @@ julia> fixed_model() 2.0 ``` -And similarly when using `Dict`: - -```jldoctest fix -julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0); - -julia> fixed_model_dict() -1.0 - -julia> fixed_model_dict = fix(model, @varname(inner) => 2.0); - -julia> fixed_model_dict() -2.0 -``` - ## Difference from `condition` -A very similar functionality is also provided by [`condition`](@ref) which, -not surprisingly, _conditions_ variables instead of fixing them. The only +A very similar functionality is also provided by [`condition`](@ref). The only difference between fixing and conditioning is as follows: - `condition`ed variables are considered to be observations, and are thus included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref), @@ -798,16 +774,16 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0); + cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); -julia> fixed(cm).x +julia> fixed(cm)[@varname(x)] 100.0 -julia> fixed(cm).var"a.m" +julia> fixed(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # <= no variables are sampled @@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be ```jldoctest submodel-to_submodel julia> vi = VarInfo(demo2(missing, 0.4)); -julia> @varname(var\"a.x\") in keys(vi) +julia> @varname(a.x) in keys(vi) true ``` @@ -1379,7 +1355,7 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel -julia> x = vi[@varname(var\"a.x\")]; +julia> x = vi[@varname(a.x)]; julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true @@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z) julia> vi = VarInfo(demo2(missing, missing, 0.4)); -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -1437,9 +1413,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index e5a8e0617..f6b9c4479 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -96,10 +96,10 @@ julia> vi = VarInfo(demo2(missing, missing, 0.4)); │ caller = ip:0x0 └ @ Core :-1 -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -116,9 +116,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodelprefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); @@ -157,7 +157,7 @@ julia> # Automatically determined from `a`. @model submodel_prefix_true() = @submodel prefix=true a = inner() submodel_prefix_true (generic function with 2 methods) -julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true())) +julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -167,7 +167,7 @@ julia> # Using a static string. @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() submodel_prefix_string (generic function with 2 methods) -julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -177,7 +177,7 @@ julia> # Using string interpolation. @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() submodel_prefix_interpolation (generic function with 2 methods) -julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression. @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() submodel_prefix_expr (generic function with 2 methods) -julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/utils.jl b/src/utils.jl index 50f9baf61..56c3d70af 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1285,14 +1285,18 @@ broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) +# Convert (x=1,) to Dict(@varname(x) => 1) +_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. # TODO: Maybe replace the default of returning `NamedTuple` with `nothing`? _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) -_merge(left::AbstractDict, right::NamedTuple{()}) = left -_merge(left::NamedTuple{()}, right::AbstractDict) = right +_merge(left::AbstractDict, ::NamedTuple{()}) = left +_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(::NamedTuple{()}, right::AbstractDict) = right +_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..79e6d129b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1" diff --git a/test/compiler.jl b/test/compiler.jl index 3d3c6d9e3..a0286d405 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -481,8 +481,8 @@ module Issue537 end m = demo_useval(missing, missing) vi = VarInfo(m) ks = keys(vi) - @test VarName{Symbol("sub1.x")}() ∈ ks - @test VarName{Symbol("sub2.x")}() ∈ ks + @test @varname(sub1.x) ∈ ks + @test @varname(sub2.x) ∈ ks @test @varname(z) ∈ ks @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 @@ -505,7 +505,7 @@ module Issue537 end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x ~ to_submodel(prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) + x ~ to_submodel(DynamicPPL.prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) y[i] ~ MvNormal(x, 0.01 * I) end end @@ -514,8 +514,9 @@ module Issue537 end m = demo(ys) vi = VarInfo(m) - for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] - @test VarName{k}() ∈ keys(vi) + for vn in + [@varname(α), @varname(μ), @varname(σ), @varname(ar1_1.η), @varname(ar1_2.η)] + @test vn ∈ keys(vi) end end diff --git a/test/contexts.jl b/test/contexts.jl index faa831cc1..11e591f8f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -39,44 +39,39 @@ end Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() -""" - remove_prefix(vn::VarName) - -Return `vn` but now with the prefix removed. -""" -function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getoptic(vn) +@testset "contexts.jl" begin + child_contexts = Dict( + :default => DefaultContext(), + :prior => PriorContext(), + :likelihood => LikelihoodContext(), ) -end -@testset "contexts.jl" begin - child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] - - parent_contexts = [ - DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - SamplingContext(), - MiniBatchContext(DefaultContext(), 0.0), - PrefixContext{:x}(DefaultContext()), - PointwiseLogdensityContext(), - ConditionContext((x=1.0,)), - ConditionContext( + parent_contexts = Dict( + :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), + :sampling => SamplingContext(), + :minibatch => MiniBatchContext(DefaultContext(), 0.0), + :prefix => PrefixContext{:x}(DefaultContext()), + :pointwiselogdensity => PointwiseLogdensityContext(), + :condition1 => ConditionContext((x=1.0,)), + :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), - ConditionContext((x=[1.0, missing],)), - ] + :condition3 => ConditionContext( + (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + ), + :condition4 => ConditionContext((x=[1.0, missing],)), + ) - contexts = vcat(child_contexts, parent_contexts) + contexts = merge(child_contexts, parent_contexts) - @testset "$(context)" for context in contexts + @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) end end @testset "contextual_isassumption" begin - @testset "$context" for context in contexts + @testset "$(name)" for (name, context) in contexts # Any `context` should return `true` by default. @test contextual_isassumption(context, VarName{gensym(:x)}()) @@ -85,14 +80,28 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end + @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # Let's check elementwise. for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -108,7 +117,7 @@ end end @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$context" for context in contexts + @testset "$name" for (name, context) in contexts fake_vn = VarName{gensym(:x)}() @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) @@ -118,14 +127,26 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() - + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -153,51 +174,42 @@ end ) vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x) vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) ctx1 = PrefixContext{:a}(DefaultContext()) + @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) + @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext{:b}(ctx2) + @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) - vn_prefixed1 = prefix(ctx1, vn) - vn_prefixed2 = prefix(ctx2, vn) - vn_prefixed3 = prefix(ctx3, vn) - vn_prefixed4 = prefix(ctx4, vn) - @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed3) == Symbol("b.a.x") - @test DynamicPPL.getsym(vn_prefixed4) == Symbol("b.a.x") - @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) + @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end - context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + prefix = :my_prefix + context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) - # Extract the resulting symbols. - vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + # Extract the resulting varnames + vns_actual = Set(keys(varinfo)) - # Extract the ground truth symbols. - vns_syms = Set([ - Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for + # Extract the ground truth varnames + vns_expected = Set([ + AbstractPPL.prefix(vn, VarName{prefix}()) for vn in DynamicPPL.TestUtils.varnames(model) ]) # Check that all variables are prefixed correctly. - @test vns_syms == vns_varinfo_syms + @test vns_actual == vns_expected end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d4f6601f5..cac52693e 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -63,8 +63,8 @@ # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() - x1 ~ to_submodel(prefix(ModelInner(), :a), false) - x2 ~ to_submodel(prefix(ModelInner(), :b), false) + x1 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :a), false) + x2 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :b), false) return (x1, x2) end model = ModelOuterWorking2() diff --git a/test/deprecated.jl b/test/deprecated.jl index f12217983..500d3eb7f 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -31,7 +31,7 @@ @test outer()() isa Tuple{Float64,Float64} vi = VarInfo(outer()) @test @varname(x) in keys(vi) - @test @varname(var"sub.x") in keys(vi) + @test @varname(sub.x) in keys(vi) end @testset "logp is still accumulated properly" begin diff --git a/test/model.jl b/test/model.jl index a863b6596..447a9ecaa 100644 --- a/test/model.jl +++ b/test/model.jl @@ -448,15 +448,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() return nothing end @model function outer_manual_prefix() - a ~ to_submodel(prefix(inner(), :a), false) - b ~ to_submodel(prefix(inner(), :b), false) + a ~ to_submodel(DynamicPPL.prefix(inner(), :a), false) + b ~ to_submodel(DynamicPPL.prefix(inner(), :b), false) return nothing end for model in (outer_auto_prefix(), outer_manual_prefix()) vi = VarInfo(model) vns = Set(keys(values_as_in_model(model, false, vi))) - @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + @test vns == Set([@varname(a.x), @varname(b.x)]) end end end From cc5e581ee00424e88dfb25873ce54e69f8f65d8c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 28 Mar 2025 17:04:32 +0000 Subject: [PATCH 3/6] Remove VarInfo(VarInfo, params) (#870) --- HISTORY.md | 4 ++++ src/varinfo.jl | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cd2757edc..a956bd188 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,10 @@ **Breaking changes** +### VarInfo constructor + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. diff --git a/src/varinfo.jl b/src/varinfo.jl index 0c033e504..94b1f1c07 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -100,8 +100,6 @@ const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } -# TODO: Remove this -@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x) # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` From b9c368b500ed1e5904f2229e915d3cefddd45171 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Apr 2025 11:35:53 +0100 Subject: [PATCH 4/6] Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879) * Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really) --- HISTORY.md | 31 +- benchmarks/src/DynamicPPLBenchmarks.jl | 10 +- docs/src/api.md | 13 +- docs/src/internals/varinfo.md | 4 +- src/DynamicPPL.jl | 2 - src/abstract_varinfo.jl | 8 +- src/sampler.jl | 2 +- src/simple_varinfo.jl | 4 +- src/test_utils/contexts.jl | 2 +- src/test_utils/varinfo.jl | 10 +- src/varinfo.jl | 510 ++++++++++++++++--------- test/ext/DynamicPPLJETExt.jl | 8 +- test/model.jl | 18 +- test/simple_varinfo.jl | 2 +- test/test_util.jl | 18 +- test/varinfo.jl | 49 +-- 16 files changed, 436 insertions(+), 255 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index a956bd188..1af5c2ca3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,10 +4,25 @@ **Breaking changes** -### VarInfo constructor +### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. @@ -53,6 +68,20 @@ outer() | (a.x=1.0,) If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +**Other changes** + +While these are technically breaking, they are only internal changes and do not affect the public API. +The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata: + + 1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])` + 2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])` + 3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])` + 4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])` + +The reason for this change is that there were several flavours of VarInfo. +Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult. +This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. + ## 0.35.5 Several internal methods have been removed: diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 4c73bf355..16338de2f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -52,8 +52,8 @@ end Create a benchmark suite for `model` using the selected varinfo type and AD backend. Available varinfo choices: - • `:untyped` → uses `VarInfo()` - • `:typed` → uses `VarInfo(model)` + • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` + • `:typed` → uses `DynamicPPL.typed_varinfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: suite = BenchmarkGroup() vi = if varinfo_choice == :untyped - vi = VarInfo() - model(rng, vi) - vi + DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed - VarInfo(rng, model) + DynamicPPL.typed_varinfo(rng, model) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict diff --git a/docs/src/api.md b/docs/src/api.md index 2f6376f5d..f83a96886 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -291,18 +291,17 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. -For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: +#### `VarInfo` ```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo +VarInfo ``` -#### `VarInfo` - ```@docs -VarInfo -TypedVarInfo +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo +DynamicPPL.untyped_vector_varinfo +DynamicPPL.typed_vector_varinfo ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index e6e1f2619..b04913aaf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi ```@example varinfo-design # Type-unstable -varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] ``` ```@example varinfo-design # Type-stable -varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9f45718c5..51fa53079 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,8 +45,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - UntypedVarInfo, - TypedVarInfo, SimpleVarInfo, push!!, empty!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 44edaa4e9..f11b8a3ec 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector) 2.0 ``` -`TypedVarInfo`: +`VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -273,11 +273,11 @@ julia> values_as(vi, Vector) 2.0 ``` -`UntypedVarInfo`: +`VarInfo` with `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; diff --git a/src/sampler.jl b/src/sampler.jl index ff008cc93..49d910fec 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -86,7 +86,7 @@ function default_varinfo( context::AbstractContext, ) init_sampler = initialsampler(sampler) - return VarInfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler, context) end function AbstractMCMC.sample( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 064483ddd..abf14b8fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. $(FIELDS) # Notes -The major differences between this and `TypedVarInfo` are: +The major differences between this and `NTVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. 3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 5150be64b..7404a9af7 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) # Typed varinfo. - varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6a655ded4..539872143 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -27,12 +27,10 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) - vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) - model(vi_untyped_metadata) - model(vi_untyped_vnv) - vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) - vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_typed_metadata = DynamicPPL.typed_varinfo(model) + vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) diff --git a/src/varinfo.jl b/src/varinfo.jl index 94b1f1c07..360857ef7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,34 +69,91 @@ end ########### """ -``` -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -``` + struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} + end + +A light wrapper over some kind of metadata. -A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of -`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used -for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If -`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each -symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows -for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. +The type of the metadata can be one of a number of options. It may either be a +`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps +symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers +to a Julia variable and may consist of one or more `VarName`s which appear on +the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both +have the same symbol `x`. -Note: It is the user's responsibility to ensure that each "symbol" is visited at least -once whenever the model is called, regardless of any stochastic branching. Each symbol -refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. +Several type aliases are provided for these forms of VarInfos: +- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` +- `VarInfo{<:NamedTuple}` is `NTVarInfo` + +The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability +of model evaluation. However, the element type of NamedTuples are not contained +in its type itself: thus, there is no way to use the type system to determine +whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. + +Note that for NTVarInfo, it is the user's responsibility to ensure that each +symbol is visited at least once during model evaluation, regardless of any +stochastic branching. """ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end -const VectorVarInfo = VarInfo{<:VarNamedVector} +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +""" + VarInfo([rng, ]model[, sampler, context]) + +Generate a `VarInfo` object for the given `model`, by evaluating it once using +the given `rng`, `sampler`, and `context`. + +!!! warning + + This function currently returns a `VarInfo` with its metadata field set to + a `NamedTuple` of `Metadata`. This is an implementation detail. In general, + this function may return any kind of object that satisfies the + `AbstractVarInfo` interface. If you require precise control over the type + of `VarInfo` returned, use the internal functions `untyped_varinfo`, + `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` + instead. +""" +function VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(rng, model, sampler, context) +end +function VarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return VarInfo(Random.default_rng(), model, sampler, context) +end +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return VarInfo(rng, model, SampleFromPrior(), context) +end +function VarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +end + +const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} -const TypedVarInfo = VarInfo{<:NamedTuple} +# TODO: NTVarInfo carries no information about the type of the actual metadata +# i.e. the elements of the NamedTuple. It could be Metadata or it could be +# VarNamedVector. +# Resolving this ambiguity would likely require us to replace NamedTuple with +# something which carried both its keys as well as its values' types as type +# parameters. +const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } @@ -132,70 +189,245 @@ function metadata_to_varnamedvector(md::Metadata) ) end -function VectorVarInfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - -function VectorVarInfo(vi::TypedVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - function has_varnamedvector(vi::VarInfo) return vi.metadata isa VarNamedVector || - (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) + (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end +######################## +# VarInfo constructors # +######################## + """ - untyped_varinfo(model[, context, metadata]) + untyped_varinfo([rng, ]model[, sampler, context, metadata]) -Return an untyped varinfo object for the given `model` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `Metadata` as its metadata field. # Arguments -- `model::Model`: The model for which to create the varinfo object. -- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. -- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. - Default: `Metadata()`. +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + varinfo = VarInfo(Metadata()) + context = SamplingContext(rng, sampler, context) + return last(evaluate!!(model, varinfo, context)) +end function untyped_varinfo( model::Model, - context::AbstractContext=SamplingContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - varinfo = VarInfo(metadata) - return last( - evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) + # No rng + return untyped_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return untyped_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_varinfo(model, SampleFromPrior(), context) +end + +""" + typed_varinfo(vi::UntypedVarInfo) + +This function finds all the unique `sym`s from the instances of `VarName{sym}` found in +`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the +global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as +a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each +symbol. +""" +function typed_varinfo(vi::UntypedVarInfo) + meta = vi.metadata + new_metas = Metadata[] + # Symbols of all instances of `VarName{sym}` in `vi.vns` + syms_tuple = Tuple(syms(vi)) + for s in syms_tuple + # Find all indices in `vns` with symbol `s` + inds = findall(vn -> getsym(vn) === s, meta.vns) + n = length(inds) + # New `vns` + sym_vns = getindex.((meta.vns,), inds) + # New idcs + sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) + # New dists + sym_dists = getindex.((meta.dists,), inds) + # New orders + sym_orders = getindex.((meta.orders,), inds) + # New flags + sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + + # Extract new ranges and vals + _ranges = getindex.((meta.ranges,), inds) + # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 + _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + sym_ranges = Vector{eltype(_ranges)}(undef, n) + start = 0 + for i in 1:n + sym_ranges[i] = (start + 1):(start + length(_vals[i])) + start += length(_vals[i]) + end + sym_vals = foldl(vcat, _vals) + + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags + ), + ) + end + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple{syms_tuple}(Tuple(new_metas)) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_varinfo(vi::NTVarInfo) + # This function preserves the behaviour of typed_varinfo(vi) where vi is + # already a NTVarInfo + has_varnamedvector(vi) && error( + "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", ) + return vi +end +""" + typed_varinfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +`Metadata` structs as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function typed_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return typed_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return typed_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_varinfo(model, SampleFromPrior(), context) end """ - typed_varinfo(model[, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) -Return a typed varinfo object for the given `model`, `sampler` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `VarNamedVector` as its metadata field. -This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting -varinfo object to a typed varinfo object. +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function untyped_vector_varinfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function untyped_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_vector_varinfo(model, SampleFromPrior(), context) +end -See also: [`DynamicPPL.untyped_varinfo`](@ref) """ -typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) + typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) -function VarInfo( +Return a VarInfo object for the given `model` and `context`, which has a +NamedTuple of `VarNamedVector`s as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_vector_varinfo(vi::NTVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function typed_vector_varinfo(vi::UntypedVectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) end -function VarInfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... +function typed_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - return VarInfo(Random.default_rng(), model, args...) + # No rng + return typed_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return typed_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_vector_varinfo(model, SampleFromPrior(), context) end """ @@ -204,7 +436,7 @@ end Return the length of the vector representation of `varinfo`. """ vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) @@ -241,11 +473,6 @@ end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) -# without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - return VarInfo(rng, model, SampleFromPrior(), context) -end - #### #### Internal functions #### @@ -500,7 +727,7 @@ setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) """ getidx(vi::VarInfo, vn::VarName) @@ -541,7 +768,7 @@ end Return the range corresponding to `varname` in the vector representation of `varinfo`. """ vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) -function vector_getrange(vi::TypedVarInfo, vn::VarName) +function vector_getrange(vi::NTVarInfo, vn::VarName) offset = 0 for md in values(vi.metadata) # First, we need to check if `vn` is in `md`. @@ -563,8 +790,8 @@ Return the range corresponding to `varname` in the vector representation of `var function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) return map(Base.Fix1(vector_getrange, varinfo), varname) end -# Specialized version for `TypedVarInfo`. -function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) +# Specialized version for `NTVarInfo`. +function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName}) # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. @@ -624,7 +851,7 @@ end getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) # NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. # See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::TypedVarInfo, ::Colon) +function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end function getindex_internal(md::Metadata, ::Colon) @@ -684,10 +911,10 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::TypedVarInfo) = keys(vi.metadata) +syms(vi::NTVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) +_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] @@ -702,12 +929,11 @@ end findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ - all_varnames_grouped_by_symbol(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::NTVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_grouped_by_symbol(vi::TypedVarInfo) = - all_varnames_grouped_by_symbol(vi.metadata) +all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -745,73 +971,6 @@ end #### APIs for typed and untyped VarInfo #### -# VarInfo - -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) - -function TypedVarInfo(vi::VectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end - -""" - TypedVarInfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function TypedVarInfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), - ) - end - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -TypedVarInfo(vi::TypedVarInfo) = vi - function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) @@ -834,8 +993,8 @@ Base.keys(vi::VarInfo) = Base.keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} +Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] +@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) @@ -898,7 +1057,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -952,13 +1111,13 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link!(vi::NTVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, vns::NamedTuple) +function _link!(vi::NTVarInfo, vns::NamedTuple) return _link!(vi.metadata, vi, vns) end @@ -1002,7 +1161,7 @@ end return expr end -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1064,13 +1223,13 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink!(vi::NTVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, vns::NamedTuple) +function _invlink!(vi::NTVarInfo, vns::NamedTuple) return _invlink!(vi.metadata, vi, vns) end @@ -1121,7 +1280,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link(::DynamicTransformation, vi::NTVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1156,13 +1315,13 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1257,7 +1416,7 @@ function _link_metadata!!( return metadata end -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1297,13 +1456,13 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1394,7 +1553,7 @@ end # TODO(mhauru) The treatment of the case when some variables are linked and others are not # should be revised. It used to be the case that for UntypedVarInfo `islinked` returned -# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. """ @@ -1538,7 +1697,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::TypedVarInfo, vn::VarName) +function Base.haskey(vi::NTVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata haskey(metadata, vn) end @@ -1601,12 +1760,12 @@ the `VarInfo` `vi`, mutating if it makes sense. function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" + elseif vi isa NTVarInfo + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) - if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) md = Metadata( @@ -1627,18 +1786,18 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end -function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...) +function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi end -function Base.push!(vi::VectorVarInfo, pair::Pair, args...) +function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) vn, val = pair return push!(vi, vn, val, args...) end -# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't -# exist in the TypedVarInfo already. We could implement it in the cases where it it does +# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't +# exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. function Base.push!(meta::Metadata, vn, r, dist, num_produce) @@ -1760,7 +1919,7 @@ function set_retained_vns_del!(vi::UntypedVarInfo) end return nothing end -function set_retained_vns_del!(vi::TypedVarInfo) +function set_retained_vns_del!(vi::NTVarInfo) idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end @@ -1821,12 +1980,12 @@ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) return vi end -function _apply!(kernel!, vi::TypedVarInfo, values, keys) +function _apply!(kernel!, vi::NTVarInfo, values, keys) return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) end @generated function _typed_apply!( - kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys + kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n quote @@ -1963,7 +2122,8 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. +julia> var_info = DynamicPPL.VarInfo(rng, m); + # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 @@ -2061,8 +2221,8 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end -values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) +values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) function values_from_metadata(md::Metadata) return ( diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 933bfb1d1..86329a51d 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -14,7 +14,7 @@ @model demo2() = x ~ Normal() @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo @model function demo3() # Just making sure that nothing strange happens when type inference fails. @@ -53,7 +53,7 @@ end # Should pass if we're only checking the tilde statements. @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo # Should fail if we're including errors in the model body. @test DynamicPPL.Experimental.determine_suitable_varinfo( demo5(); only_ddpl=false @@ -75,11 +75,11 @@ ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.TypedVarInfo + is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed - typed_vi = VarInfo(model) + typed_vi = DynamicPPL.typed_varinfo(model) f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, typed_vi ) diff --git a/test/model.jl b/test/model.jl index 447a9ecaa..dd5a35fe6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,9 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true -is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false +is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true +is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -233,8 +233,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - spl = SampleFromPrior() - vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) + vi = VarInfo(model) vi = link!!(vi, model) for i in 1:10 @@ -250,8 +249,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, VectorVarInfo" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() for i in 1:10 - vi = VarInfo(model) - @test vi[@varname(x)] >= vi[@varname(m)] + for vi_constructor in + [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] + vi = vi_constructor(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end end end @@ -400,7 +402,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( - is_typed_varinfo, + is_type_stable_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 8e48814a4..aa3b592f7 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -92,7 +92,7 @@ SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), - VarInfo(model), + DynamicPPL.typed_varinfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) diff --git a/test/test_util.jl b/test/test_util.jl index 87c69b5fe..902dd7230 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -33,14 +33,18 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" - return "TypedVarInfo" +function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) + return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(vi::DynamicPPL.NTVarInfo) + return if DynamicPPL.has_varnamedvector(vi) + "TypedVectorVarInfo" + else + "TypedVarInfo" + end +end +short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 74feb42f6..777917aa6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -34,7 +34,7 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) end @testset "varinfo.jl" begin - @testset "TypedVarInfo with Metadata" begin + @testset "VarInfo with NT of Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -43,9 +43,8 @@ end end model = gdemo(1.0, 2.0) - vi = VarInfo(DynamicPPL.Metadata()) - model(vi, SampleFromUniform()) - tvi = TypedVarInfo(vi) + vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -102,7 +101,7 @@ end @test vi[vn] == 2 * r # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.VectorVarInfo + if vi isa DynamicPPL.UntypedVectorVarInfo delete!(vi, vn) @test isempty(vi) vi = push!!(vi, vn, r, dist) @@ -116,7 +115,7 @@ end vi = VarInfo() test_base!!(vi) - test_base!!(TypedVarInfo(vi)) + test_base!!(DynamicPPL.typed_varinfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -135,7 +134,7 @@ end vi = VarInfo() test_varinfo_logp!(vi) - test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) test_varinfo_logp!(SimpleVarInfo(Dict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -160,17 +159,17 @@ end unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo(DynamicPPL.Metadata()) + vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) + test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) end - @testset "push!! to TypedVarInfo" begin + @testset "push!! to VarInfo with NT of Metadata" begin vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = TypedVarInfo(untyped_vi) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 @@ -206,16 +205,10 @@ end m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() - ) - vi_untyped = VarInfo(DynamicPPL.Metadata()) - vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) - vi_vnv_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() - ) - model(vi_untyped, SampleFromPrior()) - model(vi_vnv, SampleFromPrior()) + vi_typed = DynamicPPL.typed_varinfo(model) + vi_untyped = DynamicPPL.untyped_varinfo(model) + vi_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) model_name = model == model_uv ? "univariate" : "multivariate" @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ @@ -405,7 +398,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = TypedVarInfo(vi) + vi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) @@ -459,9 +452,9 @@ end # Need to run once since we can't specify that we want to _sample_ # in the unconstrained space for `VarInfo` without having `vn` # present in the `varinfo`. - ## `UntypedVarInfo` - vi = VarInfo() - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -469,8 +462,8 @@ end x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - ## `TypedVarInfo` - vi = VarInfo(model) + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -979,7 +972,7 @@ end @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) From be2763633a8c47103cf4943d41d893849d22b6ec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 16 Apr 2025 13:28:50 +0100 Subject: [PATCH 5/6] Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890) * Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat --- HISTORY.md | 12 +++++ docs/src/api.md | 1 + src/test_utils/ad.jl | 104 ++++++++++++++++++++++++++--------------- src/transforming.jl | 4 +- test/ad.jl | 18 +++---- test/simple_varinfo.jl | 6 --- 6 files changed, 91 insertions(+), 54 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index a21258ec0..a45644a64 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,18 @@ **Breaking changes** +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + ### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. diff --git a/docs/src/api.md b/docs/src/api.md index 2c61f54fc..ec741c9ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad DynamicPPL.TestUtils.AD.ADResult +DynamicPPL.TestUtils.AD.ADIncorrectException ``` ## Demo models diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 06c76df5e..d38915c12 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median using Test: @test -export ADResult, run_ad - -# This function needed to work around the fact that different backends can -# return different AbstractArrays for the gradient. See -# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more -# context. -_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x) +export ADResult, run_ad, ADIncorrectException """ REFERENCE_ADTYPE @@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl. const REFERENCE_ADTYPE = AutoForwardDiff() """ - ADResult + ADIncorrectException{T<:AbstractFloat} + +Exception thrown when an AD backend returns an incorrect value or gradient. + +The type parameter `T` is the numeric type of the value and gradient. +""" +struct ADIncorrectException{T<:AbstractFloat} <: Exception + value_expected::T + value_actual::T + grad_expected::Vector{T} + grad_actual::Vector{T} +end + +""" + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} Data structure to store the results of the AD correctness test. + +The type parameter `Tparams` is the numeric type of the parameters passed in; +`Tresult` is the type of the value and the gradient. """ -struct ADResult +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" varinfo::AbstractVarInfo "The values at which the model was evaluated" - params::Vector{<:Real} + params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType "The absolute tolerance for the value of logp" - value_atol::Real + value_atol::Tresult "The absolute tolerance for the gradient of logp" - grad_atol::Real + grad_atol::Tresult "The expected value of logp" - value_expected::Union{Nothing,Float64} + value_expected::Union{Nothing,Tresult} "The expected gradient of logp" - grad_expected::Union{Nothing,Vector{Float64}} + grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Real} + value_actual::Union{Nothing,Tresult} "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Float64}} + grad_actual::Union{Nothing,Vector{Tresult}} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" - time_vs_primal::Union{Nothing,Float64} + time_vs_primal::Union{Nothing,Tresult} end """ @@ -64,26 +75,27 @@ end benchmark=false, value_atol=1e-6, grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult +### Description + Test the correctness and/or benchmark the AD backend `adtype` for the model `model`. Whether to test and benchmark is controlled by the `test` and `benchmark` keyword arguments. By default, `test` is `true` and `benchmark` is `false`. -Returns an [`ADResult`](@ref) object, which contains the results of the -test and/or benchmark. - Note that to run AD successfully you will need to import the AD backend itself. For example, to test with `AutoReverseDiff()` you will need to run `import ReverseDiff`. +### Arguments + There are two positional arguments, which absolutely must be provided: 1. `model` - The model being tested. @@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups: DynamicPPL contains several different types of VarInfo objects which change the way model evaluation occurs. If you want to use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to - using a `TypedVarInfo` generated from the model. + using a linked `TypedVarInfo` generated from the model. Here, _linked_ + means that the parameters in the VarInfo have been transformed to + unconstrained Euclidean space if they aren't already in that space. 2. _How to specify the parameters._ @@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups: By default, this function prints messages when it runs. To silence it, set `verbose=false`. + +### Returns / Throws + +Returns an [`ADResult`](@ref) object, which contains the results of the +test and/or benchmark. + +If `test` is `true` and the AD backend returns an incorrect value or gradient, an +`ADIncorrectException` is thrown. If a different error occurs, it will be +thrown as-is. """ function run_ad( model::Model, adtype::AbstractADType; - test=true, - benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + test::Bool=true, + benchmark::Bool=false, + value_atol::AbstractFloat=1e-6, + grad_atol::AbstractFloat=1e-6, + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + if isnothing(params) + params = varinfo[:] + end + params = map(identity, params) # Concretise + verbose && @info "Running AD on $(model.f) with $(adtype)\n" - params = map(identity, params) verbose && println(" params : $(params)") ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) - grad = _to_vec_f64(grad) + grad = collect(grad) verbose && println(" actual : $((value, grad))") if test @@ -172,10 +199,11 @@ function run_ad( expected_value_and_grad end verbose && println(" expected : $((value_true, grad_true))") - grad_true = _to_vec_f64(grad_true) - # Then compare - @test isapprox(value, value_true; atol=value_atol) - @test isapprox(grad, grad_true; atol=grad_atol) + grad_true = collect(grad_true) + + exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) + isapprox(value, value_true; atol=value_atol) || exc() + isapprox(grad, grad_true; atol=grad_atol) || exc() else value_true = nothing grad_true = nothing diff --git a/src/transforming.jl b/src/transforming.jl index 0239725ae..429562ec8 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -19,9 +19,9 @@ function tilde_assume( lp = Bijectors.logpdf_with_trans(right, r, !isinverse) if istrans(vi, vn) - @assert isinverse "Trying to link already transformed variables" + isinverse || @warn "Trying to link an already transformed variable ($vn)" else - @assert !isinverse "Trying to invlink non-transformed variables" + isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end # Only transform if `!isinverse` since `vi[vn, right]` diff --git a/test/ad.jl b/test/ad.jl index 33d581228..69ab99e19 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = LogDensityFunction(m, varinfo) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" # Put predicates here to avoid long lines is_mooncake = adtype isa AutoMooncake is_1_10 = v"1.10" <= VERSION < v"1.11" is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} + is_svi_vnv = + linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv @@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction ref_ldf, adtype ) else - DynamicPPL.TestUtils.AD.run_ad( + @test DynamicPPL.TestUtils.AD.run_ad( m, adtype; - varinfo=varinfo, + varinfo=linked_varinfo, expected_value_and_grad=(ref_logp, ref_grad), - ) + ) isa Any end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index aa3b592f7..380c24e7d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -111,12 +111,6 @@ # Should be approx. the same as the "lazy" transformation. @test logjoint(model, vi_linked) ≈ lp_linked - # TODO: Should not `VarInfo` also error here? The current implementation - # only warns and acts as a no-op. - if vi isa SimpleVarInfo - @test_throws AssertionError link!!(vi_linked, model) - end - # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) lp_invlinked = getlogp(vi_invlinked) From ff5f2cba98aecac764288267f059c43e162729cb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Apr 2025 12:09:31 +0100 Subject: [PATCH 6/6] Fix `condition` and `fix` in submodels (#892) * Fix conditioning in submodels * Simplify contextual_isassumption * Add documentation * Fix some tests * Add tests; fix a bunch of nested submodel issues * Fix fix as well * Fix doctests * Add unit tests for new functions * Add changelog entry * Update changelog Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Finish docs * Add a test for conditioning submodel via arguments * Clean new tests up a bit * Fix for VarNames with non-identity lenses * Apply suggestions from code review Co-authored-by: Markus Hauru * Apply suggestions from code review * Make PrefixContext contain a varname rather than symbol (#896) --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Markus Hauru --- HISTORY.md | 98 +++++-- docs/Project.toml | 1 + docs/make.jl | 4 +- docs/src/api.md | 14 +- docs/src/internals/submodel_condition.md | 356 +++++++++++++++++++++++ src/compiler.jl | 60 ++-- src/context_implementations.jl | 42 ++- src/contexts.jl | 211 ++++++++++++-- src/model.jl | 71 ++--- src/submodel_macro.jl | 4 +- src/utils.jl | 9 +- test/contexts.jl | 273 +++++++++++------ test/runtests.jl | 1 + test/submodels.jl | 199 +++++++++++++ 14 files changed, 1109 insertions(+), 234 deletions(-) create mode 100644 docs/src/internals/submodel_condition.md create mode 100644 test/submodels.jl diff --git a/HISTORY.md b/HISTORY.md index a45644a64..ac3e40970 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,38 +4,25 @@ **Breaking changes** -### AD testing utilities - -`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. -To disable this, pass the `linked=false` keyword argument. -If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. -This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. -From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. +### Submodels: conditioning -### SimpleVarInfo linking / invlinking - -Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. +Variables in a submodel can now be conditioned and fixed in a correct way. +See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: -### VarInfo constructors - -`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. - -The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. -If you were not using this argument (most likely), then there is no change needed. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - -The `UntypedVarInfo` constructor and type is no longer exported. -If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. - -The `TypedVarInfo` constructor and type is no longer exported. -The _type_ has been replaced with `DynamicPPL.NTVarInfo`. -The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. +```julia +@model function inner() + x ~ Normal() + return y ~ Normal() +end +@model function outer() + return a ~ to_submodel(inner() | (x=1.0,)) +end +``` -Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. -Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. -Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. +and the `a.x` variable will be correctly conditioned. +(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) -### VarName prefixing behaviour +### Submodel prefixing The way in which VarNames in submodels are prefixed has been changed. This is best explained through an example. @@ -77,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,) outer() | (a.x=1.0,) ``` -If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected. +Consider the following setup: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model function outer() + a = Vector{Float64}(undef, 1) + a[1] ~ to_submodel(inner()) + return a +end +``` + +In this case, the variable sampled is actually the `x` field of the first element of `a`: + +```julia +julia> only(keys(VarInfo(outer()))) == @varname(a[1].x) +true +``` + +Before this version, it used to be a single variable called `var"a[1].x"`. + +Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + +### VarInfo constructors + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + **Other changes** While these are technically breaking, they are only internal changes and do not affect the public API. diff --git a/docs/Project.toml b/docs/Project.toml index 40a719e03..93f449308 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..7984fa1d1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,9 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ - "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] + "Home" => "index.md", + "API" => "api.md", + "Internals" => ["internals/varinfo.md", "internals/submodel_condition.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index ec741c9ad..08522e2ce 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -78,9 +78,9 @@ decondition ## Fixing and unfixing -We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref). +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain values using [`DynamicPPL.fix`](@ref). -This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, +This is quite similar to the aforementioned [`condition`](@ref) and its siblings, but they are indeed different operations: - `condition`ed variables are considered to be _observations_, and are thus @@ -89,19 +89,19 @@ but they are indeed different operations: - `fix`ed variables are considered to be _constant_, and are thus not included in any log-probability computations. -The differences are more clearly spelled out in the docstring of [`fix`](@ref) below. +The differences are more clearly spelled out in the docstring of [`DynamicPPL.fix`](@ref) below. ```@docs -fix +DynamicPPL.fix DynamicPPL.fixed ``` -The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above. +The difference between [`DynamicPPL.fix`](@ref) and [`DynamicPPL.condition`](@ref) is described in the docstring of [`DynamicPPL.fix`](@ref) above. -Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning: +Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the variables to their original meaning: ```@docs -unfix +DynamicPPL.unfix ``` ## Predicting diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md new file mode 100644 index 000000000..ecb9d452b --- /dev/null +++ b/docs/src/internals/submodel_condition.md @@ -0,0 +1,356 @@ +# How `PrefixContext` and `ConditionContext` interact + +```@meta +ShareDefaultModule = true +``` + +## PrefixContext + +`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. +Thus, for example: + +```@example +using DynamicPPL, Distributions + +@model function f() + x ~ Normal() + return y ~ Normal() +end + +@model function g() + return a ~ to_submodel(f()) +end +``` + +inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively. +This is easiest to observe by running the model: + +```@example +vi = VarInfo(g()) +keys(vi) +``` + +!!! note + + In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. + We will return to the 'manual prefixing' case later. + +The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`. +The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`. +This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function from `tilde_assume` onwards will see `a.x` instead. + +## ConditionContext + +`ConditionContext` is a context which stores values of variables that are to be conditioned on. +These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`. +The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden). +Because of this limitation, we will only use `Dict` in this example. + +!!! note + + If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition. + +One can inspect the conditioning values with, for example: + +```@example +@model function d() + x ~ Normal() + return y ~ Normal() +end + +cond_model = d() | (@varname(x) => 1.0) +cond_ctx = cond_model.context +``` + +There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is. + +```@example +DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x)) +``` + +```@example +DynamicPPL.getconditioned_nested(cond_ctx, @varname(x)) +``` + +These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned). + +```@example +DynamicPPL.contextual_isassumption(cond_ctx, @varname(x)) +``` + +!!! note + + Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`. + +If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density). +Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`: + +```@example +keys(VarInfo(cond_model)) +``` + +## Joint behaviour: desiderata at the model level + +When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`. + +We begin by mentioning some high-level desiderata for their joint behaviour. +Take these models, for example: + +```@example +# We define a helper function to unwrap a layer of SamplingContext, to +# avoid cluttering the print statements. +unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context +unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx +@model function inner() + println("inner context: $(unwrap_sampling_context(__context__))") + x ~ Normal() + return y ~ Normal() +end + +@model function outer() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner()) +end + +# 'Outer conditioning' +with_outer_cond = outer() | (@varname(a.x) => 1.0) + +# 'Inner conditioning' +inner_cond = inner() | (@varname(x) => 1.0) +@model function outer2() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner_cond) +end +with_inner_cond = outer2() +``` + +We want that: + + 1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`; + 2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`; + 3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`, + +**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.** + +This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. +For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing. + +!!! info + + In the current version of DynamicPPL, these criteria are all fulfilled. However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. (See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.) + +## Desiderata at the context level + +The above section describes how we expect conditioning and prefixing to behave from a user's perpective. +We now turn to the question of how we implement this in terms of DynamicPPL contexts. +We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour. + +**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above. + +**Points (2) and (3)** are more tricky. +As the reader may surmise, the difference between them is the order in which the contexts are stacked. + +For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed. +When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel. +We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`. +This is indeed what happens, as can be demonstrated by running the model. + +```@example +with_outer_cond(); +nothing; +``` + +!!! info + + The `; nothing` at the end is purely to circumvent a Documenter.jl quirk where stdout is only shown if the return value of the final statement is `nothing`. + If these documentation pages are moved to Quarto, it will be possible to remove this. + +For the _inner_ conditioning case (point (3)), the outer model is not run with any special context. +The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed. +When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context. +Again, we can run the model to see this in action: + +```@example +with_inner_cond(); +nothing; +``` + +Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above): + +```@example +using DynamicPPL: PrefixContext, ConditionContext, DefaultContext + +inner_ctx_with_outer_cond = ConditionContext( + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) +) +inner_ctx_with_inner_cond = PrefixContext( + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) +) +``` + +then we want both of these to be `true` (and thankfully, they are!): + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) +``` + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) +``` + +This allows us to finally specify our task as follows: + +(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. + +(2) We need to make sure that both the correct arguments are supplied. In order to do so: + + - (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context. + + - (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly. + +## How do we do it? + +(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack. +Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself. +For more details the reader is encouraged to read the source code. + +(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_. +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above. + +(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section). +However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is. +So, we need to explicitly prefix it before passing it to `contextual_isassumption`. +This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`. + +## Nested submodels + +Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s. +For example, in this series of nested submodels, + +```@example +@model function charlie() + x ~ Normal() + y ~ Normal() + return z ~ Normal() +end +@model function bravo() + return b ~ to_submodel(charlie() | (@varname(x) => 1.0)) +end +@model function alpha() + return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0)) +end +``` + +we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes. + +```@example +keys(VarInfo(alpha())) +``` + +The general strategy that we adopt is similar to above. +Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: + +```@example +big_ctx = PrefixContext( + @varname(a), + ConditionContext( + Dict(@varname(b.y) => 1.0), + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), + ), +) +``` + +We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`. +It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does. + +Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function). +We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`. +Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline: + +```@example +using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext +using AbstractPPL: AbstractPPL + +function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName) + return myprefix(NodeTrait(ctx), ctx, vn) +end +function myprefix(::IsLeaf, ::AbstractContext, vn::VarName) + return vn +end +function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) + return myprefix(childcontext(ctx), vn) +end +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # The functionality to actually manipulate the VarNames is in AbstractPPL + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) + # Then pass to the child context + return myprefix(childcontext(ctx), new_vn) +end + +myprefix(big_ctx, @varname(x)) +``` + +This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one. + +The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: + +```@example +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # Pass to the child context first + new_vn = myprefix(childcontext(ctx), vn) + # Then apply this context's prefix + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) +end + +myprefix(big_ctx, @varname(x)) +``` + +This is a much better result! +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. +When editing this code, it is worth being mindful of this as a potential source of incorrectness. + +!!! info + + If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold. + +## Loose ends 1: Manual prefixing + +Sometimes users may want to manually prefix a model, for example: + +```@example +@model function inner_manual() + x ~ Normal() + return y ~ Normal() +end + +@model function outer_manual() + return _unused ~ to_submodel(prefix(inner_manual(), :a), false) +end +``` + +In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function. + +The way to deal with this follows on from the previous discussion. +Specifically, we said that: + +> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...] + +When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method. +In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context. +We can see that this is precisely what happens: + +```@example +@model f() = x ~ Normal() + +model = f() +prefixed_model = prefix(model, :a) + +(model.context, prefixed_model.context) +``` + +## Loose ends 2: FixedContext + +Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. +(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) +This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same. diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..6f7489b8e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,9 @@ function isassumption( vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), ) return quote - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + if $(DynamicPPL.contextual_isassumption)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -87,67 +89,45 @@ isassumption(expr) = :(false) contextual_isassumption(context, vn) Return `true` if `vn` is considered an assumption by `context`. - -The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(::IsLeaf, context, vn) = true -function contextual_isassumption(::IsParent, context, vn) - return contextual_isassumption(childcontext(context), vn) -end function contextual_isassumption(context::AbstractContext, vn) - return contextual_isassumption(NodeTrait(context), context, vn) -end -function contextual_isassumption(context::ConditionContext, vn) - if hasconditioned(context, vn) - val = getconditioned(context, vn) + if hasconditioned_nested(context, vn) + val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true else return false end + else + return true end - - # We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isassumption(childcontext(context), vn) -end -function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix(context, vn)) end isfixed(expr, vn) = false -isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn)) +function isfixed(::Union{Symbol,Expr}, vn) + return :($(DynamicPPL.contextual_isfixed)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + )) +end """ contextual_isfixed(context, vn) Return `true` if `vn` is considered fixed by `context`. """ -contextual_isfixed(::IsLeaf, context, vn) = false -function contextual_isfixed(::IsParent, context, vn) - return contextual_isfixed(childcontext(context), vn) -end function contextual_isfixed(context::AbstractContext, vn) - return contextual_isfixed(NodeTrait(context), context, vn) -end -function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix(context, vn)) -end -function contextual_isfixed(context::FixedContext, vn) - if hasfixed(context, vn) - val = getfixed(context, vn) + if hasfixed_nested(context, vn) + val = getfixed_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return false else return true end + else + return false end - - # We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isfixed(childcontext(context), vn) end # If we're working with, say, a `Symbol`, then we're not going to `view`. @@ -467,13 +447,17 @@ function generate_tilde(left, right) ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + $left = $(DynamicPPL.getfixed_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) + $left = $(DynamicPPL.getconditioned_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..eb025dec8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,12 +85,23 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(new_context, right, new_vn, vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(rng, new_context, sampler, right, new_vn, vi) end """ @@ -104,12 +115,27 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) return if is_rhs_model(right) - # Prefix the variables using the `vn`. - rand_like!!( - right, - should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, - vi, - ) + # Here, we apply the PrefixContext _not_ to the parent `context`, but + # to the context of the submodel being evaluated. This means that later= + # on in `make_evaluate_args_and_kwargs`, the context stack will be + # correctly arranged such that it goes like this: + # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> + # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext + # See the docstring of `make_evaluate_args_and_kwargs`, and the internal + # DynamicPPL documentation on submodel conditioning, for more details. + # + # NOTE: This relies on the existence of `right.model.model`. Right now, + # the only thing that can return true for `is_rhs_model` is something + # (a `Sampleable`) that has a `model` field that itself (a + # `ReturnedModelWrapper`) has a `model` field. This may or may not + # change in the future. + if should_auto_prefix(right) + dppl_model = right.model.model # This isa DynamicPPL.Model + prefixed_submodel_context = PrefixContext(vn, dppl_model.context) + new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) + right = to_submodel(new_dppl_model, true) + end + rand_like!!(right, context, vi) else value, logp, vi = tilde_assume(context, right, vn, vi) value, acclogp_assume!!(context, vi, logp) diff --git a/src/contexts.jl b/src/contexts.jl index 58ac612b8..8ac085663 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -237,27 +237,34 @@ function setchildcontext(parent::MiniBatchContext, child) end """ - PrefixContext{Prefix}(context) + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} Create a context that allows you to use the wrapped `context` when running the model and -adds the `Prefix` to all parameters. +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. This context is useful in nested models to ensure that the names of the parameters are unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn context::C end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} - return PrefixContext{Prefix}(child) +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) end """ @@ -265,8 +272,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -277,11 +284,52 @@ function prefix(::IsParent, ctx::AbstractContext, vn::VarName) end """ - prefix(model::Model, x) + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. -Return `model` but with all random variables prefixed by `x`. +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. -If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. # Examples @@ -291,17 +339,19 @@ julia> using DynamicPPL: prefix julia> @model demo() = x ~ Dirac(1) demo (generic function with 2 methods) -julia> rand(prefix(demo(), :my_prefix)) +julia> rand(prefix(demo(), @varname(my_prefix))) (var"my_prefix.x" = 1,) -julia> # One can also use `Val` to avoid runtime overheads. - rand(prefix(demo(), Val(:my_prefix))) +julia> rand(prefix(demo(), Val(:my_prefix))) (var"my_prefix.x" = 1,) ``` """ -prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) -function prefix(model::Model, ::Val{x}) where {x} - return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) end """ @@ -370,7 +420,9 @@ Return value of `vn` in `context`. function getconditioned(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end """ hasconditioned_nested(context, vn) @@ -388,7 +440,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix(context, vn)) + return hasconditioned_nested(collapse_prefix_stack(context), vn) end """ @@ -406,7 +458,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix(context, vn)) + return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -476,6 +528,9 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -539,7 +594,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix(context, vn)) + return hasfixed_nested(collapse_prefix_stack(context), vn) end """ @@ -557,7 +612,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix(context, vn)) + return getfixed_nested(collapse_prefix_stack(context), vn) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) @@ -652,3 +707,113 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..c7c4bdf57 100644 --- a/src/model.jl +++ b/src/model.jl @@ -425,29 +425,32 @@ julia> # Returns all the variables we have conditioned on + their values. conditioned(condition(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); +julia> # Nested ones also work. + # (Note that `PrefixContext` also prefixes the variables of any + # ConditionContext that is _inside_ it; because of this, the type of the + # container has to be broadened to a `Dict`.) + cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); -julia> conditioned(cm) -(x = 100.0, m = 1.0) +julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. +julia> # Since we conditioned on `a.m`, it is not treated as a random variable. + # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m - -julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> conditioned(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> conditioned(cm)[@varname(a.m)] -1.0 +julia> conditioned(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # No variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ conditioned(model::Model) = conditioned(model.context) @@ -765,29 +768,27 @@ julia> # Returns all the variables we have fixed on + their values. fixed(fix(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); +julia> # The rest of this is the same as the `condition` example above. + cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); -julia> fixed(cm) -(x = 100.0, m = 1.0) - -julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m +julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); +julia> keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> fixed(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> fixed(cm)[@varname(a.m)] -1.0 +julia> fixed(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ fixed(model::Model) = fixed(model.context) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..5f1ec95ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -223,12 +223,12 @@ end prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) function prefix_submodel_context(prefix, ctx) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) end function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) # E.g. `prefix="asd"`. - return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) end function prefix_submodel_context(prefix::Bool, ctx) diff --git a/src/utils.jl b/src/utils.jl index 56c3d70af..71919480c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1286,7 +1286,10 @@ broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) # Convert (x=1,) to Dict(@varname(x) => 1) -_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) +function to_varname_dict(nt::NamedTuple) + return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) +end +to_varname_dict(d::AbstractDict) = d # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. @@ -1294,9 +1297,9 @@ _nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, ::NamedTuple{()}) = left -_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(right)) _merge(::NamedTuple{()}, right::AbstractDict) = right -_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) +_merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/test/contexts.jl b/test/contexts.jl index 11e591f8f..1ba099a37 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,5 @@ using Test, DynamicPPL, Accessors +using AbstractPPL: getoptic using DynamicPPL: leafcontext, setleafcontext, @@ -10,12 +11,18 @@ using DynamicPPL: IsParent, PointwiseLogdensityContext, contextual_isassumption, + FixedContext, ConditionContext, decondition_context, hasconditioned, getconditioned, + conditioned, + fixed, hasconditioned_nested, - getconditioned_nested + getconditioned_nested, + collapse_prefix_stack, + prefix_cond_and_fixed_variables, + getvalue using EnzymeCore @@ -50,14 +57,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :minibatch => MiniBatchContext(DefaultContext(), 0.0), - :prefix => PrefixContext{:x}(DefaultContext()), + :prefix => PrefixContext(@varname(x)), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + (x=1.0,), + PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -70,91 +78,52 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "contextual_isassumption" begin - @testset "$(name)" for (name, context) in contexts - # Any `context` should return `true` by default. - @test contextual_isassumption(context, VarName{gensym(:x)}()) - - if any(Base.Fix2(isa, ConditionContext), context) - # We have a `ConditionContext` among us. - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) + @testset "extracting conditioned values" begin + # This testset tests `contextual_isassumption`, `getconditioned_nested`, and + # `hasconditioned_nested`. - # The conditioned values might be a NamedTuple, or a Dict. - # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end - - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) - else - vn - end - - @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # Let's check elementwise. - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if getoptic(vn_child)(val) === missing - @test contextual_isassumption(context, vn_child) - else - @test !contextual_isassumption(context, vn_child) - end - end - end - end - end - end - - @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$name" for (name, context) in contexts + @testset "$(name)" for (name, context) in contexts + # If the varname doesn't exist, it should always be an assumption. fake_vn = VarName{gensym(:x)}() + @test contextual_isassumption(context, fake_vn) @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) if any(Base.Fix2(isa, ConditionContext), context) - # `ConditionContext` specific. - + # We have a `ConditionContext` among us. # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end + conditioned_values = DynamicPPL.to_varname_dict(conditioned_values) + + # Extract all conditioned variables. We also use varname_leaves + # here to split up arrays which could potentially have some, + # but not all, elements being `missing`. + conditioned_vns = mapreduce( + p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + vcat, + pairs(conditioned_values), + ) - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + # We can now loop over them to check which ones are missing. We use + # `getvalue` to handle the awkward case where sometimes + # `conditioned_values` contains the full Varname (e.g. `a.x`) and + # sometimes only the main symbol (e.g. it contains `x` when + # `vn` is `x[1]`) + for vn in conditioned_vns + val = DynamicPPL.getvalue(conditioned_values, vn) + # These VarNames are present in the conditioning values, so + # we should always be able to extract the value. + @test hasconditioned_nested(context, vn) + @test getconditioned_nested(context, vn) === val + # However, the return value of contextual_isassumption depends on + # whether the value is missing or not. + if ismissing(val) + @test contextual_isassumption(context, vn) else - vn - end - - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # `vn_child` should be in `context`. - @test hasconditioned_nested(context, vn_child) - # Value should be the same as extracted above. - @test getconditioned_nested(context, vn_child) === - getoptic(vn_child)(val) + @test !contextual_isassumption(context, vn) end end end @@ -163,39 +132,68 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) + ctx = @inferred PrefixContext( + @varname(a), + PrefixContext( + @varname(b), + PrefixContext( + @varname(c), + PrefixContext( + @varname(d), + PrefixContext( + @varname(e), PrefixContext(@varname(f), DefaultContext()) + ), ), ), ), ) - vn = VarName{:x}() + vn = @varname(x) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) - vn = VarName{:x}(((1,),)) + vn = @varname(x[1]) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext{:b}(ctx2) + ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end + @testset "prefix_and_strip_contexts" begin + vn = @varname(x[1]) + ctx1 = PrefixContext(@varname(a)) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == DefaultContext() + + ctx2 = SamplingContext(PrefixContext(@varname(a))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext() + + ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == ConditionContext((a=1,)) + + ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext(ConditionContext((a=1,))) + end + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix = :my_prefix - context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) + prefix_vn = @varname(my_prefix) + context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) @@ -204,7 +202,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Extract the ground truth varnames vns_expected = Set([ - AbstractPPL.prefix(vn, VarName{prefix}()) for + AbstractPPL.prefix(vn, prefix_vn) for vn in DynamicPPL.TestUtils.varnames(model) ]) @@ -343,4 +341,103 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) end end + + @testset "PrefixContext + Condition/FixedContext interactions" begin + @testset "prefix_cond_and_fixed_variables" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + + @testset "collapse_prefix_stack" begin + # Utility function to make sure that there are no PrefixContexts in + # the context stack. + function has_no_prefixcontexts(ctx::AbstractContext) + return !(ctx isa PrefixContext) && ( + NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) + ) + end + + # Prefix -> Condition + c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) + c1 = collapse_prefix_stack(c1) + @test has_no_prefixcontexts(c1) + c1_vals = conditioned(c1) + @test length(c1_vals) == 2 + @test getvalue(c1_vals, @varname(a.c)) == 1 + @test getvalue(c1_vals, @varname(a.d)) == 2 + + # Condition -> Prefix + c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) + c2 = collapse_prefix_stack(c2) + @test has_no_prefixcontexts(c2) + c2_vals = conditioned(c2) + @test length(c2_vals) == 2 + @test getvalue(c2_vals, @varname(c)) == 1 + @test getvalue(c2_vals, @varname(d)) == 2 + + # Prefix -> Fixed + c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) + c3 = collapse_prefix_stack(c3) + c3_vals = fixed(c3) + @test length(c3_vals) == 2 + @test length(c3_vals) == 2 + @test getvalue(c3_vals, @varname(a.f)) == 1 + @test getvalue(c3_vals, @varname(a.g)) == 2 + + # Fixed -> Prefix + c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) + c4 = collapse_prefix_stack(c4) + @test has_no_prefixcontexts(c4) + c4_vals = fixed(c4) + @test length(c4_vals) == 2 + @test getvalue(c4_vals, @varname(f)) == 1 + @test getvalue(c4_vals, @varname(g)) == 2 + + # Prefix -> Condition -> Prefix -> Condition + c5 = PrefixContext( + @varname(a), + ConditionContext( + (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) + ), + ) + c5 = collapse_prefix_stack(c5) + @test has_no_prefixcontexts(c5) + c5_vals = conditioned(c5) + @test length(c5_vals) == 2 + @test getvalue(c5_vals, @varname(a.c)) == 1 + @test getvalue(c5_vals, @varname(a.b.d)) == 2 + + # Prefix -> Condition -> Prefix -> Fixed + c6 = PrefixContext( + @varname(a), + ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), + ) + c6 = collapse_prefix_stack(c6) + @test has_no_prefixcontexts(c6) + @test conditioned(c6) == Dict(@varname(a.c) => 1) + @test fixed(c6) == Dict(@varname(a.b.d) => 2) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..72f33f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -67,6 +67,7 @@ include("test_util.jl") include("threadsafe.jl") include("debug_utils.jl") include("deprecated.jl") + include("submodels.jl") end if GROUP == "All" || GROUP == "Group2" diff --git a/test/submodels.jl b/test/submodels.jl new file mode 100644 index 000000000..e79eed2c3 --- /dev/null +++ b/test/submodels.jl @@ -0,0 +1,199 @@ +module DPPLSubmodelTests + +using DynamicPPL +using Distributions +using Test + +@testset "submodels.jl" begin + @testset "$op with AbstractPPL API" for op in [condition, fix] + x_val = 1.0 + x_logp = op == condition ? logpdf(Normal(), x_val) : 0.0 + + @testset "Auto prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner()) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(inner_op) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(a.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) + end + end + + @testset "No prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner(), false) + end + @model function outer2() + return a ~ to_submodel(inner_op, false) + end + with_inner_op = outer2() + inner_op = op(inner(), (@varname(x) => x_val)) + with_outer_op = op(outer(), (@varname(x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(y)]) + end + end + + @testset "Manual prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(prefix(inner(), :b), false) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(prefix(inner_op, :b), false) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(b.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) + end + end + + @testset "Complex prefixes" begin + mutable struct P + a::Float64 + b::Float64 + end + @model function f() + x = Vector{Float64}(undef, 1) + x[1] ~ Normal() + y ~ Normal() + return x[1] + end + @model function g() + p = P(1.0, 2.0) + p.a ~ to_submodel(f()) + p.b ~ Normal() + return (p.a, p.b) + end + expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) + @test Set(keys(VarInfo(g()))) == expected_vns + + # Check that we can condition/fix on any of them from the outside + for vn in expected_vns + op_g = op(g(), (vn => 1.0)) + vi = VarInfo(op_g) + @test Set(keys(vi)) == symdiff(expected_vns, Set([vn])) + end + end + + @testset "Nested submodels" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return _unused ~ to_submodel(prefix(f(), :b), false) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogp(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end + end + end + + @testset "conditioning via model arguments" begin + @model function f(x) + x ~ Normal() + return y ~ Normal() + end + @model function g(inner_x) + return a ~ to_submodel(f(inner_x)) + end + + vi = VarInfo(g(1.0)) + @test Set(keys(vi)) == Set([@varname(a.y)]) + + vi = VarInfo(g(missing)) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end +end + +end