diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..59030e600 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,74 @@ **Breaking** +### `.~` right hand side must be a univariate distribution + +Previously we allowed statements like + +```julia +x .~ [Normal(), Gamma()] +``` + +where the right hand side of a `.~` was an array of distributions, and ones like + +```julia +x .~ MvNormal(fill(0.0, 2), I) +``` + +where the right hand side was a multivariate distribution. + +These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ Normal() +``` + +The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read. + +If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of + +```julia +x .~ [Normal(), Gamma()] +x .~ Normal.(y) +x .~ MvNormal(fill(0.0, 2), I) +``` + +do + +```julia +x ~ product_distribution([Normal(), Gamma()]) +x ~ product_distribution(Normal.(y)) +x ~ MvNormal(fill(0.0, 2), I) +``` + +This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as + +```julia +dists = Normal.(y) +for i in 1:length(dists) + x[i] ~ dists[i] +end +``` + +Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ MvNormal(fill(0, 2), I) +``` + +should be replaced with something like + +```julia +x = Array{Float64,3}(2, 3, 4) +for i in 1:3, j in 1:4 + x[:, i, j] ~ MvNormal(fill(0, 2), I) +end +``` + +This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side. + ### Remove indexing by samplers This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, @@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `unflatten` no longer accepts a sampler as an argument - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument. ### Reverse prefixing order diff --git a/Project.toml b/Project.toml index 38382f98f..bb67cfb07 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -35,7 +34,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] @@ -44,7 +42,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" @@ -74,5 +71,4 @@ OrderedCollections = "1" Random = "1.6" Requires = "1" Test = "1.6" -ZygoteRules = "0.2" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index 6c58264fe..f463c50ef 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -447,10 +447,8 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume -dot_tilde_assume ``` ```@docs tilde_observe -dot_tilde_observe ``` diff --git a/ext/DynamicPPLZygoteRulesExt.jl b/ext/DynamicPPLZygoteRulesExt.jl deleted file mode 100644 index 78831fdc4..000000000 --- a/ext/DynamicPPLZygoteRulesExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module DynamicPPLZygoteRulesExt - -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, Distributions - using ZygoteRules: ZygoteRules -else - using ..DynamicPPL: DynamicPPL, Distributions - using ..ZygoteRules: ZygoteRules -end - -# https://github.com/TuringLang/Turing.jl/issues/1595 -ZygoteRules.@adjoint function DynamicPPL.dot_observe( - spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform}, - dists::AbstractArray{<:Distributions.Distribution}, - value::AbstractArray, - vi, -) - function dot_observe_fallback(spl, dists, value, vi) - DynamicPPL.increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)), vi - end - return ZygoteRules.pullback(dot_observe_fallback, __context__, spl, dists, value, vi) -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..0559da3ef 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -101,13 +101,9 @@ export AbstractVarInfo, PrefixContext, ConditionContext, assume, - dot_assume, observe, - dot_observe, tilde_assume, tilde_observe, - dot_tilde_assume, - dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, diff --git a/src/compiler.jl b/src/compiler.jl index 8743641af..0bf5608b9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -161,7 +161,16 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other """ isliteral(e) = false isliteral(::Number) = true -isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) +function isliteral(e::Expr) + # In the special case that the expression is of the form `abc[blahblah]`, we consider it + # to be a literal if `abc` is a literal. This is necessary for cases like + # [1.0, 2.0][idx...] ~ Normal() + # which are generated when turning `.~` expressions into loops over `~` expressions. + if e.head == :ref + return isliteral(e.args[1]) + end + return !isempty(e.args) && all(isliteral, e.args) +end """ check_tilde_rhs(x) @@ -172,18 +181,40 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of function check_tilde_rhs(@nospecialize(x)) return throw( ArgumentError( - "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s", + "the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel", ), ) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +check_tilde_rhs(x::Model) = x check_tilde_rhs(x::ReturnedModelWrapper) = x function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} model = check_tilde_rhs(x.model) return Sampleable{typeof(model),AutoPrefix}(model) end +""" + check_dot_tilde_rhs(x) + +Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`. +""" +function check_dot_tilde_rhs(@nospecialize(x)) + return throw( + ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") + ) +end +function check_dot_tilde_rhs(::AbstractArray{<:Distribution}) + msg = """ + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """ + return throw(ArgumentError(msg)) +end +check_dot_tilde_rhs(x::UnivariateDistribution) = x + """ unwrap_right_vn(right, vn) @@ -356,11 +387,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return Base.remove_linenums!( - generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ), + return generate_mainbody!( + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn ) end @@ -368,12 +396,25 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - return Base.remove_linenums!( - generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ), - ) + # Check for a ~ b --> c + args_longrightarrow = getargs_longrightarrow(R) + if args_longrightarrow !== nothing + M, R = args_longrightarrow + return Base.remove_linenums!( + generate_tilde_longrightarrow( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, M, warn), + generate_mainbody!(mod, found, R, warn), + ), + ) + else + return Base.remove_linenums!( + generate_tilde( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ), + ) + end end # Modify the assignment operators. @@ -439,7 +480,7 @@ function generate_tilde(left, right) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) elseif $isassumption - $(generate_tilde_assume(left, dist, vn)) + $(generate_tilde_assume(left, dist, vn, nothing)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -458,12 +499,48 @@ function generate_tilde(left, right) end end -function generate_tilde_assume(left, right, vn) +""" + generate_tilde_longrightarrow(left, middle, right) + +Generate the expression that replaces `left ~ middle --> right` in the model body. +""" +function generate_tilde_longrightarrow(left, middle, right) + isliteral(left) && error("Observing `a` is not supported in `a ~ b --> c`") # TODO + + @gensym vn isassumption model retval + + return quote + $model = $middle + $vn = $(DynamicPPL.resolve_varnames)( + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $model + ) + $isassumption = $(DynamicPPL.isassumption(left, vn)) + if $(DynamicPPL.isfixed(left, vn)) + error("Fixing `a` is not supported in `a ~ b --> c`") # TODO + elseif $isassumption + $(generate_tilde_assume(left, model, vn, right)) + else + error("Observing `a` is not supported in `a ~ b --> c`") # TODO + end + end +end + +function generate_tilde_assume(left, dist_or_model, vn, maybe_right) # HACK: Because the Setfield.jl macro does not support assignment # with multiple arguments on the LHS, we need to capture the return-values # and then update the LHS variables one by one. + @gensym value - expr = :($left = $value) + + has_right = maybe_right !== nothing + expr = if has_right + :(($left, $maybe_right) = $value) + else + :($left = $value) + end + + # TODO(penelopeysm): What does this line even do? Not sure if I need to modify it for the + # a ~ b --> c case. if left isa Expr expr = AbstractPPL.drop_escape( Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true) @@ -473,8 +550,11 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, - $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($dist_or_model), $vn + )..., __varinfo__, + $has_right, ) $expr $value @@ -487,56 +567,16 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - isliteral(left) && return generate_tilde_literal(left, right) - - # Otherwise it is determined by the model or its value, - # if the LHS represents an observation - @gensym vn isassumption value + @gensym dist left_axes idx return quote - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right - ) - $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $(DynamicPPL.isfixed(left, vn)) - $left .= $(DynamicPPL.getfixed_nested)(__context__, $vn) - elseif $isassumption - $(generate_dot_tilde_assume(left, right, vn)) - else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn) - end - - $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $(maybe_view(left)), - $vn, - __varinfo__, - ) - $value + $dist = DynamicPPL.check_dot_tilde_rhs($right) + $left_axes = axes($left) + for $idx in Iterators.product($left_axes...) + $left[$idx...] ~ $dist end end end -function generate_dot_tilde_assume(left, right, vn) - # We don't need to use `Setfield.@set` here since - # `.=` is always going to be inplace + needs `left` to - # be something that supports `.=`. - @gensym value - return quote - $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - __varinfo__, - ) - $left .= $value - $value - end -end - # Note that we cannot use `MacroTools.isdef` because # of https://github.com/FluxML/MacroTools.jl/issues/154. """ diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 462012676..013dd8d69 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -95,24 +95,141 @@ end """ tilde_assume!!(context, right, vn, vi) - Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value and updated `vi`. By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!!(context, right, vn, vi) - return if is_rhs_model(right) +function tilde_assume!!(context, dist_or_model, vn, vi, has_right) + if dist_or_model isa DynamicPPL.Model + # Forbid things like x.a ~ submodel or x[i] ~ submodel + # TODO(penelopeysm): This restriction is not really necessary and could + # be hurtful (say if someone wants to evaluate a submodel in a loop). + # It is not very difficult to lift this restriction, we just have to + # let `prefix` and `unprefix` handle cases with both a sym + optic, + # instead of just the sym as it is right now. + getoptic(vn) !== identity && + error("cannot use e.g. x.a ~ submodel, lhs must be a single identifier") + # Evaluate the inner model with the appropriate context + # NOTE: usage of _evaluate!! instead of evaluate!! is intentional. The + # version without the underscore resets logp before evaluation. + retval, vi = DynamicPPL._evaluate!!( + dist_or_model, vi, PrefixContext{getsym(vn)}(context) + ) + + #= + NOTE(penelopeysm): Why do we use OrderedDict as the output type here? + Didn't we want to use NamedTuple? + + Well, it turns out that values_as(vi, NamedTuple) has one annoying problem. + + Consider the following model: + + using DynamicPPL, Distributions + @model function inner() + x = (a=1, b=2) + x.a ~ Normal() + x.b ~ Normal() + end + values_as(VarInfo(inner()), NamedTuple) + + Now, the varinfo contains the varnames `@varname(x.a)` and `@varname(x.b)` + (with the correct representation, i.e. it knows that `a` is a field of `x` + and `b` is a field of `x`). So, you might expect to get this from values_as(): + + (x = (a = f1, b = f2),) + + where `f1` and `f2` are the values sampled for `x.a` and `x.b`, respectively. + If this were the case, it would then be quite easy to insert some code into + the compiler that looked like + + retval = values_as(VarInfo(inner()), NamedTuple) + x = retval.x + + Unfortunately, that's not how values_as works. We actually get this: + + (var"x.a" = f1, var"x.b" = f2) + + The fundamental reason for this is because the varinfo does not store any + information about the full structure of `x`. For example, it doesn't know if + `x` is a NamedTuple or a struct, and it doesn't know what other fields/keys `x` + might possibly have. So, it doesn't attempt to reconstruct the full structure + of `x` when converting to a NamedTuple. Instead, it just converts the varnames + into strings that can be used as keys in the NamedTuple. + + This inability to reproduce the correct structure of internal variables needs + to be fixed before we can consider using NamedTuple as the output type. I have + opened an issue here: https://github.com/TuringLang/DynamicPPL.jl/issues/814 + + However, my suspicion is that it cannot be fixed. The only way to be completely + safe is to stick to using a dictionary structure (it doesn't necessarily have + to be OrderedDict, but the rest of DynamicPPL uses it so we may as well stick + with it). + + The good news about OrderedDict is that it is a perfectly natural way to + represent the result of a model. In particular, we have the following parallels: + + UnivariateDistribution ==> Float + MultivariateDistribution ==> Vector{Float} + MatrixDistribution ==> Matrix{Float} + Model ==> OrderedDict(VarName => Any) + + where the right-hand side type represents the value obtained by sampling from + something on the left-hand side. Furthermore, in much the same way we can + calculate + + logpdf(Normal(), 1.0) + + we already have all the machinery needed to calculate + + logpdf(model, dict), + + and thus the implementation of submodel `observe` should not be very onerous. + In fact, I think it basically boils down to wrapping `dict` in a + ConditionContext and calling exactly the same code as we do here. + + Note that the same cannot be said of NamedTuple: we cannot, in general, + calculate + + logpdf(model, nt) + + or condition on a NamedTuple, because of the reasons described above. In + fact, writing this makes me think that we should really just get rid of all + NamedTuple stuff internally. It would substantially reduce the number of + headaches we get about models with non-trivial variable structures. + =# + + # Get all the keys that have the correct symbol + new_keys = collect(filter(k -> getsym(k) == getsym(vn), keys(vi))) + new_values_prefixed = values_as(subset(vi, new_keys), OrderedDict) + # Get rid of the prefix + # TODO(penelopeysm): Note that this does not yet work correctly for + # nested submodels (see the failing tests). To deal with that + # correctly, we have to also take into account any prefixes that have + # been applied in the _current_ parent context. + new_values_unprefixed = OrderedDict(( + unprefix_outer_layer(vn) => val for (vn, val) in new_values_prefixed + )) + return if has_right + (new_values_unprefixed, retval), vi + else + new_values_unprefixed, vi + end + elseif is_rhs_model(dist_or_model) # Prefix the variables using the `vn`. - rand_like!!( - right, - should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, + return rand_like!!( + dist_or_model, + if should_auto_prefix(dist_or_model) + PrefixContext{Symbol(vn)}(context) + else + context + end, vi, ) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + value, logp, vi = tilde_assume(context, dist_or_model, vn, vi) + return value, acclogp_assume!!(context, vi, logp) end end @@ -258,384 +375,3 @@ function observe(right::Distribution, left, vi) increment_num_produce!(vi) return Distributions.loglikelihood(right, left), vi end - -# .~ functions - -# assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) -``` -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, vi - ) -end - -# `DefaultContext` -function dot_tilde_assume(context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) -end -function dot_tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), rng, context, args...) -end - -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) - return dot_assume(right, left, vns, vi) -end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -function dot_tilde_assume(::IsParent, context::AbstractContext, args...) - return dot_tilde_assume(childcontext(context), args...) -end -function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) - return dot_tilde_assume(rng, childcontext(context), args...) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, left, vns, vi -) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -# `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi -) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) -end - -# `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi -) - return dot_tilde_assume( - rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi - ) -end - -""" - dot_tilde_assume!!(context, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. -""" -function dot_tilde_assume!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`.~` with a model on the right-hand side is not supported; please use `~`" - ), - ) - value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp) -end - -# `dot_assume` -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) - end - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) - r = get_and_set_val!(rng, vi, vns, dist, spl) - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp, vi -end - -function dot_assume( - dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - dists::AbstractArray{<:Distribution}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) - r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::AbstractVarInfo, -) - r = get_and_set_val!(rng, vi, vns, dists, spl) - # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end -function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" - ) -end - -# HACK: These methods are only used in the `get_and_set_val!` methods below. -# FIXME: Remove these. -function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, vn, dist) - return b(r) -end - -function _maybe_invlink_broadcast(vi, vn, dist) - xvec = getindex_internal(vi, vn) - b = from_maybe_linked_internal_transform(vi, vn, dist) - return b(xvec) -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - n = length(vns) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[:, i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns, dist] - end - else - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - if istrans(vi) - ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist, spl) - # `push!!` sets the trans-flag to `false` by default. - settrans!!(vi, true, vn) - else - push!!(vi, vn, r[:, i], dist, spl) - end - end - end - return r -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - rs = _maybe_invlink_broadcast.((vi,), vns, dists) - r = reshape(rs, size(vns)) - end - else - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - # TODO: This will inefficient since it will allocate an entire vector. - # We could either: - # 1. Figure out the broadcast size and use a `foreach`. - # 2. Define an anonymous function which returns `nothing`, which - # we then broadcast. This will allocate a vector of `nothing` though. - if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) - # NOTE: Need to add the correction. - # FIXME: This is not great. - acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) - # `push!!` sets the trans-flag to `false` by default. - settrans!!.((vi,), true, vns) - else - push!!.((vi,), vns, r, dists, (spl,)) - end - end - return r -end - -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - val::AbstractMatrix, -) - @assert size(val, 2) == length(vns) - foreach(enumerate(vns)) do (i, vn) - setindex!!(vi, val[:, i], vn) - end - return val -end -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - val::AbstractArray, -) - @assert size(val) == size(vns) - foreach(CartesianIndices(val)) do ind - setindex!!(vi, tovec(val[ind]), vns[ind]) - end - return val -end - -# observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function dot_tilde_observe(context::AbstractContext, args...) - return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) -function dot_tilde_observe(::IsParent, context::AbstractContext, args...) - return dot_tilde_observe(childcontext(context), args...) -end - -dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = dot_tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end - -# `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vname, vi) - -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable -name and indices; if needed, these can be accessed through this function, though. -""" -function dot_tilde_observe!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return dot_tilde_observe!!(context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe(context, right, left, vi)`. -""" -function dot_tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) -end - -# Falls back to non-sampler definition. -function dot_observe(::AbstractSampler, dist, value, vi) - return dot_observe(dist, value, vi) -end -function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value), vi -end -function dot_observe(dists::Distribution, value::AbstractArray, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value), vi -end -function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) - increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)), vi -end diff --git a/src/contexts.jl b/src/contexts.jl index 0b4633283..ff75ff412 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -259,21 +259,33 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} return PrefixContext{Prefix}(child) end -const PREFIX_SEPARATOR = Symbol(".") +function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} + return VarName{sym}() +end +function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return optic_to_vn(o.outer) +end +function optic_to_vn( + o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} +) where {Outer,sym} + return VarName{sym}(o.outer) +end +function optic_to_vn(@nospecialize(o)) + return error("optic_to_vn failed with optic $o") +end -@generated function PrefixContext{PrefixOuter}( - context::PrefixContext{PrefixInner} -) where {PrefixOuter,PrefixInner} - return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - context.context - )) +function unprefix_outer_layer(vn::VarName{sym}) where {sym} + return optic_to_vn(getoptic(vn)) end function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - vn_prefixed_inner = prefix(childcontext(ctx), vn) - return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}( - getoptic(vn_prefixed_inner) - ) + optic = getoptic(vn) + new_optic = if optic === identity + Accessors.PropertyLens{Sym}() + else + Base.ComposedFunction(optic, Accessors.PropertyLens{Sym}()) + end + return VarName{Symbol(Prefix)}(new_optic) end prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 43b5054d5..328fe6983 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -113,52 +113,6 @@ function Base.show(io::IO, stmt::ObserveStmt) return print(io, ")") end -Base.@kwdef struct DotAssumeStmt <: Stmt - varname - left - right - value - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotAssumeStmt) - io = add_io_context(io) - print(io, " assume: ") - show_varname(io, stmt.varname) - print(io, " = ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - -Base.@kwdef struct DotObserveStmt <: Stmt - left - right - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotObserveStmt) - io = add_io_context(io) - print(io, "observe: ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - # Some utility methods for extracting information from a trace. """ varnames_in_trace(trace) @@ -168,24 +122,14 @@ Return all the varnames present in the trace. varnames_in_trace(trace::AbstractVector) = mapreduce(varnames_in_stmt, vcat, trace) varnames_in_stmt(stmt::AssumeStmt) = [stmt.varname] -function varnames_in_stmt(stmt::DotAssumeStmt) - return stmt.varname isa VarName ? [stmt.varname] : stmt.varname -end varnames_in_stmt(::ObserveStmt) = [] -varnames_in_stmt(::DotObserveStmt) = [] function distributions_in_trace(trace::AbstractVector) return mapreduce(distributions_in_stmt, vcat, trace) end distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotAssumeStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotObserveStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end """ DebugContext <: AbstractContext @@ -382,95 +326,6 @@ function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, v return logp, vi end -# dot-assume -function record_pre_dot_tilde_assume!(context::DebugContext, vn, left, right, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Variable $(vn) has missing has missing value(s)!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end - - # TODO: Can we do without the memory allocation here? - record_varname!.(broadcast_safe(context), vn, broadcast_safe(right)) - - # Check that `left` does not contain any `` - return nothing -end - -function record_post_dot_tilde_assume!( - context::DebugContext, vns, left, right, value, logp, varinfo -) - stmt = DotAssumeStmt(; - varname=vns, - left=left, - right=right, - value=value, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(varinfo) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - - return nothing -end - -function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - childcontext(context), right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi -) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -# dot-observe -function record_pre_dot_tilde_observe!(context::DebugContext, left, right, vi) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - # TODO: Once `observe` statements receive `vn`, refer to this in the - # error message. - error( - "Encountered missing value(s) in observe!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end -end - -function record_post_dot_tilde_observe!(context::DebugContext, left, right, logp, vi) - stmt = DotObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(vi) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end -function DynamicPPL.dot_tilde_observe(context::DebugContext, right, left, vi) - record_pre_dot_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.dot_tilde_observe(childcontext(context), right, left, vi) - record_post_dot_tilde_observe!(context, left, right, logp, vi) - return logp, vi -end - _conditioned_varnames(d::AbstractDict) = keys(d) _conditioned_varnames(d) = map(sym -> VarName{sym}(), keys(d)) function conditioned_varnames(context) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index dd5aeeb04..0f312fa2c 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -39,11 +39,6 @@ function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) end -function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) -end - """ extract_priors([rng::Random.AbstractRNG, ]model::Model) diff --git a/src/model.jl b/src/model.jl index 3601d77fd..0fb18f463 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 8c18163e3..cb9ea4894 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -100,52 +100,6 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return dot_tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_observe` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - # Note on submodels (penelopeysm) # # We don't need to overload tilde_observe!! for Sampleables (yet), because it @@ -174,44 +128,6 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) return value, acclogp!!(vi, logp) end -function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) - !_include_prior(context) && - return (dot_tilde_assume!!(context.context, right, left, vns, vi)) - value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) - # Track loglikelihood values. - for (vn, logp) in zip(vns, logps) - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)) -end - -function _pointwise_tilde_assume(context, right, left, vns, vi) - # We need to drop the `vi` returned. - values_and_logps = broadcast(right, left, vns) do r, l, vn - # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated - # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, - # and b) even if the variables aren't stored in the vi correctly, we're not going to use - # this vi for anything downstream anyways, i.e. I don't see a case where this would matter - # for this particular use case. - val, logp, _ = tilde_assume(context, r, vn, vi) - return val, logp - end - return map(first, values_and_logps), map(last, values_and_logps) -end -function _pointwise_tilde_assume( - context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi -) - # We need to drop the `vi` returned. - values_and_logps = map(eachcol(left), vns) do l, vn - val, logp, _ = tilde_assume(context, right, vn, vi) - return val, logp - end - # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. - # But this also means that we need to first flatten the entire `values` component before recombining. - values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) - return values, map(last, values_and_logps) -end - """ pointwise_logdensities(model::Model, chain::Chains, keytype = String) @@ -357,7 +273,7 @@ end """ pointwise_loglikelihoods(model, chain[, keytype, context]) - + Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 07296c3f7..324390394 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -487,57 +487,6 @@ function assume( return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - value_raw = if dists isa Distribution - to_maybe_linked_internal.((vi,), vns, (dists,), value) - else - to_maybe_linked_internal.((vi,), vns, dists, value) - end - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - - # r = get_and_set_val!(rng, vi, vns, dist, spl) - n = length(vns) - value = init(rng, dist, spl, n) - - # Update `vi`. - for (vn, val) in zip(vns, eachcol(value)) - val_linked = to_maybe_linked_internal(vi, vn, dist, val) - vi = BangBang.setindex!!(vi, val_linked, vn) - end - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) - return value, lp, vi -end - # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 93bb02d3b..5150be64b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -26,22 +26,10 @@ function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, v value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end -function DynamicPPL.dot_tilde_assume( - context::TestLogModifyingChildContext, right, left, vn, vi -) - value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) - return value, logp * context.mod, vi -end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end -function DynamicPPL.dot_tilde_observe( - context::TestLogModifyingChildContext, right, left, vi -) - logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c506e1ba3..e29614982 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -186,31 +186,29 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end -@model function demo_dot_assume_dot_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} -) where {TV} +@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) m = TV(undef, length(x)) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe)}, s, m + model::Model{typeof(demo_dot_assume_observe)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] +function varnames(model::Model{typeof(demo_dot_assume_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_index_observe( @@ -276,7 +274,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) for i in eachindex(x) x[i] ~ Normal(m[i], sqrt(s[i])) end @@ -295,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -355,7 +353,7 @@ end s = TV(undef, 2) m = TV(undef, 2) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) @@ -376,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_observe_literal() @@ -431,7 +429,7 @@ end s = TV(undef, 2) s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) return s, m end @@ -460,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function _likelihood_multivariate_observe(s, m, x) @@ -473,7 +471,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to @@ -494,76 +492,39 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_dot_assume_dot_observe_matrix( +@model function demo_dot_assume_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s)) + x[:, 1] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) - return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] -end - -@model function demo_dot_assume_matrix_dot_observe_matrix( - x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} -) where {TV} - n = length(x) - d = length(x) ÷ 2 - s = TV(undef, d, 2) - s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - s_vec = vec(s) - m ~ MvNormal(zeros(n), Diagonal(s_vec)) - - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) - - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m -) - n = length(model.args.x) - s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) -end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) - return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - s = zeros(1, 2) # used for varname concretization only - return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] +function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_assume_matrix_dot_observe_matrix( +@model function demo_assume_matrix_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) @@ -572,33 +533,32 @@ end s_vec = vec(s) m ~ MvNormal(zeros(n), Diagonal(s_vec)) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) + x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) s_vec = vec(s) return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end const DemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -609,9 +569,8 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, - Model{typeof(demo_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, + Model{typeof(demo_assume_matrix_observe_matrix_index)}, } const UnivariateAssumeDemoModels = Union{ @@ -637,7 +596,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoMod end const MultivariateAssumeDemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -645,8 +604,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. @@ -699,7 +657,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoM end const MatrixvariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_matrix_dot_observe_matrix)} + Model{typeof(demo_assume_matrix_observe_matrix_index)} } function posterior_mean(model::MatrixvariateAssumeDemoModels) # Get some containers to fill. @@ -786,7 +744,7 @@ And for the multivariate one (the latter one): """ const DEMO_MODELS = ( - demo_dot_assume_dot_observe(), + demo_dot_assume_observe(), demo_assume_index_observe(), demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), @@ -797,7 +755,6 @@ const DEMO_MODELS = ( demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), - demo_dot_assume_dot_observe_matrix(), - demo_dot_assume_matrix_dot_observe_matrix(), - demo_assume_matrix_dot_observe_matrix(), + demo_dot_assume_observe_matrix_index(), + demo_assume_matrix_observe_matrix_index(), ) diff --git a/src/transforming.jl b/src/transforming.jl index 1a26d212f..0239725ae 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -30,67 +30,6 @@ function tilde_assume( return r, lp, setindex!!(vi, r_transformed, vn) end -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::Distribution, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) where {isinverse} - r = getindex.((vi,), vns, (dist,)) - b = link_transform(dist) - - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : b.(r) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) - return r, lp, setindex!!(vi, r_transformed, vns) -end - -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) where {isinverse} - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - r = vi[vns, dist] - - # Compute `logpdf` with logabsdet-jacobian correction. - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, !isinverse) - end - - # Transform _all_ values. - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - b = link_transform(dist) - for (vn, ri) in zip(vns, eachcol(r)) - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - vi = setindex!!(vi, isinverse ? ri : b(ri), vn) - end - - return r, lp, vi -end - function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end diff --git a/src/utils.jl b/src/utils.jl index d64f6dc66..995755b2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -152,6 +152,19 @@ function getargs_tilde(expr::Expr) end end +""" + getargs_longrightarrow(x) + +Same but for L --> R +""" +getargs_longrightarrow(x) = nothing +function getargs_longrightarrow(expr::Expr) + return MacroTools.@match expr begin + (L_ --> R_) => (L, R) + x_ => nothing + end +end + """ getargs_assignment(x) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4cef5fa4e..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -90,29 +90,6 @@ function tilde_assume( return value, logp, vi end -# `dot_tilde_assume` -function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) - - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi -) - value, logp, vi = dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, left, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end - """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) diff --git a/test/Project.toml b/test/Project.toml index c7583c672..f3aeb5ec6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,58 +1,8 @@ [deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" -AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1" -AbstractMCMC = "5" -AbstractPPL = "0.10.1" -Accessors = "0.1" -Bijectors = "0.15.1" -Combinatorics = "1" -Compat = "4.3.0" -DifferentiationInterface = "0.6" Distributions = "0.25" -DistributionsAD = "0.6.3" -Documenter = "1" -EnzymeCore = "0.6 - 0.8" -ForwardDiff = "0.10.12" -JET = "0.9" -LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" -MCMCChains = "6.0.4" -MacroTools = "0.5.6" -Mooncake = "0.4.59" -OrderedCollections = "1" -ReverseDiff = "1" -StableRNGs = "1" -Tracker = "0.2.23" -Zygote = "0.6" julia = "1.6" diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 17981cf2a..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,76 +0,0 @@ -@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - f = DynamicPPL.LogDensityFunction(m) - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = DynamicPPL.LogDensityFunction(m, varinfo) - - # use ForwardDiff result as reference - ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( - ADTypes.AutoForwardDiff(; chunksize=0), f - ) - # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 - # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 - θ = convert(Vector{Float64}, varinfo[:]) - logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) - - @testset "$adtype" for adtype in [ - ADTypes.AutoReverseDiff(; compile=false), - ADTypes.AutoReverseDiff(; compile=true), - ADTypes.AutoMooncake(; config=nothing), - ] - # Mooncake can't currently handle something that is going on in - # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. - if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo - @test_broken 1 == 0 - else - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad - end - end - end - end - - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () - DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = - DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) - @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any - end -end diff --git a/test/compat/ad.jl b/test/compat/ad.jl deleted file mode 100644 index f76ce6f6e..000000000 --- a/test/compat/ad.jl +++ /dev/null @@ -1,57 +0,0 @@ -@testset "ad.jl" begin - @testset "logp" begin - # Hand-written log probabilities for vector `x = [s, m]`. - function logp_gdemo_default(x) - s = x[1] - m = x[2] - dist = Normal(m, sqrt(s)) - - return logpdf(InverseGamma(2, 3), s) + - logpdf(Normal(0, sqrt(s)), m) + - logpdf(dist, 1.5) + - logpdf(dist, 2.0) - end - - test_model_ad(gdemo_default, logp_gdemo_default) - - @model function wishart_ad() - return v ~ Wishart(7, [1 0.5; 0.5 1]) - end - - # Hand-written log probabilities for `x = [v]`. - function logp_wishart_ad(x) - dist = Wishart(7, [1 0.5; 0.5 1]) - return logpdf(dist, reshape(x, 2, 2)) - end - - test_model_ad(wishart_ad(), logp_wishart_ad) - end - - # https://github.com/TuringLang/Turing.jl/issues/1595 - @testset "dot_observe" begin - function f_dot_observe(x) - logp, _ = DynamicPPL.dot_observe( - SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() - ) - return logp - end - function f_dot_observe_manual(x) - return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) - end - - # Manual computation of the gradient. - x = randn(2) - val = f_dot_observe_manual(x) - grad = ForwardDiff.gradient(f_dot_observe_manual, x) - - @test ForwardDiff.gradient(f_dot_observe, x) ≈ grad - - y, back = Tracker.forward(f_dot_observe, x) - @test Tracker.data(y) ≈ val - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(f_dot_observe, x) - @test y ≈ val - @test back(1)[1] ≈ grad - end -end diff --git a/test/compiler.jl b/test/compiler.jl deleted file mode 100644 index 051eba618..000000000 --- a/test/compiler.jl +++ /dev/null @@ -1,730 +0,0 @@ -macro custom(expr) - (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage") - quote - $(esc(expr.args[2])) = 0.0 - end -end - -macro mymodel1(ex) - # check if expression was modified by the DynamicPPL "compiler" - if ex == :(y ~ Uniform()) - return esc(:(x ~ Normal())) - else - return esc(:(z ~ Exponential())) - end -end - -struct MyModelStruct{T} - x::T -end -Base.:~(x, y::MyModelStruct) = y.x -macro mymodel2(ex) - # check if expression was modified by the DynamicPPL "compiler" - if ex == :(y ~ Uniform()) - # Just returns 42 - return :(4 ~ MyModelStruct(42)) - else - return :(return -1) - end -end - -# Used to test sampling of immutable types. -struct MyCoolStruct{T} - a::T -end - -module Issue537 end - -@testset "compiler.jl" begin - @testset "model macro" begin - @model function testmodel_comp(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - - return x, y - end - @test length(methods(testmodel_comp)) == 2 - testmodel_comp(1.0, 1.2) - - # check if drawing from the prior works - @model function testmodel01(x=missing) - x ~ Normal() - return x - end - @test length(methods(testmodel01)) == 4 - f0_mm = testmodel01() - @test mean(f0_mm() for _ in 1:1000) ≈ 0.0 atol = 0.1 - - # Test #544 - @model function testmodel02(x=missing) - if x === missing - x = Vector{Float64}(undef, 2) - end - x[1] ~ Normal() - x[2] ~ Normal() - return x - end - @test length(methods(testmodel02)) == 4 - f0_mm = testmodel02() - @test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000)) - - @model function testmodel03(x=missing) - x ~ Bernoulli(0.5) - return x - end - f01_mm = testmodel03() - @test length(methods(testmodel03)) == 4 - @test mean(f01_mm() for _ in 1:1000) ≈ 0.5 atol = 0.1 - - # test if we get the correct return values - @model function testmodel1(x1, x2) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x1 ~ Normal(m, sqrt(s)) - x2 ~ Normal(m, sqrt(s)) - - return x1, x2 - end - @test length(methods(testmodel1)) == 2 - f1_mm = testmodel1(1.0, 10.0) - @test f1_mm() == (1, 10) - - # alternatives with keyword arguments - testmodel1kw(; x1, x2) = testmodel1(x1, x2) - f1_mm = testmodel1kw(; x1=1.0, x2=10.0) - @test f1_mm() == (1, 10) - - @model function testmodel2(; x1, x2) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x1 ~ Normal(m, sqrt(s)) - x2 ~ Normal(m, sqrt(s)) - - return x1, x2 - end - @test length(methods(testmodel2)) == 2 - f1_mm = testmodel2(; x1=1.0, x2=10.0) - @test f1_mm() == (1, 10) - - @info "Testing the compiler's ability to catch bad models..." - - # Test for assertions in observe statements. - @model function brokentestmodel_observe1(x1, x2) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x1 ~ Normal(m, sqrt(s)) - x2 ~ x1 + 2 - - return x1, x2 - end - - btest = brokentestmodel_observe1(1.0, 2.0) - @test_throws ArgumentError btest() - - @model function brokentestmodel_observe2(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x = Vector{Float64}(undef, 2) - x ~ [Normal(m, sqrt(s)), 2.0] - - return x - end - - btest = brokentestmodel_observe2([1.0, 2.0]) - @test_throws ArgumentError btest() - - # Test for assertions in assume statements. - @model function brokentestmodel_assume1() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x1 ~ Normal(m, sqrt(s)) - x2 ~ x1 + 2 - - return x1, x2 - end - - btest = brokentestmodel_assume1() - @test_throws ArgumentError btest() - - @model function brokentestmodel_assume2() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - - x = Vector{Float64}(undef, 2) - x ~ [Normal(m, sqrt(s)), 2.0] - - return x - end - - btest = brokentestmodel_assume2() - @test_throws ArgumentError btest() - - # Test missing input arguments - @model function testmodel_missing1(x) - x ~ Bernoulli(0.5) - return x - end - @test_throws MethodError testmodel_missing1() - - # Test missing initialization for vector observation turned parameter - @model function testmodel_missing2(x) - x[1] ~ Bernoulli(0.5) - return x - end - @test_throws MethodError testmodel_missing2(missing)() - - # Test use of internal names - @model function testmodel_missing3(x) - x[1] ~ Bernoulli(0.5) - global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler - global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) - return x - end - model = testmodel_missing3([1.0]) - varinfo = VarInfo(model) - @test getlogp(varinfo) == lp - @test varinfo_ isa AbstractVarInfo - @test model_ === model - @test context_ isa SamplingContext - @test rng_ isa Random.AbstractRNG - - # disable warnings - @model function testmodel_missing4(x) - x[1] ~ Bernoulli(0.5) - global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler - global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) - return x - end false - lpold = lp - model = testmodel_missing4([1.0]) - varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold - - # test DPPL#61 - @model function testmodel_missing5(z) - m ~ Normal() - z[1:end] ~ MvNormal(fill(m, length(z)), I) - return m - end - model = testmodel_missing5(rand(10)) - @test all(z -> isapprox(z, 0; atol=0.2), mean(model() for _ in 1:1000)) - - # test Turing#1464 - @model function gdemo(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - end - x = [1.0, missing] - VarInfo(gdemo(x)) - @test ismissing(x[2]) - - # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 - vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) - vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) - - # Non-array variables - @model function testmodel_nonarray(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in 1:(length(x.a) - 1) - x.a[i] ~ Normal(m, √s) - end - - # Dynamic indexing - x.a[end] ~ Normal(100.0, 1.0) - - # Immutable set - y.a ~ Normal() - - # Dotted - z = Vector{Float64}(undef, 3) - z[1:2] .~ Normal() - z[end:end] .~ Normal() - - return (; s=s, m=m, x=x, y=y, z=z) - end - - m_nonarray = testmodel_nonarray( - MyCoolStruct([missing, missing]), MyCoolStruct(missing) - ) - result = m_nonarray() - @test !any(ismissing, result.x.a) - @test result.y.a !== missing - @test result.x.a[end] > 10 - - # Ensure that we can work with `Vector{Real}(undef, N)` which is the - # reason why we're using `BangBang.prefermutation` in `src/compiler.jl` - # rather than the default from Setfield.jl. - # Related: https://github.com/jw3126/Setfield.jl/issues/157 - @model function vdemo() - x = Vector{Real}(undef, 10) - for i in eachindex(x) - x[i] ~ Normal(0, sqrt(4)) - end - - return x - end - x = vdemo()() - @test all((isassigned(x, i) for i in eachindex(x))) - end - @testset "nested model" begin - function makemodel(p) - @model function testmodel(x) - x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) - return x - end - return testmodel - end - model = makemodel(0.5)([1.0]) - varinfo = VarInfo(model) - @test getlogp(varinfo) == lp - end - @testset "user-defined variable name" begin - @model f1() = x ~ NamedDist(Normal(), :y) - @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) - @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) - vi1 = VarInfo(f1()) - vi2 = VarInfo(f2()) - vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test first(Base.keys(vi1.metadata.y)) == @varname(y) - @test haskey(vi2.metadata, :y) - @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) - @test haskey(vi3.metadata, :y) - @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) - - # Conditioning - f1_c = f1() | (y=1,) - f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,)) - f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,)) - @test f1_c() == 1 - # TODO(torfjelde): We need conditioning for `Dict`. - @test_broken f2_c() == 1 - @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) - end - @testset "custom tilde" begin - @model demo() = begin - $(@custom m ~ Normal()) - return m - end - model = demo() - @test all(iszero(model()) for _ in 1:1000) - end - @testset "docstring" begin - "This is a test" - @model function demo(x) - m ~ Normal() - return x ~ Normal(m, 1) - end - - s = @doc(demo) - @test string(s) == "This is a test\n" - - # Verify that adding docstring didn't completely break execution of model - m = demo(0.0) - @test m() isa Float64 - end - @testset "type annotations" begin - @model function demo_without(x) - return x ~ Normal() - end - @test isempty(VarInfo(demo_without(0.0))) - - @model function demo_with(x::Real) - return x ~ Normal() - end - @test isempty(VarInfo(demo_with(0.0))) - end - - @testset "macros within model" begin - # Macro expansion - @model function demo1() - @mymodel1(y ~ Uniform()) - end - - @test haskey(VarInfo(demo1()), @varname(x)) - - # Interpolation - # Will fail if: - # 1. Compiler expands `y ~ Uniform()` before expanding the macros - # => returns -1. - # 2. `@mymodel` is expanded before entire `@model` has been - # expanded => errors since `MyModelStruct` is not a distribution, - # and hence `tilde_observe` errors. - @model function demo2() - return $(@mymodel2(y ~ Uniform())) - end - @test demo2()() == 42 - end - - @testset "to_submodel" begin - # No prefix, 1 level. - @model function demo1(x) - return x ~ Normal() - end - @model function demo2(x, y) - _ignore ~ to_submodel(demo1(x), false) - return y ~ Uniform() - end - # No observation. - m = demo2(missing, missing) - vi = VarInfo(m) - ks = keys(vi) - @test @varname(x) ∈ ks - @test @varname(y) ∈ ks - - # Observation in top-level. - m = demo2(missing, 1.0) - vi = VarInfo(m) - ks = keys(vi) - @test @varname(x) ∈ ks - @test @varname(y) ∉ ks - - # Observation in nested model. - m = demo2(1000.0, missing) - vi = VarInfo(m) - ks = keys(vi) - @test @varname(x) ∉ ks - @test @varname(y) ∈ ks - - # Observe all. - m = demo2(1000.0, 0.5) - vi = VarInfo(m) - ks = keys(vi) - @test isempty(ks) - - # Check values makes sense. - @model function demo3(x, y) - _ignore ~ to_submodel(demo1(x), false) - return y ~ Normal(x) - end - m = demo3(1000.0, missing) - # Mean of `y` should be close to 1000. - @test abs(mean([VarInfo(m)[@varname(y)] for i in 1:10]) - 1000) ≤ 10 - - # Prefixed submodels and usage of submodel return values. - @model function demo_return(x) - x ~ Normal() - return x - end - @model function demo_useval(x, y) - sub1 ~ to_submodel(demo_return(x)) - sub2 ~ to_submodel(demo_return(y)) - return z ~ Normal(sub1 + sub2 + 100, 1.0) - end - m = demo_useval(missing, missing) - vi = VarInfo(m) - ks = keys(vi) - @test VarName{Symbol("sub1.x")}() ∈ ks - @test VarName{Symbol("sub2.x")}() ∈ ks - @test @varname(z) ∈ ks - @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 - - # AR1 model. Dynamic prefixing. - @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV} - η ~ MvNormal(zeros(num_steps), I) - δ = sqrt(1 - α^2) - x = TV(undef, num_steps) - x[1] = η[1] - @inbounds for t in 2:num_steps - x[t] = @. α * x[t - 1] + δ * η[t] - end - return @. μ + σ * x - end - - @model function demo(y) - α ~ Uniform() - μ ~ Normal() - σ ~ truncated(Normal(), 0, Inf) - num_steps = length(y[1]) - num_obs = length(y) - @inbounds for i in 1:num_obs - x ~ to_submodel(prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) - y[i] ~ MvNormal(x, 0.01 * I) - end - end - - ys = [randn(10), randn(10)] - m = demo(ys) - vi = VarInfo(m) - - for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] - @test VarName{k}() ∈ keys(vi) - end - end - - @testset "check_tilde_rhs" begin - @test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn()) - - x = Normal() - @test DynamicPPL.check_tilde_rhs(x) === x - - x = [Laplace(), Normal(), MvNormal(zeros(3), I)] - @test DynamicPPL.check_tilde_rhs(x) === x - end - @testset "isliteral" begin - @test DynamicPPL.isliteral(:([1.0])) - @test DynamicPPL.isliteral(:([[1.0], 1.0])) - @test DynamicPPL.isliteral(:((1.0, 1.0))) - - @test !(DynamicPPL.isliteral(:([x]))) - @test !(DynamicPPL.isliteral(:([[x], 1.0]))) - @test !(DynamicPPL.isliteral(:((x, 1.0)))) - end - - @testset "array literals" begin - # Verify that we indeed can parse this. - @test @model(function array_literal_model() - # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - return [10.0, 10.0] ~ MvNormal(m, 0.25 * I) - end) isa Function - - @model function array_literal_model2() - # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - return [10.0, 10.0] ~ MvNormal(m, 0.25 * I) - end - - @test array_literal_model2()() == [10.0, 10.0] - end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/260 - @testset "anonymous function" begin - error = ArgumentError("anonymous functions without name are not supported") - @test_throws LoadError(@__FILE__, (@__LINE__) + 1, error) @macroexpand begin - @model function (x) - return x ~ Normal() - end - end - @test_throws LoadError(@__FILE__, (@__LINE__) + 1, error) @macroexpand begin - model = @model(x -> (x ~ Normal())) - end - end - - @testset "dispatching with model" begin - f(x) = false - - @model demo() = x ~ Normal() - @test !f(demo()) - f(::Model{typeof(demo)}) = true - @test f(demo()) - - # Leads to re-definition of `demo` and trait is not affected. - @test length(methods(demo)) == 2 - @model demo() = x ~ Normal() - @test length(methods(demo)) == 2 - @test f(demo()) - - # Ensure we can specialize on arguments. - @model demo(x) = x ~ Normal() - @test length(methods(demo)) == 4 - @test f(demo(1.0)) - f(::Model{typeof(demo),(:x,)}) = false - @test !f(demo(1.0)) - @test f(demo()) # should still be `true` - - # Set it to `false` again. - f(::Model{typeof(demo),()}) = false - @test !f(demo()) - end - - @testset "return value" begin - # Make sure that a return-value of `x = 1` isn't combined into - # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. - @model empty_model() = return x = 1 - empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) - @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} - - # Even if the return-value is `AbstractVarInfo`, we should return - # a `Tuple` with `AbstractVarInfo` in the second component too. - @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) - @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end - - # We should not be altering return-values other than at top-level. - @model function demo() - # If we also replaced this `return` inside of `f`, then the - # final `return` would be include `__varinfo__`. - f(x) = return x^2 - return f(1.0) - end - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) - @test retval isa Float64 - - @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) - - # Return-value when using `to_submodel` - @model inner() = x ~ Normal() - @model function outer() - return _ignore ~ to_submodel(inner()) - end - @test outer()() isa Real - - # Edge-cases. - # `return` in the last statement. - # Ref: issue #511. - @model function demo_ret_in_last_stmt(x::Bool) - # Two different values not supporting `iterate`. - if x - return Val(1) - else - return Val(2) - end - end - - model_true = demo_ret_in_last_stmt(true) - @test model_true() === Val(1) - - model_false = demo_ret_in_last_stmt(false) - @test model_false() === Val(2) - - # `return` with `return` - @model function demo_ret_with_ret() - return begin - return Val(1) - Val(2) - end - end - @test demo_ret_with_ret()() === Val(1) - end - - @testset "issue #368: hasmissing dispatch" begin - @test !DynamicPPL.hasmissing(typeof(Union{}[])) - - # (nested) arrays with `Missing` eltypes - @test DynamicPPL.hasmissing(Vector{Union{Missing,Float64}}) - @test DynamicPPL.hasmissing(Matrix{Union{Missing,Real}}) - @test DynamicPPL.hasmissing(Vector{Matrix{Union{Missing,Float32}}}) - - # no `Missing` - @test !DynamicPPL.hasmissing(Vector{Float64}) - @test !DynamicPPL.hasmissing(Matrix{Real}) - @test !DynamicPPL.hasmissing(Vector{Matrix{Float32}}) - end - - @testset "issue #393: anonymous argument with type parameter" begin - @model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1 - @test f_393()() == 1 - @test f_393(Val(true))() == 0 - end - - @testset "splatting of args and kwargs" begin - @model function f_splat_test_1(x; y::T=1, kwargs...) where {T} - x ~ Normal(y, 1) - return x, y, T, NamedTuple(kwargs) - end - - # Non-empty `kwargs...`. - res = f_splat_test_1(1; z=2, w=3)() - @test res == (1, 1, Int, (z=2, w=3)) - - # Empty `kwargs...`. - res = f_splat_test_1(1)() - @test res == (1, 1, Int, NamedTuple()) - - @model function f_splat_test_2(x, args...; y::T=1, kwargs...) where {T} - x ~ Normal(y, 1) - return x, args, y, T, NamedTuple(kwargs) - end - - # Non-empty `args...` and non-empty `kwargs...`. - res = f_splat_test_2(1, 2, 3; z=2, w=3)() - @test res == (1, (2, 3), 1, Int, (z=2, w=3)) - - # Empty `args...` and empty `kwargs...`. - res = f_splat_test_2(1)() - @test res == (1, (), 1, Int, NamedTuple()) - end - - @testset "issue #537: model with logging" begin - # Make sure `Module` is valid to put in a model. - @model demo_with_module() = Issue537 - model = demo_with_module() - @test model() === Issue537 - - # And one explicit test for logging so know that is working. - @model demo_with_logging() = @info "hi" - model = demo_with_logging() - @test model() === nothing - # Make sure that the log message is present. - @test_logs (:info, "hi") model() - end - - @testset ":= (tracked values)" begin - @model function demo_tracked() - x ~ Normal() - y := 100 + x - return (; x, y) - end - @model function demo_tracked_submodel() - return vals ~ to_submodel(demo_tracked(), false) - end - for model in [demo_tracked(), demo_tracked_submodel()] - # Make sure it's runnable and `y` is present in the return-value. - @test model() isa NamedTuple{(:x, :y)} - - # `VarInfo` should only contain `x`. - varinfo = VarInfo(model) - @test haskey(varinfo, @varname(x)) - @test !haskey(varinfo, @varname(y)) - - # While `values_as_in_model` should contain both `x` and `y`, if - # include_colon_eq is set to `true`. - values = values_as_in_model(model, true, deepcopy(varinfo)) - @test haskey(values, @varname(x)) - @test haskey(values, @varname(y)) - - # And if include_colon_eq is set to `false`, then `values` should - # only contain `x`. - values = values_as_in_model(model, false, deepcopy(varinfo)) - @test haskey(values, @varname(x)) - @test !haskey(values, @varname(y)) - end - end - - @testset "signature parsing + TypeWrap" begin - @model function demo_typewrap( - a, b=1, ::Type{T1}=Float64; c, d=2, t::Type{T2}=Int - ) where {T1,T2} - return (; a, b, c, d, t) - end - - model = demo_typewrap(1; c=2) - res = model() - @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) - end -end diff --git a/test/context_implementations.jl b/test/context_implementations.jl deleted file mode 100644 index 8a795320d..000000000 --- a/test/context_implementations.jl +++ /dev/null @@ -1,72 +0,0 @@ -@testset "context_implementations.jl" begin - # https://github.com/TuringLang/DynamicPPL.jl/issues/129 - @testset "#129" begin - @model function test(x) - μ ~ MvNormal(zeros(2), 4 * I) - z = Vector{Int}(undef, length(x)) - z .~ Categorical.(fill([0.5, 0.5], length(x))) - for i in 1:length(x) - x[i] ~ Normal(μ[z[i]], 0.1) - end - end - - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) - end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577 - @testset "dot tilde: arrays of distributions" begin - @testset "assume" begin - @model function test(x, size) - y = Array{Float64,length(size)}(undef, size...) - y .~ Normal.(x) - return y, getlogp(__varinfo__) - end - - for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - model = test(x, ysize) - y, lp = model() - @test lp ≈ sum(logpdf.(Normal.(x), y)) - - ys = [first(model()) for _ in 1:10_000] - @test norm(mean(ys) .- x, Inf) < 0.1 - @test norm(std(ys) .- 1, Inf) < 0.1 - end - end - end - - @testset "observe" begin - @model function test(x, y) - return y .~ Normal.(x) - end - - for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - y = randn(ysize) - z = logjoint(test(x, y), VarInfo()) - @test z ≈ sum(logpdf.(Normal.(x), y)) - end - end - end - end -end diff --git a/test/contexts.jl b/test/contexts.jl deleted file mode 100644 index faa831cc1..000000000 --- a/test/contexts.jl +++ /dev/null @@ -1,334 +0,0 @@ -using Test, DynamicPPL, Accessors -using DynamicPPL: - leafcontext, - setleafcontext, - childcontext, - setchildcontext, - AbstractContext, - NodeTrait, - IsLeaf, - IsParent, - PointwiseLogdensityContext, - contextual_isassumption, - ConditionContext, - decondition_context, - hasconditioned, - getconditioned, - hasconditioned_nested, - getconditioned_nested - -using EnzymeCore - -# TODO: Should we maybe put this in DPPL itself? -function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context -end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) -end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child -end - -Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() - -""" - remove_prefix(vn::VarName) - -Return `vn` but now with the prefix removed. -""" -function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getoptic(vn) - ) -end - -@testset "contexts.jl" begin - child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] - - parent_contexts = [ - DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - SamplingContext(), - MiniBatchContext(DefaultContext(), 0.0), - PrefixContext{:x}(DefaultContext()), - PointwiseLogdensityContext(), - ConditionContext((x=1.0,)), - ConditionContext( - (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) - ), - ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), - ConditionContext((x=[1.0, missing],)), - ] - - contexts = vcat(child_contexts, parent_contexts) - - @testset "$(context)" for context in contexts - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - DynamicPPL.TestUtils.test_context(context, model) - end - end - - @testset "contextual_isassumption" begin - @testset "$context" for context in contexts - # Any `context` should return `true` by default. - @test contextual_isassumption(context, VarName{gensym(:x)}()) - - if any(Base.Fix2(isa, ConditionContext), context) - # We have a `ConditionContext` among us. - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) - - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() - - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) - - # Let's check elementwise. - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if getoptic(vn_child)(val) === missing - @test contextual_isassumption(context, vn_child) - else - @test !contextual_isassumption(context, vn_child) - end - end - end - end - end - end - - @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$context" for context in contexts - fake_vn = VarName{gensym(:x)}() - @test !hasconditioned_nested(context, fake_vn) - @test_throws ErrorException getconditioned_nested(context, fake_vn) - - if any(Base.Fix2(isa, ConditionContext), context) - # `ConditionContext` specific. - - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) - - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() - - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) - - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # `vn_child` should be in `context`. - @test hasconditioned_nested(context, vn_child) - # Value should be the same as extracted above. - @test getconditioned_nested(context, vn_child) === - getoptic(vn_child)(val) - end - end - end - end - end - - @testset "PrefixContext" begin - @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) - ), - ), - ), - ) - vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) - - vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) - end - - @testset "nested within arbitrary context stacks" begin - vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) - ctx2 = SamplingContext(ctx1) - ctx3 = PrefixContext{:b}(ctx2) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) - vn_prefixed1 = prefix(ctx1, vn) - vn_prefixed2 = prefix(ctx2, vn) - vn_prefixed3 = prefix(ctx3, vn) - vn_prefixed4 = prefix(ctx4, vn) - @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed3) == Symbol("b.a.x") - @test DynamicPPL.getsym(vn_prefixed4) == Symbol("b.a.x") - @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) - end - - context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(model, varinfo, context) - # Extract the resulting symbols. - vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) - - # Extract the ground truth symbols. - vns_syms = Set([ - Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ]) - - # Check that all variables are prefixed correctly. - @test vns_syms == vns_varinfo_syms - end - end - - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - - @testset "ConditionContext" begin - @testset "Nesting" begin - @testset "NamedTuple" begin - n1 = (x=1, y=2) - n2 = (x=3,) - # Values from outer context should override inner one - ctx1 = ConditionContext(n1, ConditionContext(n2)) - @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed - @test childcontext(ctx1) isa DefaultContext - # Then test the nesting the other way round - ctx2 = ConditionContext(n2, ConditionContext(n1)) - @test ctx2.values == (x=3, y=2) - @test childcontext(ctx2) isa DefaultContext - end - - @testset "Dict" begin - # Same tests as NamedTuple above - d1 = Dict(@varname(x) => 1, @varname(y) => 2) - d2 = Dict(@varname(x) => 3) - ctx1 = ConditionContext(d1, ConditionContext(d2)) - @test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2) - @test childcontext(ctx1) isa DefaultContext - ctx2 = ConditionContext(d2, ConditionContext(d1)) - @test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2) - @test childcontext(ctx2) isa DefaultContext - end - end - - @testset "decondition_context" begin - @testset "NamedTuple" begin - ctx = ConditionContext((x=1, y=2, z=3)) - # Decondition all variables - @test decondition_context(ctx) isa DefaultContext - # Decondition only some variables - dctx = decondition_context(ctx, :x) - @test dctx isa ConditionContext - @test dctx.values == (y=2, z=3) - dctx = decondition_context(ctx, :y, :z) - @test dctx isa ConditionContext - @test dctx.values == (x=1,) - # Decondition all variables manually - @test decondition_context(ctx, :x, :y, :z) isa DefaultContext - end - - @testset "Dict" begin - ctx = ConditionContext( - Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3) - ) - # Decondition all variables - @test decondition_context(ctx) isa DefaultContext - # Decondition only some variables - dctx = decondition_context(ctx, @varname(x)) - @test dctx isa ConditionContext - @test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3) - dctx = decondition_context(ctx, @varname(y), @varname(z)) - @test dctx isa ConditionContext - @test dctx.values == Dict(@varname(x) => 1) - # Decondition all variables manually - @test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa - DefaultContext - end - - @testset "Nesting" begin - ctx = ConditionContext( - (x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4)) - ) - # Decondition an outer variable - dctx = decondition_context(ctx, :x) - @test dctx.values == (y=2,) - @test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4) - # Decondition an inner variable - dctx = decondition_context(ctx, @varname(a)) - @test dctx.values == (x=1, y=2) - @test childcontext(dctx).values == Dict(@varname(b) => 4) - # Try deconditioning everything - dctx = decondition_context(ctx) - @test dctx isa DefaultContext - end - end - end - - @testset "FixedContext" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - retval = model() - s, m = retval.s, retval.m - - # Keword approach. - model_fixed = DynamicPPL.fix(model; s=s) - @test model_fixed().s == s - @test model_fixed().m != m - # A fixed variable should not contribute at all to the logjoint. - # Assuming `condition` is correctly implemented, the following should hold. - @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) - - # Positional approach. - model_fixed = DynamicPPL.fix(model, (; s)) - @test model_fixed().s == s - @test model_fixed().m != m - @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) - - # Pairs approach. - model_fixed = DynamicPPL.fix(model, @varname(s) => s) - @test model_fixed().s == s - @test model_fixed().m != m - @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) - - # Dictionary approach. - model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s)) - @test model_fixed().s == s - @test model_fixed().m != m - @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) - end - end -end diff --git a/test/debug_utils.jl b/test/debug_utils.jl deleted file mode 100644 index d4f6601f5..000000000 --- a/test/debug_utils.jl +++ /dev/null @@ -1,214 +0,0 @@ -@testset "check_model" begin - @testset "context interface" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context(context, model) - end - end - - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model) - # These models should all work. - @test issuccess - - # Check that the trace contains all the variables in the model. - varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace) - for vn in DynamicPPL.TestUtils.varnames(model) - @test vn in varnames_in_trace - end - - # Quick checks for `show` of trace. - @test occursin("assume: ", string(trace)) - @test occursin("observe: ", string(trace)) - - # All these models should have static constraints. - @test DynamicPPL.has_static_constraints(model) - end - - @testset "multiple usage of same variable" begin - @testset "simple" begin - @model function buggy_demo_model() - x ~ Normal() - x ~ Normal() - return y ~ Normal() - end - buggy_model = buggy_demo_model() - - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) - @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) - end - - @testset "submodel" begin - @model ModelInner() = x ~ Normal() - @model function ModelOuterBroken() - # Without automatic prefixing => `x` s used twice. - z ~ to_submodel(ModelInner(), false) - return x ~ Normal() - end - model = ModelOuterBroken() - @test_throws ErrorException check_model(model; error_on_failure=true) - - @model function ModelOuterWorking() - # With automatic prefixing => `x` is not duplicated. - z ~ to_submodel(ModelInner()) - x ~ Normal() - return z - end - model = ModelOuterWorking() - @test check_model(model; error_on_failure=true) - - # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 - @model function ModelOuterWorking2() - x1 ~ to_submodel(prefix(ModelInner(), :a), false) - x2 ~ to_submodel(prefix(ModelInner(), :b), false) - return (x1, x2) - end - model = ModelOuterWorking2() - @test check_model(model; error_on_failure=true) - end - - @testset "subsumes (x then x[1])" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x ~ MvNormal(zeros(2), I) - x[1] ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) - @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) - end - - @testset "subsumes (x[1] then x)" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x[1] ~ Normal() - x ~ MvNormal(zeros(2), I) - return nothing - end - buggy_model = buggy_subsumes_demo_model() - - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) - @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) - end - - @testset "subsumes (x.a then x)" begin - @model function buggy_subsumes_demo_model() - x = (a=nothing,) - x.a ~ Normal() - x ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) - @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) - end - end - - @testset "incorrect use of condition" begin - @testset "missing in multivariate" begin - @model function demo_missing_in_multivariate(x) - return x ~ MvNormal(zeros(length(x)), I) - end - model = demo_missing_in_multivariate([1.0, missing]) - @test_throws ErrorException check_model(model) - end - - @testset "condition both in args and context" begin - @model function demo_condition_both_in_args_and_context(x) - return x ~ Normal() - end - model = demo_condition_both_in_args_and_context(1.0) - for vals in [ - (x=2.0,), - OrderedDict(@varname(x) => 2.0), - OrderedDict(@varname(x[1]) => 2.0), - ] - conditioned_model = DynamicPPL.condition(model, vals) - @test_throws ErrorException check_model( - conditioned_model; error_on_failure=true - ) - end - end - end - - @testset "printing statements" begin - @testset "assume" begin - @model demo_assume() = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_assume()) - @test isuccess - @test startswith(string(trace), " assume: x ~ Normal") - end - - @testset "observe" begin - @model demo_observe(x) = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_observe(1.0)) - @test isuccess - @test occursin(r"observe: \d+\.\d+ ~ Normal", string(trace)) - end - end - - @testset "comparing multiple traces" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model) - issuccess_2, trace_2 = check_model_and_trace(model) - @test issuccess_1 && issuccess_2 - - # Should have the same varnames present. - varnames_1 = DynamicPPL.DebugUtils.varnames_in_trace(trace_1) - varnames_2 = DynamicPPL.DebugUtils.varnames_in_trace(trace_2) - @info varnames_1 == varnames_2 - - # But will have different distributions. - dists_1 = DynamicPPL.DebugUtils.distributions_in_trace(trace_1) - dists_2 = DynamicPPL.DebugUtils.distributions_in_trace(trace_2) - @test dists_1[1] == dists_2[1] - @test dists_1[2] != dists_2[2] - - @test !DynamicPPL.has_static_constraints(model) - end - - @testset "vector with `undef`" begin - # Source: https://github.com/TuringLang/Turing.jl/pull/2218 - @model function demo_undef(ns...) - x = Array{Real}(undef, ns...) - @. x ~ Normal(0, 2) - end - for ns in [(2,), (2, 2), (2, 2, 2)] - model = demo_undef(ns...) - @test check_model(model; error_on_failure=true) - end - end - - @testset "model_warntype & model_codetyped" begin - @model demo_without_kwargs(x) = y ~ Normal(x, 1) - @model demo_with_kwargs(x; z=1) = y ~ Normal(x, z) - - for model in [demo_without_kwargs(1.0), demo_with_kwargs(1.0)] - codeinfo, retype = DynamicPPL.DebugUtils.model_typed(model) - @test codeinfo isa Core.CodeInfo - @test retype <: Tuple - - # Just make sure the following is runnable. - @test DynamicPPL.DebugUtils.model_warntype(model) isa Any - end - end -end diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index f12217983..000000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,57 +0,0 @@ -@testset "deprecated" begin - @testset "@submodel" begin - @testset "is deprecated" begin - @model inner() = x ~ Normal() - @model outer() = @submodel x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer()() - ) - - @model outer_with_prefix() = @submodel prefix = "sub" x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer_with_prefix()() - ) - end - - @testset "prefixing still works correctly" begin - @model inner() = x ~ Normal() - @model function outer() - a = @submodel inner() - b = @submodel prefix = "sub" inner() - return a, b - end - @test outer()() isa Tuple{Float64,Float64} - vi = VarInfo(outer()) - @test @varname(x) in keys(vi) - @test @varname(var"sub.x") in keys(vi) - end - - @testset "logp is still accumulated properly" begin - @model inner_assume() = x ~ Normal() - @model inner_observe(x, y) = y ~ Normal(x) - @model function outer(b) - a = @submodel inner_assume() - @submodel inner_observe(a, b) - end - y_val = 1.0 - model = outer(y_val) - @test model() == y_val - - x_val = 1.5 - vi = VarInfo(outer(y_val)) - DynamicPPL.setindex!!(vi, x_val, @varname(x)) - @test logprior(model, vi) ≈ logpdf(Normal(), x_val) - @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) - @test logjoint(model, vi) ≈ - logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val) - end - end -end diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl deleted file mode 100644 index 8bb692783..000000000 --- a/test/distribution_wrappers.jl +++ /dev/null @@ -1,13 +0,0 @@ -@testset "distribution_wrappers.jl" begin - d = Normal() - nd = DynamicPPL.NoDist(d) - - # Smoke test - rand(nd) - - # Actual tests - @test minimum(nd) == -Inf - @test maximum(nd) == Inf - @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 -end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index 8de28046b..000000000 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -@testset "tag" begin - for chunksize in (nothing, 0, 1, 10) - ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) - standardtag = if !isdefined(Base, :get_extension) - DynamicPPL.DynamicPPLForwardDiffExt.standardtag - else - Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag - end - @test standardtag(ad) - for tag in (false, 0, 1) - @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) - end - end -end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl deleted file mode 100644 index 933bfb1d1..000000000 --- a/test/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,94 +0,0 @@ -@testset "DynamicPPLJETExt.jl" begin - @testset "determine_suitable_varinfo" begin - @model function demo1() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end - model = demo1() - @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa - DynamicPPL.UntypedVarInfo - - @model demo2() = x ~ Normal() - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.TypedVarInfo - - @model function demo3() - # Just making sure that nothing strange happens when type inference fails. - x = Vector(undef, 1) - x[1] ~ Bernoulli() - if x[1] - y ~ Normal() - else - z ~ Normal() - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa - DynamicPPL.UntypedVarInfo - - # Evaluation works (and it would even do so in practice), but sampling - # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. - @model function demo4() - x ~ Bernoulli() - if x - y ~ Normal() - else - y ~ Cauchy() # different distibution, but same transformation - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo - - # In this model, the type error occurs in the user code rather than in DynamicPPL. - @model function demo5() - x ~ Normal() - xs = Any[] - push!(xs, x) - # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the - # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. - return sum(xs) - end - # Should pass if we're only checking the tilde statements. - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.TypedVarInfo - # Should fail if we're including errors in the model body. - @test DynamicPPL.Experimental.determine_suitable_varinfo( - demo5(); only_ddpl=false - ) isa DynamicPPL.UntypedVarInfo - end - - @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - # Use debug logging below. - varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, DynamicPPL.SamplingContext() - ) - JET.test_call(f_sample, argtypes_sample) - # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.TypedVarInfo - @test is_typed - # If the test failed, check why it didn't infer a typed varinfo - if !is_typed - typed_vi = VarInfo(model) - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi - ) - JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi, DynamicPPL.SamplingContext() - ) - JET.test_call(f_sample, argtypes_sample) - end - end - end -end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl deleted file mode 100644 index 3ba5edfe1..000000000 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "DynamicPPLMCMCChainsExt" begin - @model demo() = x ~ Normal() - model = demo() - - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) - chain_generated = @test_nowarn returned(model, chain) - @test size(chain_generated) == (1000, 1) - @test mean(chain_generated) ≈ 0 atol = 0.1 -end - -# test for `predict` is in `test/model.jl` diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl deleted file mode 100644 index 986057da0..000000000 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ /dev/null @@ -1,5 +0,0 @@ -@testset "DynamicPPLMooncakeExt" begin - Mooncake.TestUtils.test_rule( - StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true - ) -end diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/linking.jl b/test/linking.jl deleted file mode 100644 index d424a9c2d..000000000 --- a/test/linking.jl +++ /dev/null @@ -1,204 +0,0 @@ -using Bijectors - -# Simple transformations which alters the "dimension" of the variable. -struct TrilToVec{S} - size::S -end - -struct TrilFromVec{S} - size::S -end - -Bijectors.inverse(f::TrilToVec) = TrilFromVec(f.size) -Bijectors.inverse(f::TrilFromVec) = TrilToVec(f.size) - -function (v::TrilToVec)(x) - mask = tril(trues(v.size)) - return vec(x[mask]) -end -function (v::TrilFromVec)(y) - mask = tril(trues(v.size)) - x = similar(y, v.size) - x[mask] .= y - return LowerTriangular(x) -end - -# Just some dummy values so we can make sure that the log-prob computation -# has been altered correctly. -Bijectors.with_logabsdet_jacobian(f::TrilToVec, x) = (f(x), log(eltype(x)(2))) -Bijectors.with_logabsdet_jacobian(f::TrilFromVec, x) = (f(x), -eltype(x)(log(2))) - -# Dummy example. -struct MyMatrixDistribution <: ContinuousMatrixDistribution - dim::Int -end - -Base.size(d::MyMatrixDistribution) = (d.dim, d.dim) -function Distributions._rand!( - rng::Random.AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real} -) - return randn!(rng, x) -end -function Distributions._logpdf(::MyMatrixDistribution, x::AbstractMatrix{<:Real}) - return -sum(abs2, LowerTriangular(x)) / 2 -end - -# Skip reconstruction in the inverse-map since it's no longer needed. -function DynamicPPL.from_linked_vec_transform(dist::MyMatrixDistribution) - return TrilFromVec((dist.dim, dist.dim)) -end - -# Specify the link-transform to use. -Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) -function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Bool) - lp = logpdf(dist, x) - if istrans - lp = lp - logabsdetjac(bijector(dist), x) - end - - return lp -end - -@testset "Linking (mutable=$mutable)" for mutable in [false, true] - @testset "simple matrix distribution" begin - # Just making sure the transformations are okay. - x = randn(3, 3) - f = TrilToVec((3, 3)) - f_inv = inverse(f) - y = f(x) - @test y isa AbstractVector - @test f_inv(f(x)) == LowerTriangular(x) - - # Within a model. - dist = MyMatrixDistribution(3) - @model demo() = m ~ dist - model = demo() - - example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - vi_linked = if mutable - DynamicPPL.link!!(deepcopy(vi), model) - else - DynamicPPL.link(vi, model) - end - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) - # Linked one should be working with a lower-dimensional representation. - @test length(vi_linked[:]) < length(vi[:]) - @test length(vi_linked[:]) == length(y) - # Invlinked. - vi_invlinked = if mutable - DynamicPPL.invlink!!(deepcopy(vi_linked), model) - else - DynamicPPL.invlink(vi_linked, model) - end - @test length(vi_invlinked[:]) == length(vi[:]) - @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) - end - end - - @testset "LKJCholesky" begin - @testset "uplo=$uplo" for uplo in ['L', 'U'] - @model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo) - @testset "d=$d" for d in [2, 3, 5] - model = demo_lkj(d) - dist = LKJCholesky(d, 1.0, uplo) - values_original = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos( - model, values_original, (@varname(x),) - ) - @testset "$(short_varinfo_name(vi))" for vi in vis - val = vi[@varname(x), dist] - # Ensure that `reconstruct` works as intended. - @test val isa Cholesky - @test val.uplo == uplo - - @test length(vi[:]) == d^2 - lp = logpdf(dist, val) - lp_model = logjoint(model, vi) - @test lp_model ≈ lp - # Linked. - vi_linked = if mutable - DynamicPPL.link!!(deepcopy(vi), model) - else - DynamicPPL.link(vi, model) - end - @test length(vi_linked[:]) == d * (d - 1) ÷ 2 - # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) - # Invlinked. - vi_invlinked = if mutable - DynamicPPL.invlink!!(deepcopy(vi_linked), model) - else - DynamicPPL.invlink(vi_linked, model) - end - @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp - end - end - end - end - - # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 - @testset "Dirichlet" begin - @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) - @testset "d=$d" for d in [2, 3, 5] - model = demo_dirichlet(d) - example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - lp = logpdf(Dirichlet(d, 1.0), vi[:]) - @test length(vi[:]) == d - lp_model = logjoint(model, vi) - @test lp_model ≈ lp - # Linked. - vi_linked = if mutable - DynamicPPL.link!!(deepcopy(vi), model) - else - DynamicPPL.link(vi, model) - end - @test length(vi_linked[:]) == d - 1 - # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) - # Invlinked. - vi_invlinked = if mutable - DynamicPPL.invlink!!(deepcopy(vi_linked), model) - else - DynamicPPL.invlink(vi_linked, model) - end - @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp - end - end - end - - # Related: https://github.com/TuringLang/Turing.jl/issues/2190 - @testset "High-dim Dirichlet" begin - @model function demo_highdim_dirichlet(ns...) - return x ~ filldist(Dirichlet(ones(2)), ns...) - end - @testset "ns=$ns" for ns in [ - (3,), - # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304 - # (3, 4), (3, 4, 5) - ] - model = demo_highdim_dirichlet(ns...) - example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - # Linked. - vi_linked = if mutable - DynamicPPL.link!!(deepcopy(vi), model) - else - DynamicPPL.link(vi, model) - end - @test length(vi_linked[:]) == prod(ns) - end - end - end -end diff --git a/test/lkj.jl b/test/lkj.jl deleted file mode 100644 index d581cd21b..000000000 --- a/test/lkj.jl +++ /dev/null @@ -1,61 +0,0 @@ -function pd_from_triangular(X::AbstractMatrix, uplo::Char) - # Pre-allocation fixes a problem with abstract element types in Julia 1.10 - # Ref https://github.com/TuringLang/DynamicPPL.jl/pull/570#issue-2092729916 - out = similar(X, Base.promote_op(*, eltype(X), eltype(X))) - if uplo === 'U' - mul!(out, UpperTriangular(X)', UpperTriangular(X)) - else - mul!(out, LowerTriangular(X), LowerTriangular(X)') - end - return out -end - -@model lkj_prior_demo() = x ~ LKJ(2, 1) -@model lkj_chol_prior_demo(uplo) = x ~ LKJCholesky(2, 1, uplo) - -# Same for both distributions -target_mean = vec(Matrix{Float64}(I, 2, 2)) - -_lkj_atol = 0.05 - -@testset "Sample from x ~ LKJ(2, 1)" begin - model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end -end - -@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] - model = lkj_chol_prior_demo(uplo) - # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end -end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl deleted file mode 100644 index beda767e6..000000000 --- a/test/logdensityfunction.jl +++ /dev/null @@ -1,36 +0,0 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model - - # ReverseDiff related - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) - @test DynamicPPL.getmodel(∇ℓ) == model - @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == - model - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) - new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) - @test DynamicPPL.getmodel(new_∇ℓ) == model - # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` - @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape - end -end - -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - @testset "$(varinfo)" for varinfo in varinfos - logdensity = DynamicPPL.LogDensityFunction(model, varinfo) - θ = varinfo[:] - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) - end - end -end diff --git a/test/model.jl b/test/model.jl deleted file mode 100644 index 256ada0ad..000000000 --- a/test/model.jl +++ /dev/null @@ -1,609 +0,0 @@ -# some functors (#367) -struct MyModel - a::Int -end -@model function (f::MyModel)(x) - m ~ Normal(f.a, 1) - return x ~ Normal(m, 1) -end -struct MyZeroModel end -@model function (::MyZeroModel)(x) - m ~ Normal(0, 1) - return x ~ Normal(m, 1) -end - -innermost_distribution_type(d::Distribution) = typeof(d) -function innermost_distribution_type(d::Distributions.ReshapedDistribution) - return innermost_distribution_type(d.dist) -end -function innermost_distribution_type(d::Distributions.Product) - dists = map(innermost_distribution_type, d.v) - if any(!=(dists[1]), dists) - error("Cannot extract innermost distribution type from $d") - end - - return dists[1] -end - -is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true -is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true - -const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() - -@testset "model.jl" begin - @testset "convenience functions" begin - model = GDEMO_DEFAULT - - # sample from model and extract variables - vi = VarInfo(model) - s = vi[@varname(s)] - m = vi[@varname(m)] - - # extract log pdf of variable object - lp = getlogp(vi) - - # log prior probability - lprior = logprior(model, vi) - @test lprior ≈ logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) - - # log likelihood - llikelihood = loglikelihood(model, vi) - @test llikelihood ≈ loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) - - # log joint probability - ljoint = logjoint(model, vi) - @test ljoint ≈ lprior + llikelihood - @test ljoint ≈ lp - - #### logprior, logjoint, loglikelihood for MCMC chains #### - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 200 - chain = make_chain_from_prior(model, N) - logpriors = logprior(model, chain) - loglikelihoods = loglikelihood(model, chain) - logjoints = logjoint(model, chain) - - # Construct mapping of varname symbols to varname-parent symbols. - # Here, varname_leaves is used to ensure compatibility with the - # variables stored in the chain - var_info = VarInfo(model) - chain_sym_map = Dict{Symbol,Symbol}() - for vn_parent in keys(var_info) - sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) - for vn_child in vn_children - chain_sym_map[Symbol(vn_child)] = sym - end - end - - # compare them with true values - for i in 1:N - samples_dict = Dict() - for chain_key in keys(chain) - value = chain[i, chain_key, 1] - key = chain_sym_map[chain_key] - existing_value = get(samples_dict, key, Float64[]) - push!(existing_value, value) - samples_dict[key] = existing_value - end - samples = (; samples_dict...) - samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl - @test logpriors[i] ≈ - DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m]) - @test loglikelihoods[i] ≈ DynamicPPL.TestUtils.loglikelihood_true( - model, samples[:s], samples[:m] - ) - @test logjoints[i] ≈ - DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) - end - end - end - - @testset "model de/conditioning" begin - @model function demo_condition() - x ~ Normal() - return y ~ Normal(x) - end - model = demo_condition() - - # Test that different syntaxes work and give the same underlying ConditionContext - @testset "conditioning NamedTuple" begin - expected_values = (y=2,) - @test condition(model, (y=2,)).context.values == expected_values - @test condition(model; y=2).context.values == expected_values - @test condition(model; y=2).context.values == expected_values - @test (model | (y=2,)).context.values == expected_values - conditioned_model = condition(model, (y=2,)) - @test keys(VarInfo(conditioned_model)) == [@varname(x)] - end - @testset "conditioning AbstractDict" begin - expected_values = Dict(@varname(y) => 2) - @test condition(model, Dict(@varname(y) => 2)).context.values == expected_values - @test condition(model, @varname(y) => 2).context.values == expected_values - @test (model | (@varname(y) => 2,)).context.values == expected_values - conditioned_model = condition(model, Dict(@varname(y) => 2)) - @test keys(VarInfo(conditioned_model)) == [@varname(x)] - end - - @testset "deconditioning" begin - conditioned_model = condition(model, (y=2,)) - deconditioned_model = decondition(conditioned_model) - @test keys(VarInfo(deconditioned_model)) == [@varname(x), @varname(y)] - end - end - - @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin - @model function multiple_types(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = multiple_types(1) - chain = make_chain_from_prior(model, 10) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - model(Random.default_rng(), vi, sampler) - vals = DynamicPPL.getall(vi) - - Random.seed!(100 + i) - vi = VarInfo() - model(Random.default_rng(), vi, sampler) - @test DynamicPPL.getall(vi) == vals - end - end - end - - @testset "defaults without VarInfo, Sampler, and Context" begin - model = GDEMO_DEFAULT - - Random.seed!(100) - retval = model() - - Random.seed!(100) - retval2 = model(Random.default_rng()) - @test retval2.s == retval.s - @test retval2.m == retval.m - end - - @testset "nameof" begin - @model function test1(x) - m ~ Normal(0, 1) - return x ~ Normal(m, 1) - end - @model test2(x) = begin - m ~ Normal(0, 1) - x ~ Normal(m, 1) - end - function test3 end - @model function (::typeof(test3))(x) - m ~ Normal(0, 1) - return x ~ Normal(m, 1) - end - function test4 end - @model function (a::typeof(test4))(x) - m ~ Normal(0, 1) - return x ~ Normal(m, 1) - end - - @test nameof(test1(rand())) == :test1 - @test nameof(test2(rand())) == :test2 - @test nameof(test3(rand())) == :test3 - @test nameof(test4(rand())) == :test4 - - # callables - @test nameof(MyModel(3)(rand())) == Symbol("MyModel(3)") - @test nameof(MyZeroModel()(rand())) == Symbol("MyZeroModel()") - end - - @testset "Internal methods" begin - model = GDEMO_DEFAULT - - # sample from model and extract variables - vi = VarInfo(model) - - # Second component of return-value of `evaluate!!` should - # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) - @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo - - # Should not return `AbstractVarInfo` when we call the model. - call_retval = model() - @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) - end - - @testset "Dynamic constraints, Metadata" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - spl = SampleFromPrior() - vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) - vi = link!!(vi, model) - - for i in 1:10 - # Sample with large variations. - r_raw = randn(length(vi[:])) * 10 - DynamicPPL.setall!(vi, r_raw) - @test vi[@varname(m)] == r_raw[1] - @test vi[@varname(x)] != r_raw[2] - model(vi) - end - end - - @testset "Dynamic constraints, VectorVarInfo" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - for i in 1:10 - vi = VarInfo(model) - @test vi[@varname(x)] >= vi[@varname(m)] - end - end - - @testset "rand" begin - model = GDEMO_DEFAULT - - Random.seed!(1776) - s, m = model() - sample_namedtuple = (; s=s, m=m) - sample_dict = OrderedDict(@varname(s) => s, @varname(m) => m) - - # With explicit RNG - @test rand(Random.seed!(1776), model) == sample_namedtuple - @test rand(Random.seed!(1776), NamedTuple, model) == sample_namedtuple - @test rand(Random.seed!(1776), Dict, model) == sample_dict - - # Without explicit RNG - Random.seed!(1776) - @test rand(model) == sample_namedtuple - Random.seed!(1776) - @test rand(NamedTuple, model) == sample_namedtuple - Random.seed!(1776) - @test rand(OrderedDict, model) == sample_dict - end - - @testset "default arguments" begin - @model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I) - @test length(test_defaults(missing, 2)()) == 2 - end - - @testset "missing kwarg" begin - @model test_missing_kwarg(; x=missing) = x ~ Normal(0, 1) - @test :x in keys(rand(test_missing_kwarg())) - end - - @testset "extract priors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - priors = extract_priors(model) - - # We know that any variable starting with `s` should have `InverseGamma` - # and any variable starting with `m` should have `Normal`. - for (vn, prior) in priors - if DynamicPPL.getsym(vn) == :s - @test innermost_distribution_type(prior) <: InverseGamma - elseif DynamicPPL.getsym(vn) == :m - @test innermost_distribution_type(prior) <: Union{Normal,MvNormal} - else - error("Unexpected variable name: $vn") - end - end - end - end - - @testset "TestUtils" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - x = DynamicPPL.TestUtils.rand_prior_true(model) - # `rand_prior_true` should return a `NamedTuple`. - @test x isa NamedTuple - - # `rand` with a `AbstractDict` should have `varnames` as keys. - x_rand_dict = rand(OrderedDict, model) - for vn in DynamicPPL.TestUtils.varnames(model) - @test haskey(x_rand_dict, vn) - end - # `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys. - x_rand_nt = rand(NamedTuple, model) - for vn in DynamicPPL.TestUtils.varnames(model) - @test haskey(x_rand_nt, Symbol(vn)) - end - - # Ensure log-probability computations are implemented. - @test logprior(model, x) ≈ DynamicPPL.TestUtils.logprior_true(model, x...) - @test loglikelihood(model, x) ≈ - DynamicPPL.TestUtils.loglikelihood_true(model, x...) - @test logjoint(model, x) ≈ DynamicPPL.TestUtils.logjoint_true(model, x...) - @test logjoint(model, x) != - DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) - # Ensure `varnames` is implemented. - vi = last( - DynamicPPL.evaluate!!( - model, SimpleVarInfo(OrderedDict()), SamplingContext() - ), - ) - @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) - # Ensure `posterior_mean` is implemented. - @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) - end - end - - @testset "returned() on `LKJCholesky`" begin - n = 10 - d = 2 - model = DynamicPPL.TestUtils.demo_lkjchol(d) - xs = [model().x for _ in 1:n] - - # Extract varnames and values. - vns_and_vals_xs = map( - collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs - ) - vns = map(first, first(vns_and_vals_xs)) - vals = map(vns_and_vals_xs) do vns_and_vals - map(last, vns_and_vals) - end - - # Construct the chain. - syms = map(Symbol, vns) - vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms)) - - chain = MCMCChains.Chains( - permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,) - ) - - # Test! - results = returned(model, chain) - for (x_true, result) in zip(xs, results) - @test x_true.UL == result.x.UL - end - - # With variables that aren't in the `model`. - vns_to_syms_with_extra = let d = deepcopy(vns_to_syms) - d[@varname(y)] = :y - d - end - vals_with_extra = map(enumerate(vals)) do (i, v) - vcat(v, i) - end - chain_with_extra = MCMCChains.Chains( - permutedims(stack(vals_with_extra)), - vcat(syms, [:y]); - info=(varname_to_symbol=vns_to_syms_with_extra,), - ) - # Test! - results = returned(model, chain_with_extra) - for (x_true, result) in zip(xs, results) - @test x_true.UL == result.x.UL - end - end - - if VERSION >= v"1.8" - @testset "Type stability of models" begin - models_to_test = [ - DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) - ] - context = DefaultContext() - @testset "$(model.f)" for model in models_to_test - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_typed_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) - true - end - - varinfo_linked = DynamicPPL.link(varinfo, model) - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) - true - end - end - end - end - end - - @testset "values_as_in_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # We can set the include_colon_eq arg to false because none of - # the demo models contain :=. The behaviour when - # include_colon_eq is true is tested in test/compiler.jl - realizations = values_as_in_model(model, false, varinfo) - # Ensure that all variables are found. - vns_found = collect(keys(realizations)) - @test vns ∩ vns_found == vns ∪ vns_found - # Ensure that the values are the same. - for vn in vns - @test realizations[vn] == varinfo[vn] - end - end - end - - @testset "Prefixing" begin - @model inner() = x ~ Normal() - - @model function outer_auto_prefix() - a ~ to_submodel(inner(), true) - b ~ to_submodel(inner(), true) - return nothing - end - @model function outer_manual_prefix() - a ~ to_submodel(prefix(inner(), :a), false) - b ~ to_submodel(prefix(inner(), :b), false) - return nothing - end - - for model in (outer_auto_prefix(), outer_manual_prefix()) - vi = VarInfo(model) - vns = Set(keys(values_as_in_model(model, false, vi))) - @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) - end - end - end - - @testset "Erroneous model call" begin - # Calling a model with the wrong arguments used to lead to infinite recursion, see - # https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it. - @model function a_model(x) - m ~ Normal(0, 1) - x ~ Normal(m, 1) - return nothing - end - instance = a_model(1.0) - # `instance` should be called with rng, context, etc., but one may easily get - # confused and call it the way you are meant to call `a_model`. - @test_throws MethodError instance(1.0) - end - - @testset "Product distribution with changing support" begin - @model function product_dirichlet() - return x ~ product_distribution(fill(Dirichlet(ones(4)), 2, 3)) - end - model = product_dirichlet() - - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - varinfo_linked = DynamicPPL.link(varinfo, model) - varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) - ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) - end - end - - @testset "predict" begin - @testset "with MCMCChains.Chains" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(0, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - # Insert a := block to test that it is not included in predictions - return σ2 := σ^2 - end - - # Construct a chain with 'sampled values' of β - ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) - - # Generate predictions from that chain - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, β_chain) - - # Also test a vectorized model - @model function linear_reg_vec(x, y, σ=0.1) - β ~ Normal(0, 1) - return y ~ MvNormal(β .* x, σ^2 * I) - end - m_lin_reg_test_vec = linear_reg_vec(xs_test, missing) - - @testset "variables in chain" begin - # Note that this also checks that variables on the lhs of :=, - # such as σ2, are not included in the resulting chain - @test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")]) - end - - @testset "accuracy" begin - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) - @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 - end - - @testset "ensure that rng is respected" begin - rng = MersenneTwister(42) - predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) - predictions2 = DynamicPPL.predict( - MersenneTwister(42), m_lin_reg_test, β_chain[1:2] - ) - @test all(Array(predictions1) .== Array(predictions2)) - end - - @testset "accuracy on vectorized model" begin - predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain) - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) - - @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 - end - - @testset "prediction from multiple chains" begin - # Normal linreg model - multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] - ) - predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) - @test size(multiple_β_chain, 3) == size(predictions, 3) - - for chain_idx in MCMCChains.chains(multiple_β_chain) - ys_pred = vec( - mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1) - ) - @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 - end - - # Vectorized linreg model - predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain) - - for chain_idx in MCMCChains.chains(multiple_β_chain) - ys_pred_vec = vec( - mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) - ) - @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 - end - end - end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end - end -end diff --git a/test/model_utils.jl b/test/model_utils.jl deleted file mode 100644 index 720ae55aa..000000000 --- a/test/model_utils.jl +++ /dev/null @@ -1,20 +0,0 @@ -@testset "model_utils.jl" begin - @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - # Check that the values generated by value_iterator_from_chain - # match the values in the original chain - chain = make_chain_from_prior(model, 10) - for (i, d) in enumerate(value_iterator_from_chain(model, chain)) - for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) - # Because value_iterator_from_chain groups varnames with - # the same parent symbol, we have to ungroup them here - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) - @test val_leaf == chain[i, Symbol(vn_leaf), 1] - end - end - end - end - end -end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl deleted file mode 100644 index 5c0b2e090..000000000 --- a/test/pointwise_logdensities.jl +++ /dev/null @@ -1,101 +0,0 @@ -@testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(model) - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( - model, example_values... - ) - logprior_true = logprior(model, vi) - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(model, vi) - if isempty(lls) - # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. - loglikelihood_true = 0.0 - else - @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) - loglikelihood_sum = sum(sum, values(lls)) - @test loglikelihood_sum ≈ loglikelihood_true - end - - # Compute the pointwise logdensities of the priors. - lps_prior = pointwise_prior_logdensities(model, vi) - @test :x ∉ DynamicPPL.getsym.(keys(lps_prior)) - logp = sum(sum, values(lps_prior)) - @test logp ≈ logprior_true - - # Compute both likelihood and logdensity of prior - # using the default DefaultContext - lps = pointwise_logdensities(model, vi) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 - end -end - -@testset "pointwise_logdensities chain" begin - # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, - # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() - # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced - # an impl of this for containers. - # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. - vns = DynamicPPL.TestUtils.varnames(model) - # Get some random `NamedTuple` samples from the prior. - num_iters = 3 - vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] - # Concatenate the vector representations and create a `Chains` from it. - vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)) - - # Compute the different pointwise logdensities. - logjoints_pointwise = pointwise_logdensities(model, chain) - logpriors_pointwise = pointwise_prior_logdensities(model, chain) - loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) - - # Check that they contain the correct variables. - @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) - @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) - @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) - @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) - @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) - - # Get the sum of the logjoints for each of the iterations. - logjoints = [ - sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for - idx in 1:num_iters - ] - logpriors = [ - sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for - idx in 1:num_iters - ] - loglikelihoods = [ - sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for - idx in 1:num_iters - ] - - for (val, logjoint, logprior, loglikelihood) in - zip(vals, logjoints, logpriors, loglikelihoods) - # Compare true logjoint with the one obtained from `pointwise_logdensities`. - logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) - logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) - - @test logjoint ≈ logjoint_true - @test logprior ≈ logprior_true - @test loglikelihood ≈ loglikelihood_true - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 29a148789..473420dfd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,101 +1,122 @@ -using Accessors -using ADTypes -using DynamicPPL -using AbstractMCMC -using AbstractPPL -using Bijectors -using DifferentiationInterface using Distributions -using DistributionsAD -using Documenter -using ForwardDiff -using LogDensityProblems, LogDensityProblemsAD -using MacroTools -using MCMCChains -using Mooncake: Mooncake -using StableRNGs -using Tracker -using ReverseDiff -using Zygote -using Compat - -using Distributed -using LinearAlgebra -using Pkg +using DynamicPPL using Random -using Serialization using Test -using Distributions -using LinearAlgebra # Diagonal - -using JET: JET - -using Combinatorics: combinations -using OrderedCollections: OrderedSet - -using DynamicPPL: getargs_dottilde, getargs_tilde, Selector -const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) -include("test_util.jl") +@testset verbose = true "submodel tests" begin + @testset "sanity check with original models" begin + @model f() = x ~ Normal() + model = f() + vi = VarInfo(model) + # check parent varinfo + @test Set(keys(vi)) == Set([@varname(x)]) + @test vi[@varname(x)] isa Float64 + # check logp + @test DynamicPPL.getlogp(vi) ≈ logpdf(Normal(), vi[@varname(x)]) + end -@testset verbose = true "DynamicPPL.jl" begin - # The tests are split into two groups so that CI can run in parallel. The - # groups are chosen to make both groups take roughly the same amount of - # time, but beyond that there is no particular reason for the split. - if GROUP == "All" || GROUP == "Group1" - include("utils.jl") - include("compiler.jl") - include("varnamedvector.jl") - include("varinfo.jl") - include("simple_varinfo.jl") - include("model.jl") - include("sampler.jl") - include("independence.jl") - include("distribution_wrappers.jl") - include("logdensityfunction.jl") - include("linking.jl") - include("serialization.jl") - include("pointwise_logdensities.jl") - include("lkj.jl") - include("deprecated.jl") + @testset "submodel - assume - no rhs" begin + @model function g() + a ~ Normal() + return "foo" + end + @model function f() + x ~ Normal() + lhs ~ g() + return (__varinfo__, lhs) + end + model = f() + (vi, lhs) = model() + # Check parent model varinfo + @test Set(keys(vi)) == Set([@varname(x), @varname(lhs.a)]) + @test vi[@varname(x)] isa Float64 + @test vi[@varname(lhs.a)] isa Float64 + # check the lhs of submodel tilde + @test lhs isa OrderedDict + @test lhs[@varname(a)] isa Float64 + @test lhs[@varname(a)] == vi[@varname(lhs.a)] + # check logp accumulated correctly + @test DynamicPPL.getlogp(vi) ≈ + logpdf(Normal(), vi[@varname(x)]) + logpdf(Normal(), vi[@varname(lhs.a)]) end - if GROUP == "All" || GROUP == "Group2" - include("contexts.jl") - include("context_implementations.jl") - include("threadsafe.jl") - include("debug_utils.jl") - @testset "compat" begin - include(joinpath("compat", "ad.jl")) + @testset "submodel - assume - with rhs" begin + @model function g() + a ~ Normal() + return "foo" end - @testset "extensions" begin - include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") + @model function f() + x ~ Normal() + lhs ~ g() --> rhs + return (__varinfo__, lhs, rhs) end - @testset "ad" begin - include("ext/DynamicPPLForwardDiffExt.jl") - include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") + model = f() + (vi, lhs, rhs) = model() + # Check parent model varinfo + @test Set(keys(vi)) == Set([@varname(x), @varname(lhs.a)]) + @test vi[@varname(x)] isa Float64 + @test vi[@varname(lhs.a)] isa Float64 + # check the lhs of submodel tilde + @test lhs isa OrderedDict + @test lhs[@varname(a)] isa Float64 + @test lhs[@varname(a)] == vi[@varname(lhs.a)] + # check the rhs + @test rhs == "foo" + # check logp accumulated correctly + @test DynamicPPL.getlogp(vi) ≈ + logpdf(Normal(), vi[@varname(x)]) + logpdf(Normal(), vi[@varname(lhs.a)]) + end + + @testset "submodel - assume - nested with rhs" begin + # OK, this is getting a bit confusing, so I added some annotations. + @model function h() + q ~ Normal() + return "bar" end - @testset "prob and logprob macro" begin - @test_throws ErrorException prob"..." - @test_throws ErrorException logprob"..." + @model function g() + p ~ Normal() + a ~ h() --> b + # Here, `a` should be an OrderedDict with a single key, `q` + # `b` should be "bar" + return ("foo", a, b) end - @testset "doctests" begin - DocMeta.setdocmeta!( - DynamicPPL, - :DocTestSetup, - :(using DynamicPPL, Distributions); - recursive=true, - ) - doctestfilters = [ - # Ignore the source of a warning in the doctest output, since this is dependent on host. - # This is a line that starts with "└ @ " and ends with the line number. - r"└ @ .+:[0-9]+", - ] - doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) + @model function f() + x ~ Normal() + lhs ~ g() --> rhs + # Here, `lhs` should be an OrderedDict with two keys, `p` and `a` + # lhs[`p`] should be a Float64, and lhs[`a`] should itself be an + # OrderedDict with a single key `q`. + # `rhs` should be the return value of g, i.e. a 3-tuple + # ("foo", OrderedDict(`q` -> Float64), "bar") + return (__varinfo__, lhs, rhs) end + + model = f() + (vi, lhs, rhs) = model() + # Check parent model varinfo + @test Set(keys(vi)) == Set([@varname(x), @varname(lhs.p), @varname(lhs.a.q)]) + @test vi[@varname(x)] isa Float64 + @test vi[@varname(lhs.p)] isa Float64 + @test vi[@varname(lhs.a.q)] isa Float64 + # check the lhs of submodel tilde + @test lhs isa OrderedDict + @test lhs[@varname(p)] isa Float64 + @test lhs[@varname(p)] == vi[@varname(lhs.p)] + @test_throws KeyError lhs[@varname(a)][@varname(q)] isa Float64 + @test_throws KeyError lhs[@varname(a)][@varname(q)] == vi[@varname(lhs.a.q)] + # check the rhs of submodel tilde + (foo, a, bar) = rhs + @test foo == "foo" + @test a isa OrderedDict + @test_throws KeyError a[@varname(q)] isa Float64 + @test_throws KeyError a[@varname(q)] == vi[@varname(lhs.a.q)] + @test bar == "bar" + # check logp accumulated correctly + @test DynamicPPL.getlogp(vi) ≈ + logpdf(Normal(), vi[@varname(x)]) + + logpdf(Normal(), vi[@varname(lhs.p)]) + + logpdf(Normal(), vi[@varname(lhs.a.q)]) end end diff --git a/test/sampler.jl b/test/sampler.jl deleted file mode 100644 index 50111b1fd..000000000 --- a/test/sampler.jl +++ /dev/null @@ -1,207 +0,0 @@ -@testset "sampler.jl" begin - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - - @testset "Initial parameters" begin - # dummy algorithm that just returns initial value and does not perform any sampling - abstract type OnlyInitAlg end - struct OnlyInitAlgDefault <: OnlyInitAlg end - struct OnlyInitAlgUniform <: OnlyInitAlg end - function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - ::Sampler{<:OnlyInitAlg}, - vi::AbstractVarInfo; - kwargs..., - ) - return vi, nothing - end - DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () - - # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() - - for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) - # model with one variable: initialization p = 0.2 - @model function coinflip() - p ~ Beta(1, 1) - return 10 ~ Binomial(25, p) - end - model = coinflip() - sampler = Sampler(alg) - lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue - end - end - - # model with two variables: initialization s = 4, m = -1 - @model function twovars() - s ~ InverseGamma(2, 3) - return m ~ Normal(0, sqrt(s)) - end - model = twovars() - lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test chain[1].metadata.s.vals == [4] - @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.s.vals == [4] - @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue - end - end - - # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) - @test !ismissing(chain[1].metadata.s.vals[1]) - @test chain[1].metadata.m.vals == [-1] - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(inits, 10), - progress=false, - ) - for c in chains - @test !ismissing(c[1].metadata.s.vals[1]) - @test c[1].metadata.m.vals == [-1] - end - end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) - end - end -end diff --git a/test/serialization.jl b/test/serialization.jl deleted file mode 100644 index a2d9abb36..000000000 --- a/test/serialization.jl +++ /dev/null @@ -1,53 +0,0 @@ -@testset "serialization.jl" begin - @testset "saving and loading" begin - # Save model. - file = joinpath(mktempdir(), "gdemo_default.jls") - serialize(file, gdemo_default) - - # Sample from deserialized model. - gdemo_default_copy = deserialize(file) - samples = [gdemo_default_copy() for _ in 1:1_000] - samples_s = first.(samples) - samples_m = last.(samples) - - @test mean(samples_s) ≈ 3 atol = 0.2 - @test mean(samples_m) ≈ 0 atol = 0.15 - end - @testset "pmap" begin - # Add worker processes. - pids = addprocs() - @info "serialization test: using $(nworkers()) processes" - - # Load packages on all processes. - @everywhere begin - using DynamicPPL - using Distributions - end - - # Define model on all proceses. - @everywhere @model function model() - return m ~ Normal(0, 1) - end - - # Generate `Model` objects on all processes. - models = pmap(_ -> model(), 1:100) - @test models isa Vector{<:Model} - @test length(models) == 100 - - # Sample from model on all processes. - n = 1_000 - samples1 = pmap(_ -> model()(), 1:n) - m = model() - samples2 = pmap(_ -> m(), 1:n) - - for samples in (samples1, samples2) - @test samples isa Vector{Float64} - @test length(samples) == n - @test mean(samples) ≈ 0 atol = 0.15 - @test std(samples) ≈ 1 atol = 0.1 - end - - # Remove processes - rmprocs(pids...) - end -end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 137c791c2..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,319 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,), 1.0) - @test getlogp(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), - SimpleVarInfo(values_constrained), - SimpleVarInfo(DynamicPPL.VarNamedVector()), - VarInfo(model), - ) - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_linked ≈ lp_linked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_linked) ≈ lp_linked - - # TODO: Should not `VarInfo` also error here? The current implementation - # only warns and acts as a no-op. - if vi isa SimpleVarInfo - @test_throws AssertionError link!!(vi_linked, model) - end - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_invlinked ≈ lp_invlinked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_invlinked) ≈ lp_invlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() - - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( - svi_nt, - svi_dict, - svi_vnv, - DynamicPPL.settrans!!(deepcopy(svi_nt), true), - DynamicPPL.settrans!!(deepcopy(svi_dict), true), - DynamicPPL.settrans!!(deepcopy(svi_vnv), true), - ) - # RandOM seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.istrans(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp field. - svi_eval = DynamicPPL.resetlogp!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_eval[vn] == get(values_eval, vn) - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) - svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) - @test lp ≈ lp_true - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.istrans( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) - @test !DynamicPPL.istrans(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - retval, vi_linked_result = DynamicPPL.evaluate!!( - model, deepcopy(vi_linked), DefaultContext() - ) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_linked_result[@varname(m)] == retval.m - - # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) - @test lp ≈ lp_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl deleted file mode 100644 index 27a68456c..000000000 --- a/test/test_util.jl +++ /dev/null @@ -1,111 +0,0 @@ -# default model -@model function gdemo_d() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - 1.5 ~ Normal(m, sqrt(s)) - 2.0 ~ Normal(m, sqrt(s)) - return s, m -end -const gdemo_default = gdemo_d() - -function test_model_ad(model, logp_manual) - vi = VarInfo(model) - x = DynamicPPL.getall(vi) - - # Log probabilities using the model. - ℓ = DynamicPPL.LogDensityFunction(model, vi) - logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Check that both functions return the same values. - lp = logp_manual(x) - @test logp_model(x) ≈ lp - - # Gradients based on the manual implementation. - grad = ForwardDiff.gradient(logp_manual, x) - - y, back = Tracker.forward(logp_manual, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_manual, x) - @test y ≈ lp - @test back(1)[1] ≈ grad - - # Gradients based on the model. - @test ForwardDiff.gradient(logp_model, x) ≈ grad - - y, back = Tracker.forward(logp_model, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_model, x) - @test y ≈ lp - @test back(1)[1] ≈ grad -end - -""" - short_varinfo_name(vi::AbstractVarInfo) - -Return string representing a short description of `vi`. -""" -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" - return "TypedVarInfo" -end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" -end - -# convenient functions for testing model.jl -# function to modify the representation of values based on their length -function modify_value_representation(nt::NamedTuple) - modified_nt = NamedTuple() - for (key, value) in zip(keys(nt), values(nt)) - if length(value) == 1 # Scalar value - modified_value = value[1] - else # Non-scalar value - modified_value = value - end - modified_nt = merge(modified_nt, (key => modified_value,)) - end - return modified_nt -end - -""" - make_chain_from_prior([rng,] model, n_iters) - -Construct an MCMCChains.Chains object by sampling from the prior of `model` for -`n_iters` iterations. -""" -function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - # Sample from the prior - varinfos = [VarInfo(rng, model) for _ in 1:n_iters] - # Extract all varnames found in any dictionary. Doing it this way guards - # against the possibility of having different varnames in different - # dictionaries, e.g. for models that have dynamic variables / array sizes - varnames = OrderedSet{VarName}() - # Convert each varinfo into an OrderedDict of vns => params. - # We have to use varname_and_value_leaves so that each parameter is a scalar - dicts = map(varinfos) do t - vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - tuples = mapreduce(collect, vcat, iters) - push!(varnames, map(first, tuples)...) - OrderedDict(tuples) - end - # Convert back to list - varnames = collect(varnames) - # Construct matrix of values - vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] - # Construct and return the Chains object - return Chains(vals, varnames) -end -function make_chain_from_prior(model::Model, n_iters::Int) - return make_chain_from_prior(Random.default_rng(), model, n_iters) -end diff --git a/test/threadsafe.jl b/test/threadsafe.jl deleted file mode 100644 index 72c439db8..000000000 --- a/test/threadsafe.jl +++ /dev/null @@ -1,117 +0,0 @@ -@testset "threadsafe.jl" begin - @testset "constructor" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) - - @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() - @test all(iszero(x[]) for x in threadsafe_vi.logps) - end - - # TODO: Add more tests of the public API - @testset "API" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp - - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 - - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) - @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) - - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 - @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) - end - - @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - - x = rand(10_000) - - @model function wthreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - - vi = VarInfo() - wthreads(x)(vi) - lp_w_threads = getlogp(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time wthreads(x)(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - - vi = VarInfo() - wothreads(x)(vi) - lp_wo_threads = getlogp(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time wothreads(x)(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads - @test vi_ isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - end -end diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index d683f132d..000000000 --- a/test/utils.jl +++ /dev/null @@ -1,74 +0,0 @@ -@testset "utils.jl" begin - @testset "addlogprob!" begin - @model function testmodel() - global lp_before = getlogp(__varinfo__) - @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) - end - - model = testmodel() - varinfo = VarInfo(model) - @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 - end - - @testset "getargs_dottilde" begin - # Some things that are not expressions. - @test getargs_dottilde(:x) === nothing - @test getargs_dottilde(1.0) === nothing - @test getargs_dottilde([1.0, 2.0, 4.0]) === nothing - - # Some expressions. - @test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing - @test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) - @test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing - @test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing - @test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing - end - - @testset "getargs_tilde" begin - # Some things that are not expressions. - @test getargs_tilde(:x) === nothing - @test getargs_tilde(1.0) === nothing - @test getargs_tilde([1.0, 2.0, 4.0]) === nothing - - # Some expressions. - @test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) - @test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing - @test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing - @test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing - @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing - @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing - end - - @testset "tovec" begin - dist = LKJCholesky(2, 1) - x = rand(dist) - @test DynamicPPL.tovec(x) == vec(x.UL) - end - - @testset "unique_syms" begin - vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) - @inferred DynamicPPL.unique_syms(vns) - @inferred DynamicPPL.unique_syms(()) - @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) - @test DynamicPPL.unique_syms(()) == () - end - - @testset "group_varnames_by_symbol" begin - vns_tuple = ( - @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) - ) - vns_vec = collect(vns_tuple) - vns_nt = (; - x=[@varname(x), @varname(x.a)], - y=[@varname(y[1]), @varname(y[2])], - z=[@varname(z[15])], - ) - vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) - @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl deleted file mode 100644 index d689a1bf4..000000000 --- a/test/varinfo.jl +++ /dev/null @@ -1,1010 +0,0 @@ -function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end -end - -""" -Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. -""" -function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist) - r - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) - r - else - vi[vn] - end -end - -@testset "varinfo.jl" begin - @testset "TypedVarInfo with Metadata" begin - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - vi = VarInfo(DynamicPPL.Metadata()) - model(vi, SampleFromUniform()) - tvi = TypedVarInfo(vi) - - meta = vi.metadata - for f in fieldnames(typeof(tvi.metadata)) - fmeta = getfield(tvi.metadata, f) - for vn in fmeta.vns - @test tvi[vn] == vi[vn] - ind = meta.idcs[vn] - tind = fmeta.idcs[vn] - @test meta.dists[ind] == fmeta.dists[tind] - @test meta.orders[ind] == fmeta.orders[tind] - for flag in keys(meta.flags) - @test meta.flags[flag][ind] == fmeta.flags[flag][tind] - end - range = meta.ranges[ind] - trange = fmeta.ranges[tind] - @test all(meta.vals[range] .== fmeta.vals[trange]) - end - end - end - - @testset "Base" begin - # Test Base functions: - # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, - # getindex, setindex!, getproperty, setproperty! - csym = gensym() - vn1 = @varname x[1][2] - @test string(vn1) == "x[1][2]" - @test Symbol(vn1) == Symbol("x[1][2]") - - vn2 = @varname x[1][2] - @test vn2 == vn1 - @test hash(vn2) == hash(vn1) - - function test_base!!(vi_original) - vi = empty!!(vi_original) - @test getlogp(vi) == 0 - @test isempty(vi[:]) - - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - - @test isempty(vi) - @test ~haskey(vi, vn) - @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist) - @test ~isempty(vi) - @test haskey(vi, vn) - @test vn in keys(vi) - - @test length(vi[vn]) == 1 - @test vi[vn] == r - vi = DynamicPPL.setindex!!(vi, 2 * r, vn) - @test vi[vn] == 2 * r - - # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.VectorVarInfo - delete!(vi, vn) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - end - - vi = empty!!(vi) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - @test ~isempty(vi) - end - - vi = VarInfo() - test_base!!(vi) - test_base!!(TypedVarInfo(vi)) - test_base!!(SimpleVarInfo()) - test_base!!(SimpleVarInfo(Dict())) - test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) - end - - @testset "get/set/acc/resetlogp" begin - function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 - vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - end - - vi = VarInfo() - test_varinfo_logp!(vi) - test_varinfo_logp!(TypedVarInfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(Dict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) - end - - @testset "flags" begin - # Test flag setting: - # is_flagged, set_flag!, unset_flag! - function test_varinfo!(vi) - vn_x = @varname x - dist = Normal(0, 1) - r = rand(dist) - - push!!(vi, vn_x, r, dist) - - # del is set by default - @test !is_flagged(vi, vn_x, "del") - - set_flag!(vi, vn_x, "del") - @test is_flagged(vi, vn_x, "del") - - unset_flag!(vi, vn_x, "del") - @test !is_flagged(vi, vn_x, "del") - end - vi = VarInfo(DynamicPPL.Metadata()) - test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) - end - - @testset "push!! to TypedVarInfo" begin - vn_x = @varname x - vn_y = @varname y - untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = TypedVarInfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) - @test typed_vi[vn_x] == 1.0 - @test typed_vi[vn_y] == 2.0 - end - - @testset "setval! & setval_and_resample!" begin - @model function testmodel(x) - n = length(x) - s ~ truncated(Normal(), 0, Inf) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, s^2 * I) - end - - @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} - n = length(x) - s ~ truncated(Normal(), 0, Inf) - - m = TV(undef, n) - for i in eachindex(m) - m[i] ~ Normal() - end - - for i in eachindex(x) - x[i] ~ Normal(m[i], s) - end - end - - x = randn(5) - model_mv = testmodel(x) - model_uv = testmodel_univariate(x) - - for model in [model_uv, model_mv] - m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) - s_vns = @varname(s) - - vi_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() - ) - vi_untyped = VarInfo(DynamicPPL.Metadata()) - vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) - vi_vnv_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() - ) - model(vi_untyped, SampleFromPrior()) - model(vi_vnv, SampleFromPrior()) - - model_name = model == model_uv ? "univariate" : "multivariate" - @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ - vi_untyped, vi_typed, vi_vnv, vi_vnv_typed - ] - Random.seed!(23) - vicopy = deepcopy(vi) - - ### `setval` ### - # TODO(mhauru) The interface here seems inconsistent between Metadata and - # VarNamedVector. I'm lazy to fix it though, because I think we need to - # rework it soon anyway. - if vi in [vi_vnv, vi_vnv_typed] - DynamicPPL.setval!(vicopy, zeros(5), m_vns) - else - DynamicPPL.setval!(vicopy, (m=zeros(5),)) - end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. - if model == model_uv && vi in [vi_untyped, vi_typed] - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] == vi[s_vns] - - # Ordering is NOT preserved => fails for multivariate model. - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] == vi[s_vns] - - DynamicPPL.setval!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - DynamicPPL.setval!(vicopy, (s=42,)) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 - end - end - - # https://github.com/TuringLang/DynamicPPL.jl/issues/250 - @model function demo() - return x ~ filldist(MvNormal([1, 100], I), 2) - end - - vi = VarInfo(demo()) - vals_prev = vi.metadata.x.vals - ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] - DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals - end - - @testset "setval! on chain" begin - # Define a helper function - """ - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - - Test `setval!` on `model` and `chain`. - - Worth noting that this only supports models containing symbols of the forms - `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. - """ - function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - θ_old = var_info[:] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[:] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end - end - - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = make_chain_from_prior(model, 10) - # A simple way of checking that the computation is determinstic: run twice and compare. - res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) - res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) - @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) - end - end - - @testset "link!! and invlink!!" begin - @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin - s ~ InverseGamma(2, 3) - m ~ Uniform(0, 2) - x = Vector{T}(undef, length(a)) - x .~ Normal(m, sqrt(s)) - y = Vector{T}(undef, length(a)) - for i in eachindex(y) - y[i] ~ Normal(m, sqrt(s)) - end - a .~ Normal(m, sqrt(s)) - for i in eachindex(b) - b[i] ~ Normal(x[i] * y[i], sqrt(s)) - end - end - model = gdemo([1.0, 1.5], [2.0, 2.5]) - - # Check that instantiating the model does not perform linking - vi = VarInfo() - meta = vi.metadata - model(vi, SampleFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) - - # Check that linking and invlinking set the `trans` flag accordingly - v = copy(meta.vals) - vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.vns) - vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.vns) - @test meta.vals ≈ v atol = 1e-10 - - # Check that linking and invlinking preserves the values - vi = TypedVarInfo(vi) - meta = vi.metadata - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - v_x = copy(meta.x.vals) - v_y = copy(meta.y.vals) - - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - - # Transform only one variable - all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) - for vn in [ - @varname(s), - @varname(m), - @varname(x), - @varname(y), - @varname(x[2]), - @varname(y[2]) - ] - target_vns = filter(x -> subsumes(vn, x), all_vns) - other_vns = filter(x -> !subsumes(vn, x), all_vns) - @test !isempty(target_vns) - @test !isempty(other_vns) - vi = link!!(vi, (vn,), model) - @test all(x -> istrans(vi, x), target_vns) - @test all(x -> !istrans(vi, x), other_vns) - vi = invlink!!(vi, (vn,), model) - @test all(x -> !istrans(vi, x), all_vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - @test meta.x.vals ≈ v_x atol = 1e-10 - @test meta.y.vals ≈ v_y atol = 1e-10 - end - end - - @testset "istrans" begin - @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) - model = demo_constrained() - vn = @varname(x) - dist = truncated(Normal(), 0, Inf) - - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - ## `UntypedVarInfo` - vi = VarInfo() - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - - ## `TypedVarInfo` - vi = VarInfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - end - - @testset "values_as" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - - # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, example_values, vns; include_threadsafe=true - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Just making sure. - DynamicPPL.TestUtils.test_values(vi, example_values, vns) - - @testset "NamedTuple" begin - vals = values_as(vi, NamedTuple) - for vn in vns - if haskey(vals, Symbol(vn)) - # Assumed to be of form `(var"m[1]" = 1.0, ...)`. - @test getindex(vals, Symbol(vn)) == getindex(vi, vn) - else - # Assumed to be of form `(m = [1.0, ...], ...)`. - @test get(vals, vn) == getindex(vi, vn) - end - end - end - - @testset "OrderedDict" begin - vals = values_as(vi, OrderedDict) - # All varnames in `vns` should be subsumed by one of `keys(vals)`. - @test all(vns) do vn - any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) - end - # Iterate over `keys(vals)` because we might have scenarios such as - # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is - # the varname present in `vns`, not `@varname(m)`. - for vn in keys(vals) - @test getindex(vals, vn) == getindex(vi, vn) - end - end - end - end - end - - @testset "unflatten + linking" begin - @testset "Model: $(model.f)" for model in [ - DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), - DynamicPPL.TestUtils.demo_lkjchol(), - ] - @testset "mutating=$mutating" for mutating in [false, true] - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=true - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - - if DynamicPPL.has_varnamedvector(varinfo) && mutating - # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. - @test_broken false - continue - end - - # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) - - varinfo_linked = if mutating - DynamicPPL.link!!(deepcopy(varinfo), model) - else - DynamicPPL.link(varinfo, model) - end - for vn in keys(varinfo) - @test DynamicPPL.istrans(varinfo_linked, vn) - end - @test length(varinfo[:]) > length(varinfo_linked[:]) - varinfo_linked_unflattened = DynamicPPL.unflatten( - varinfo_linked, varinfo_linked[:] - ) - @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) - - lp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - value_linked_true, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, value_true... - ) - - lp = logjoint(model, varinfo) - @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) - @test lp_linked ≈ lp_linked_true - - # TODO: Compare values once we are no longer working with `NamedTuple` for - # the true values, e.g. `value_true`. - - if !mutating - # This is also compatible with invlinking of unflattened varinfo. - varinfo_invlinked = DynamicPPL.invlink( - varinfo_linked_unflattened, model - ) - @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true - end - end - end - end - end - - @testset "subset" begin - @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x = TV(undef, 2) - x[1] ~ Normal(m, sqrt(s)) - x[2] ~ Normal(m, sqrt(s)) - return (; s, m, x) - end - model = demo_subsetting_varinfo() - vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] - - # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, model(), vns; include_threadsafe=true - ) - varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) - - # `VarInfo` supports subsetting using, basically, arbitrary varnames. - vns_supported_standard = [ - [@varname(s)], - [@varname(m)], - [@varname(x[1])], - [@varname(x[2])], - [@varname(s), @varname(m)], - [@varname(s), @varname(x[1])], - [@varname(s), @varname(x[2])], - [@varname(m), @varname(x[1])], - [@varname(m), @varname(x[2])], - [@varname(x[1]), @varname(x[2])], - [@varname(s), @varname(m), @varname(x[1])], - [@varname(s), @varname(m), @varname(x[2])], - [@varname(s), @varname(x[1]), @varname(x[2])], - [@varname(m), @varname(x[1]), @varname(x[2])], - ] - - # Patterns requiring `subsumes`. - vns_supported_with_subsumes = [ - [@varname(s), @varname(x)] => [@varname(s), @varname(x[1]), @varname(x[2])], - [@varname(m), @varname(x)] => [@varname(m), @varname(x[1]), @varname(x[2])], - [@varname(s), @varname(m), @varname(x)] => - [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], - ] - - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # All variables. - check_varinfo_keys(varinfo, vns) - - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in - vns_supported - varinfo_subset = subset(varinfo, VarName[]) - @test isempty(varinfo_subset) - end - - @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in - vns_supported - varinfo_subset = subset(varinfo, vns_subset) - # Should now only contain the variables in `vns_subset`. - check_varinfo_keys(varinfo_subset, vns_subset) - # Values should be the same. - @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] - - # `merge` with the original. - varinfo_merged = merge(varinfo, varinfo_subset) - vns_merged = keys(varinfo_merged) - # Should be equivalent. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - end - - @testset "$(convert(Vector{VarName}, vns_subset))" for ( - vns_subset, vns_target - ) in vns_supported_with_subsumes - varinfo_subset = subset(varinfo, vns_subset) - # Should now only contain the variables in `vns_subset`. - check_varinfo_keys(varinfo_subset, vns_target) - # Values should be the same. - @test [varinfo_subset[vn] for vn in vns_target] == [varinfo[vn] for vn in vns_target] - - # `merge` with the original. - varinfo_merged = merge(varinfo, varinfo_subset) - vns_merged = keys(varinfo_merged) - # Should be equivalent. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - end - end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end - end - - @testset "merge" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, - DynamicPPL.TestUtils.rand_prior_true(model), - vns; - include_threadsafe=true, - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @testset "with itself" begin - # Merging itself should be a no-op. - varinfo_merged = merge(varinfo, varinfo) - # Varnames should be unchanged. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - end - - @testset "with itself (3-argument version)" begin - # Merging itself should be a no-op. - varinfo_merged = merge(varinfo, varinfo, varinfo) - # Varnames should be unchanged. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - end - - @testset "with empty" begin - # Empty is 1st argument. - # Merging with an empty `VarInfo` should be a no-op. - varinfo_merged = merge(empty!!(deepcopy(varinfo)), varinfo) - # Varnames should be unchanged. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - - # Empty is 2nd argument. - # Merging with an empty `VarInfo` should be a no-op. - varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) - # Varnames should be unchanged. - check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] - end - - @testset "with different value" begin - x = DynamicPPL.TestUtils.rand_prior_true(model) - varinfo_changed = DynamicPPL.TestUtils.update_values!!( - deepcopy(varinfo), x, vns - ) - # After `merge`, we should have the same values as `x`. - varinfo_merged = merge(varinfo, varinfo_changed) - DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) - end - end - end - - @testset "different models" begin - @model function demo_merge_different_y() - x ~ Uniform() - return y ~ Normal() - end - @model function demo_merge_different_z() - x ~ Normal() - return z ~ Normal() - end - model_left = demo_merge_different_y() - model_right = demo_merge_different_z() - - varinfo_left = VarInfo(model_left) - varinfo_right = VarInfo(model_right) - varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) - - varinfo_merged = merge(varinfo_left, varinfo_right) - vns = [@varname(x), @varname(y), @varname(z)] - check_varinfo_keys(varinfo_merged, vns) - - # Right has precedence. - @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.istrans(varinfo_merged, @varname(x)) - end - end - - # The below used to error, testing to avoid regression. - @testset "merge different dimensions" begin - vn = @varname(x) - vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) - vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) - @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] - @test merge(vi_double, vi_single)[vn] == 1.0 - end - - @testset "sampling from linked varinfo" begin - # `~` - @model function demo(n=1) - x = Vector(undef, n) - for i in eachindex(x) - x[i] ~ Exponential() - end - return x - end - model1 = demo(1) - varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) - for vn in [@varname(x[1]), @varname(x[2])] - @test DynamicPPL.istrans(varinfo2, vn) - end - - # `.~` - @model function demo_dot(n=1) - x ~ Exponential() - if n > 1 - y = Vector(undef, n - 1) - y .~ Exponential() - end - return x - end - model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo_dot(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) - for vn in [@varname(x), @varname(y[1])] - @test DynamicPPL.istrans(varinfo2, vn) - end - end - - # NOTE: It is not yet clear if this is something we want from all varinfo types. - # Hence, we only test the `VarInfo` types here. - @testset "vector_getranges for `VarInfo`" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, nt, vns; include_threadsafe=true - ) - # Only keep `VarInfo` types. - varinfos = filter( - Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - x = values_as(varinfo, Vector) - - # Let's just check all the subsets of `vns`. - @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in - combinations(vns) - ranges = DynamicPPL.vector_getranges(varinfo, vns_subset) - @test length(ranges) == length(vns_subset) - for (r, vn) in zip(ranges, vns_subset) - @test x[r] == DynamicPPL.tovec(varinfo[vn]) - end - end - - # Let's try some failure cases. - @test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[] - # Non-existent variables. - @test_throws KeyError DynamicPPL.vector_getranges( - varinfo, [VarName{gensym("vn")}()] - ) - @test_throws KeyError DynamicPPL.vector_getranges( - varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] - ) - # Duplicate variables. - ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2)) - @test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2) - end - end - end - - @testset "orders" begin - @model empty_model() = x = 1 - - csym = gensym() # unique per model - vn_z1 = @varname z[1] - vn_z2 = @varname z[2] - vn_z3 = @varname z[3] - vn_z4 = @varname z[4] - vn_a1 = @varname a[1] - vn_a2 = @varname a[2] - vn_b = @varname b - - vi = DynamicPPL.VarInfo() - dists = [Categorical([0.7, 0.3]), Normal()] - - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1]) - randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2]) - randr(vi, vn_z2, dists[1]) - randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1]) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] - @test DynamicPPL.get_num_produce(vi) == 3 - - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - @test DynamicPPL.is_flagged(vi, vn_z1, "del") - @test DynamicPPL.is_flagged(vi, vn_a1, "del") - @test DynamicPPL.is_flagged(vi, vn_z2, "del") - @test DynamicPPL.is_flagged(vi, vn_a2, "del") - @test DynamicPPL.is_flagged(vi, vn_z3, "del") - - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1]) - randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1]) - randr(vi, vn_a2, dists[2]) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] - @test DynamicPPL.get_num_produce(vi) == 3 - - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1]) - randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2]) - randr(vi, vn_z2, dists[1]) - randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] - @test DynamicPPL.get_num_produce(vi) == 3 - - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - @test DynamicPPL.is_flagged(vi, vn_z1, "del") - @test DynamicPPL.is_flagged(vi, vn_a1, "del") - @test DynamicPPL.is_flagged(vi, vn_z2, "del") - @test DynamicPPL.is_flagged(vi, vn_a2, "del") - @test DynamicPPL.is_flagged(vi, vn_z3, "del") - - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1]) - randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1]) - randr(vi, vn_a2, dists[2]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] - @test DynamicPPL.get_num_produce(vi) == 3 - end -end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl deleted file mode 100644 index bd3f5553f..000000000 --- a/test/varnamedvector.jl +++ /dev/null @@ -1,626 +0,0 @@ -replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) - -increase_size_for_test(x::Real) = [x] -increase_size_for_test(x::AbstractArray) = repeat(x, 2) - -decrease_size_for_test(x::Real) = x -decrease_size_for_test(x::AbstractVector) = first(x) -decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) - -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.varnames)) - # If the container is concrete, we need to make sure that the varname types match. - # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then - # we need `vn` to also be of this type. - # => If the varname types don't match, we need to relax the container type. - return any(keys(vnv)) do vn_present - typeof(vn_present) !== typeof(val) - end - end - - return false -end -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.vals)) - return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) - end - - return false -end -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return if isconcretetype(eltype(vnv.transforms)) - # If the container is concrete, we need to make sure that the sizes match. - # => If the sizes don't match, we need to relax the container type. - any(keys(vnv)) do vn_present - size(vnv[vn_present]) != size(val) - end - elseif eltype(vnv.transforms) !== Any - # If it's not concrete AND it's not `Any`, then we should just make it `Any`. - true - else - # Otherwise, it's `Any`, so we don't need to relax the container type. - false - end -end -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -""" - relax_container_types(vnv::VarNamedVector, vn::VarName, val) - relax_container_types(vnv::VarNamedVector, vns, val) - -Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. - -This attempts to avoid unnecessary container type relaxations by checking whether -the container types of `vnv` are already compatible with `vn` and `val`. - -# Notes -For example, if `vn` is not compatible with the current keys in `vnv`, then -the underlying types will be changed to `VarName` to accommodate `vn`. - -Similarly: -- If `val` is not compatible with the current values in `vnv`, then - the underlying value type will be changed to `Real`. -- If `val` requires a transformation that is not compatible with the current - transformations type in `vnv`, then the underlying transformation type will - be changed to `Any`. -""" -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return relax_container_types(vnv, [vn], [val]) -end -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) - if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) - varnames_new = convert(Vector{VarName}, vnv.varnames) - else - varname_to_index_new = vnv.varname_to_index - varnames_new = vnv.varnames - end - - transforms_new = if need_transforms_relaxation(vnv, vns, vals) - convert(Vector{Any}, vnv.transforms) - else - vnv.transforms - end - - vals_new = if need_values_relaxation(vnv, vns, vals) - convert(Vector{Real}, vnv.vals) - else - vnv.vals - end - - return DynamicPPL.VarNamedVector( - varname_to_index_new, - varnames_new, - vnv.ranges, - vals_new, - transforms_new, - vnv.is_unconstrained, - vnv.num_inactive, - ) -end - -@testset "VarNamedVector" begin - # Test element-related operations: - # - `getindex` - # - `setindex!` - # - `push!` - # - `update!` - # - `insert!` - # - `reset!` - # - `_internal!` versions of the above - # - !! versions of the above - # - # And these are all be tested for different types of values: - # - scalar - # - vector - # - matrix - - # Test operations on `VarNamedVector`: - # - `empty!` - # - `iterate` - # - `convert` to - # - `AbstractDict` - test_pairs = OrderedDict( - @varname(x[1]) => rand(), - @varname(x[2]) => rand(2), - @varname(x[3]) => rand(2, 3), - @varname(y[1]) => rand(), - @varname(y[2]) => rand(2), - @varname(y[3]) => rand(2, 3), - @varname(z[1]) => rand(1:10), - @varname(z[2]) => rand(1:10, 2), - @varname(z[3]) => rand(1:10, 2, 3), - ) - test_vns = collect(keys(test_pairs)) - test_vals = collect(values(test_pairs)) - - @testset "constructor: no args" begin - # Empty. - vnv = DynamicPPL.VarNamedVector() - @test isempty(vnv) - @test eltype(vnv) == Real - - # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64}() - @test isempty(vnv) - @test eltype(vnv) == Float64 - end - - test_varnames_iter = combinations(test_vns, 2) - @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter - val_left = test_pairs[vn_left] - val_right = test_pairs[vn_right] - vnv_base = DynamicPPL.VarNamedVector([vn_left, vn_right], [val_left, val_right]) - - # We'll need the transformations later. - # TODO: Should we test other transformations than just `ReshapeTransform`? - from_vec_left = DynamicPPL.from_vec_transform(val_left) - from_vec_right = DynamicPPL.from_vec_transform(val_right) - to_vec_left = inverse(from_vec_left) - to_vec_right = inverse(from_vec_right) - - # Compare to alternative constructors. - vnv_from_dict = DynamicPPL.VarNamedVector( - OrderedDict(vn_left => val_left, vn_right => val_right) - ) - @test vnv_base == vnv_from_dict - - # We want the types of fields such as `varnames` and `transforms` to specialize - # whenever possible + some functionality, e.g. `push!`, is only sensible - # if the underlying containers can support it. - # Expected behavior - should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) - should_have_restricted_transform_type = size(val_left) == size(val_right) - # Actual behavior - has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) - has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) - - @testset "type specialization" begin - @test !should_have_restricted_varname_type || has_restricted_varname_type - @test !should_have_restricted_transform_type || has_restricted_transform_type - end - - @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) - @test DynamicPPL.length_internal(vnv_base) == length(val_left) + length(val_right) - @test length(vnv_base) == 2 - - @test !isempty(vnv_base) - - @testset "empty!" begin - vnv = deepcopy(vnv_base) - empty!(vnv) - @test isempty(vnv) - end - - @testset "similar" begin - vnv = similar(vnv_base) - @test isempty(vnv) - @test typeof(vnv) == typeof(vnv_base) - end - - @testset "getindex" begin - # With `VarName` index. - @test vnv_base[vn_left] == val_left - @test vnv_base[vn_right] == val_right - end - - @testset "getindex_internal" begin - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_left) == - to_vec_left(val_left) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_right) == - to_vec_right(val_right) - end - - @testset "getindex_internal with Ints" begin - for (i, val) in enumerate(to_vec_left(val_left)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, i) == val - end - offset = length(to_vec_left(val_left)) - for (i, val) in enumerate(to_vec_right(val_right)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, offset + i) == val - end - end - - @testset "update!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "delete!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - @test !haskey(vnv, vn_left) - @test haskey(vnv, vn_right) - delete!(vnv, vn_right) - @test !haskey(vnv, vn_right) - end - - @testset "insert!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "merge" begin - # When there are no inactive entries, `merge` on itself result in the same. - @test merge(vnv_base, vnv_base) == vnv_base - - # Merging with empty should result in the same. - @test merge(vnv_base, similar(vnv_base)) == vnv_base - @test merge(similar(vnv_base), vnv_base) == vnv_base - - # With differences. - vnv_left_only = deepcopy(vnv_base) - delete!(vnv_left_only, vn_right) - vnv_right_only = deepcopy(vnv_base) - delete!(vnv_right_only, vn_left) - - # `(x,)` and `(x, y)` should be `(x, y)`. - @test merge(vnv_left_only, vnv_base) == vnv_base - # `(x, y)` and `(x,)` should be `(x, y)`. - @test merge(vnv_base, vnv_left_only) == vnv_base - # `(x, y)` and `(y,)` should be `(x, y)`. - @test merge(vnv_base, vnv_right_only) == vnv_base - # `(y,)` and `(x, y)` should be `(y, x)`. - vnv_merged = merge(vnv_right_only, vnv_base) - @test vnv_merged != vnv_base - @test collect(keys(vnv_merged)) == [vn_right, vn_left] - end - - @testset "push!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - vnv_copy = deepcopy(vnv) - push!(vnv, (vn => val)) - @test vnv[vn] == val - end - end - - @testset "setindex_internal!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), typeof(increment)) - DynamicPPL.setindex_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), typeof(increment)) - DynamicPPL.setindex_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - DynamicPPL.setindex_internal!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex_internal! with Ints" begin - vnv = deepcopy(vnv_base) - for i in 1:DynamicPPL.length_internal(vnv_base) - DynamicPPL.setindex_internal!(vnv, i, i) - end - for i in 1:DynamicPPL.length_internal(vnv_base) - @test DynamicPPL.getindex_internal(vnv, i) == i - end - end - - @testset "setindex_internal!!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - # Note that unlike with setindex_internal!, we don't need loosen_types!! here. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex! and reset!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - expected_length = if haskey(vnv, vn) - # If it's already present, the resulting length will be unchanged. - DynamicPPL.length_internal(vnv) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - vnv[vn] = val .+ 1 - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - - # There should be no redundant values in the underlying vector. - @test !DynamicPPL.has_inactive(vnv) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (increased size)" for vn in test_vns - val_original = test_pairs[vn] - val = increase_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (decreased size)" for vn in test_vns - val_original = test_pairs[vn] - val = decrease_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - end - end - - @testset "growing and shrinking" begin - @testset "deterministic" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - # Growing should not create inactive ranges. - for i in 1:n - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - end - - # Same size should not create inactive ranges. - x = fill(true, n) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - - # Shrinking should create inactive ranges. - for i in (n - 1):-1:1 - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test DynamicPPL.has_inactive(vnv) - @test DynamicPPL.num_inactive(vnv, vn) == n - i - end - end - - @testset "random" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - - # Insert a bunch of random-length vectors. - for i in 1:100 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - end - # Should never be allocating more than `n` elements. - @test DynamicPPL.num_allocated(vnv, vn) ≤ n - - # If we compaticfy, then it should always be the same size as just inserted. - for i in 1:10 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - DynamicPPL.contiguify!(vnv) - @test DynamicPPL.num_allocated(vnv, vn) == length(x) - end - end - end - - @testset "subset" begin - vnv = DynamicPPL.VarNamedVector(test_pairs) - @test subset(vnv, test_vns) == vnv - @test subset(vnv, VarName[]) == DynamicPPL.VarNamedVector() - @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv - - # Test that subset preserves transformations and unconstrainedness. - vn = @varname(t[1]) - vns = vcat(test_vns, [vn]) - vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.settrans!(vnv, true, @varname(t[1])) - @test vnv[@varname(t[1])] == [4.0] - @test istrans(vnv, @varname(t[1])) - @test subset(vnv, vns) == vnv - end -end - -@testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.DEMO_MODELS - @testset "$(model.f)" for model in models - # NOTE: Need to set random seed explicitly to avoid using the same seed - # for initialization as for sampling in the inner testset below. - Random.seed!(42) - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) - # Filter out those which are not based on `VarNamedVector`. - varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) - # Get the true log joint. - logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Need to make sure we're using a different random seed from the - # one used in the above call to `rand_prior_true`. - Random.seed!(43) - - # Are values correct? - DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) - - # Is evaluation correct? - varinfo_eval = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) - ) - # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true - # Values should be the same. - DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) - - # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) - ) - # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) - # Values should be different. - DynamicPPL.TestUtils.test_values( - varinfo_sample, value_true, vns; compare=!isequal - ) - end - end -end