From 2ec2455758e8c7545b69975db51c0b275c63ce32 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 25 Jun 2024 21:35:10 +0100 Subject: [PATCH 01/15] add tapir support, update interface to support stateful AD --- Project.toml | 3 ++ ext/AdvancedVIForwardDiffExt.jl | 16 +++++---- ext/AdvancedVIReverseDiffExt.jl | 14 ++++---- ext/AdvancedVIZygoteExt.jl | 14 ++++---- src/AdvancedVI.jl | 33 ++++++++++++------- src/objectives/elbo/repgradelbo.jl | 19 +++++++++-- src/optimize.jl | 4 ++- src/utils.jl | 3 +- test/Project.toml | 2 ++ test/inference/repgradelbo_distributionsad.jl | 1 + test/inference/repgradelbo_locationscale.jl | 1 + .../repgradelbo_locationscale_bijectors.jl | 3 +- test/interface/ad.jl | 9 +++-- test/interface/repgradelbo.jl | 14 ++++---- test/runtests.jl | 2 +- 15 files changed, 92 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 9d30e740..515310bb 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -32,6 +33,7 @@ AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" AdvancedVIForwardDiffExt = "ForwardDiff" AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVITapirExt = "Tapir" AdvancedVIZygoteExt = "Zygote" [compat] @@ -55,6 +57,7 @@ Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" +Tapir = "0.2.23" Zygote = "0.6.63" julia = "1.6" diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index a8afd031..bf0da458 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,10 +14,11 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad ::ADTypes.AutoForwardDiff, + ::Any, f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + x ::AbstractVector, + out ::DiffResults.MutableDiffResult ) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) @@ -30,13 +31,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad ::ADTypes.AutoForwardDiff, + st_ad, f, - x ::AbstractVector, + x ::AbstractVector, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 392f5cea..cde187b5 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,9 +13,10 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, + ::ADTypes.AutoReverseDiff, + ::Any, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) tp = ReverseDiff.GradientTape(f, x) @@ -24,13 +25,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, + ad ::ADTypes.AutoReverseDiff, + st_ad, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 806c08e4..f60b31fb 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -14,9 +14,10 @@ else end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, + ::ADTypes.AutoZygote, + ::Any, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) y, back = Zygote.pullback(f, x) @@ -27,13 +28,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, + ad ::ADTypes.AutoZygote, + st_ad, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7a09030b..8423a312 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,14 +25,15 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, x, out) - value_and_gradient!(ad, f, x, aux, out) + value_and_gradient!(adtype, ad_st, f, x, out) + value_and_gradient!(adtype, ad_st, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation (AD) backend `ad` and store the result in `out`. `f` may receive auxiliary input as `f(x,aux)`. # Arguments -- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `adtype::ADTypes.AbstractADType`: AD backend. +- `ad_st`: State used by the AD backend. (This will often be pre-compiled tapes/caches.) - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. @@ -41,18 +42,22 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif function value_and_gradient! end """ - stop_gradient(x) + init_adbackend(adtype, f, x) + init_adbackend(adtype, f, x, aux) -Stop the gradient from propagating to `x` if the selected ad backend supports it. -Otherwise, it is equivalent to `identity`. +Initialize the AD backend and setup states necessary. # Arguments -- `x`: Input +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. # Returns -- `x`: Same value as the input. +- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.) """ -function stop_gradient end +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing # Update for gradient descent step """ @@ -95,14 +100,17 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, prob, params, restructure) + init(rng, obj, adtype, prob, params, restructure) -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +Initialize a state of the variational objective `obj`. This function needs to be implemented only if `obj` is stateful. +The state of the AD backend `adtype` shall also be initialized here. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. +- `adtype::ADTypes.ADType`:Automatic differentiation backend. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ @@ -112,6 +120,7 @@ init( ::Any, ::Any, ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 27a937e8..e0f5de40 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -106,6 +106,20 @@ function estimate_repgradelbo_ad_forward(params′, aux) -elbo end +function init( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adtype ::ADTypes.AbstractADType, + prob, + params, + restructure, +) + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux) + (ad_st=ad_st,) +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, @@ -117,11 +131,12 @@ function estimate_gradient!( state, ) q_stop = restructure(params) + ad_st = state.ad_st aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) value_and_gradient!( - adtype, estimate_repgradelbo_ad_forward, params, aux, out + adtype, ad_st, estimate_repgradelbo_ad_forward, params, aux, out ) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat + out, state, stat end diff --git a/src/optimize.jl b/src/optimize.jl index e5fe374d..659f3d16 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,9 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) + obj_st = maybe_init_objective( + state_init, rng, adtype, objective, problem, params, restructure + ) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 3ae59a78..fbfdc330 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,6 +18,7 @@ end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, + adtype ::ADTypes.AbstractADType, objective ::AbstractVariationalObjective, problem, params, @@ -26,7 +27,7 @@ function maybe_init_objective( if haskey(state_init, :objective) state_init.objective else - init(rng, objective, problem, params, restructure) + init(rng, objective, adtype, problem, params, restructure) end end diff --git a/test/Project.toml b/test/Project.toml index a0dba17f..b25ddfd5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -38,6 +39,7 @@ ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +Tapir = "0.2.23" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 38dbc6e3..201f69b2 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -14,6 +14,7 @@ :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Tapir => AutoTapir(), #:Enzyme => AutoEnzyme(), ) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 73846675..b46d4d4d 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -14,6 +14,7 @@ (adbackname, adtype) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), + :Tapir => AutoTapir(safe_mode=false), :Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), ) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 1f62df0f..5f04c25a 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -13,7 +13,8 @@ (adbackname, adtype) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), + :Zygote => AutoZygote(), + :Tapir => AutoTapir(safe_mode=false), #:Enzyme => AutoEnzyme(), ) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index be4ca34e..55e7cdbd 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -2,18 +2,21 @@ using Test @testset "ad" begin - @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + @testset "$(adname)" for (adname, adtype) ∈ Dict( :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Tapir => AutoTapir(), # :Enzyme => AutoEnzyme() # Currently not tested against ) D = 10 A = randn(D, D) λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) f(λ′) = λ′'*A*λ′ / 2 - AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + + ad_st = AdvancedVI.init_adbackend(adtype, f, λ) + grad_buf = DiffResults.GradientResult(λ) + AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A')*λ/2 diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index ac9bfeca..58e440ce 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -34,7 +34,7 @@ end modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - @testset for ad in [ + @testset for adtype in [ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() @@ -44,12 +44,14 @@ end Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) - obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) - out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + ad_st = AdvancedVI.init_adbackend( + adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux + ) AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + adtype, ad_st, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol=1e-5 diff --git a/test/runtests.jl b/test/runtests.jl index 3bd13144..c85840b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,7 @@ using DistributionsAD using LogDensityProblems using Optimisers using ADTypes -using ForwardDiff, ReverseDiff, Zygote +using ForwardDiff, ReverseDiff, Zygote, Tapir using AdvancedVI From 6549258de26fdf714637fa8cb92f478f46cc3b8c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 25 Jun 2024 21:37:43 +0100 Subject: [PATCH 02/15] add missing Tapir extra --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 515310bb..c35b4b3d 100644 --- a/Project.toml +++ b/Project.toml @@ -67,6 +67,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From 767274c354c342e49a824f83f5f57bd161b42af6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 25 Jun 2024 21:43:47 +0100 Subject: [PATCH 03/15] update tighten julia version requirement for Tapir --- .github/workflows/CI.yml | 4 ++-- Project.toml | 2 +- test/Project.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0302271a..224f81d4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,8 +19,8 @@ jobs: fail-fast: false matrix: version: - - '1' - - '1.6' + - '1.7' + - '1.10' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index c35b4b3d..85f9e13c 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" Tapir = "0.2.23" Zygote = "0.6.63" -julia = "1.6" +julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/test/Project.toml b/test/Project.toml index b25ddfd5..d64d07e5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,4 +43,4 @@ Tapir = "0.2.23" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" -julia = "1.6" +julia = "1.7" From 28174c98a2967a245f5655aafc52958436cd1816 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 12 Jul 2024 22:15:29 +0100 Subject: [PATCH 04/15] add Tapir extension --- ext/AdvancedVITapirExt.jl | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 ext/AdvancedVITapirExt.jl diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl new file mode 100644 index 00000000..3794789b --- /dev/null +++ b/ext/AdvancedVITapirExt.jl @@ -0,0 +1,47 @@ + +module AdvancedVITapirExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Tapir +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Tapir +end + +AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x) + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, + st_ad, + f, + x ::AbstractVector{<:Real}, + out ::DiffResults.MutableDiffResult +) + rule = st_ad + y, g = Tapir.value_and_gradient!!(rule, f, x) + DiffResults.value!(out, y) + DiffResults.gradient!(out, last(g)) + return out +end + +AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux) + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, + st_ad, + f, + x ::AbstractVector{<:Real}, + aux, + out ::DiffResults.MutableDiffResult +) + rule = st_ad + y, g = Tapir.value_and_gradient!!(rule, f, x, aux) + DiffResults.value!(out, y) + DiffResults.gradient!(out, last(g)) + return out +end + +end From a3c7e10ce08718848deb18e41b0d00bd34c21328 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:35:50 +0100 Subject: [PATCH 05/15] Update test/interface/ad.jl --- test/interface/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index f6f91f76..63a837a4 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -7,7 +7,7 @@ using Test :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Tapir => AutoTapir(), - :Enzyme => AutoEnzyme() + #:Enzyme => AutoEnzyme() ) D = 10 A = randn(D, D) From 86ed16175c8a0e1959a6dd12d3ab03819ed1f0f5 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:36:18 +0100 Subject: [PATCH 06/15] Update .github/workflows/CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 224f81d4..21b3b6a3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: version: - - '1.7' + #- '1.7' - '1.10' os: - ubuntu-latest From 7887306ad68e51fd09855001b4b05fd7f9b1dda8 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 9 Aug 2024 16:11:29 +0100 Subject: [PATCH 07/15] formatting --- src/AdvancedVI.jl | 33 ++---- src/families/location_scale.jl | 93 +++++++-------- src/objectives/elbo/entropy.jl | 6 +- src/objectives/elbo/repgradelbo.jl | 68 +++++------ src/optimize.jl | 70 ++++++------ src/utils.jl | 20 ++-- test/inference/repgradelbo_distributionsad.jl | 76 +++++++------ test/inference/repgradelbo_locationscale.jl | 78 +++++++------ .../repgradelbo_locationscale_bijectors.jl | 84 ++++++++------ test/interface/ad.jl | 20 ++-- test/interface/location_scale.jl | 107 ++++++++++-------- test/interface/optimize.jl | 56 ++++----- test/interface/repgradelbo.jl | 27 ++--- test/models/normal.jl | 28 ++--- test/models/normallognormal.jl | 87 +++++++------- test/runtests.jl | 1 - 16 files changed, 422 insertions(+), 432 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8423a312..70f7b253 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -56,7 +56,7 @@ Initialize the AD backend and setup states necessary. # Returns - `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.) """ -init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing # Update for gradient descent step @@ -83,8 +83,9 @@ Instead, the return values should be used. """ function update_variational_params! end -update_variational_params!(::Type, opt_st, params, restructure, grad) = - Optimisers.update!(opt_st, params, grad) +function update_variational_params!(::Type, opt_st, params, restructure, grad) + return Optimisers.update!(opt_st, params, grad) +end # estimators """ @@ -114,14 +115,8 @@ The state of the AD backend `adtype` shall also be initialized here. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ -init( - ::Random.AbstractRNG, - ::AbstractVariationalObjective, - ::Any, - ::Any, - ::Any, - ::Any, -) = nothing +init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any, ::Any) = + nothing """ estimate_objective([rng,] obj, q, prob; kwargs...) @@ -145,7 +140,6 @@ function estimate_objective end export estimate_objective - """ estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) @@ -186,25 +180,16 @@ Estimate the entropy of `q`. """ function estimate_entropy end -export - RepGradELBO, - ClosedFormEntropy, - StickingTheLandingEntropy, - MonteCarloEntropy +export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") - # Variational Families -export - MvLocationScale, - MeanFieldGaussian, - FullRankGaussian +export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") - # Optimization Routine function optimize end @@ -214,7 +199,6 @@ export optimize include("utils.jl") include("optimize.jl") - # optional dependencies if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base using Requires @@ -241,4 +225,3 @@ end end end - diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 66ea5cdb..552e6b93 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -13,22 +13,21 @@ represented as follows: z = scale*u + location ``` """ -struct MvLocationScale{ - S, D <: ContinuousDistribution, L, E <: Real -} <: ContinuousMultivariateDistribution - location ::L - scale ::S - dist ::D +struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: + ContinuousMultivariateDistribution + location::L + scale::S + dist::D scale_eps::E end function MvLocationScale( - location ::AbstractVector{T}, - scale ::AbstractMatrix{T}, - dist ::ContinuousDistribution; - scale_eps::T = sqrt(eps(T)) -) where {T <: Real} - MvLocationScale(location, scale, dist, scale_eps) + location::AbstractVector{T}, + scale::AbstractMatrix{T}, + dist::ContinuousDistribution; + scale_eps::T=sqrt(eps(T)), +) where {T<:Real} + return MvLocationScale(location, scale, dist, scale_eps) end Functors.@functor MvLocationScale (location, scale) @@ -38,23 +37,21 @@ Functors.@functor MvLocationScale (location, scale) # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{S <: Diagonal, D, L} - q::MvLocationScale{S, D, L} +struct RestructureMeanField{S<:Diagonal,D,L} + q::MvLocationScale{S,D,L} end function (re::RestructureMeanField)(flat::AbstractVector) - n_dims = div(length(flat), 2) + n_dims = div(length(flat), 2) location = first(flat, n_dims) - scale = Diagonal(last(flat, n_dims)) - MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) + scale = Diagonal(last(flat, n_dims)) + return MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) end -function Optimisers.destructure( - q::MvLocationScale{<:Diagonal, D, L} -) where {D, L} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) - flat, RestructureMeanField(q) + return flat, RestructureMeanField(q) end # end @@ -62,61 +59,63 @@ Base.length(q::MvLocationScale) = length(q.location) Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) - @unpack location, scale, dist = q + @unpack location, scale, dist = q n_dims = length(location) # `convert` is necessary because `entropy` is not type stable upstream - n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) + return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) + return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions.rand(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - scale*rand(dist, n_dims) + location + return scale * rand(dist, n_dims) + location end function Distributions.rand( - rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int -) where {S, D, L} + rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int +) where {S,D,L} @unpack location, scale, dist = q n_dims = length(location) - scale*rand(rng, dist, n_dims, num_samples) .+ location + return scale * rand(rng, dist, n_dims, num_samples) .+ location end # This specialization improves AD performance of the sampling path function Distributions.rand( - rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int -) where {L, D} + rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int +) where {L,D} @unpack location, scale, dist = q - n_dims = length(location) + n_dims = length(location) scale_diag = diag(scale) - scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location + return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location end -function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}) +function Distributions._rand!( + rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real} +) @unpack location, scale, dist = q rand!(rng, dist, x) - x[:] = scale*x + x[:] = scale * x return x .+= location end Distributions.mean(q::MvLocationScale) = q.location -function Distributions.var(q::MvLocationScale) +function Distributions.var(q::MvLocationScale) C = q.scale - Diagonal(C*C') + return Diagonal(C * C') end function Distributions.cov(q::MvLocationScale) C = q.scale - Hermitian(C*C') + return Hermitian(C * C') end """ @@ -132,13 +131,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix. - `check_args`: Check the conditioning of the initial scale (default: `true`). """ function FullRankGaussian( - μ::AbstractVector{T}, - L::LinearAlgebra.AbstractTriangular{T}; - scale_eps::T = sqrt(eps(T)) -) where {T <: Real} + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T)) +) where {T<:Real} @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base, scale_eps) + return MvLocationScale(μ, L, q_base, scale_eps) end """ @@ -154,13 +151,11 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix - `check_args`: Check the conditioning of the initial scale (default: `true`). """ function MeanFieldGaussian( - μ::AbstractVector{T}, - L::Diagonal{T}; - scale_eps::T = sqrt(eps(T)), -) where {T <: Real} + μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T)) +) where {T<:Real} @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base, scale_eps) + return MvLocationScale(μ, L, q_base, scale_eps) end function update_variational_params!( @@ -176,5 +171,5 @@ function update_variational_params!( params, _ = Optimisers.destructure(q) - opt_st, params + return opt_st, params end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 973d54e2..210b49ca 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -12,7 +12,7 @@ struct ClosedFormEntropy <: AbstractEntropyEstimator end maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q function estimate_entropy(::ClosedFormEntropy, ::Any, q) - entropy(q) + return entropy(q) end """ @@ -31,9 +31,7 @@ struct MonteCarloEntropy <: AbstractEntropyEstimator end maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop function estimate_entropy( - ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, - mc_samples::AbstractMatrix, - q + ::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q ) mean(eachcol(mc_samples)) do mc_sample -logpdf(q, mc_sample) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index e0f5de40..d6b54ce7 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -28,31 +28,32 @@ This computes the evidence lower-bound (ELBO) through the formulation: Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - entropy ::EntropyEst +struct RepGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy::EntropyEst n_samples::Int end -RepGradELBO( - n_samples::Int; - entropy ::AbstractEntropyEstimator = ClosedFormEntropy() -) = RepGradELBO(entropy, n_samples) +function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy()) + return RepGradELBO(entropy, n_samples) +end function Base.show(io::IO, obj::RepGradELBO) print(io, "RepGradELBO(entropy=") print(io, obj.entropy) print(io, ", n_samples=") print(io, obj.n_samples) - print(io, ")") + return print(io, ")") end -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) +function estimate_entropy_maybe_stl( + entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop +) q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, samples, q_maybe_stop) + return estimate_entropy(entropy_estimator, samples, q_maybe_stop) end function estimate_energy_with_samples(prob, samples) - mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end """ @@ -71,31 +72,24 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng ::Random.AbstractRNG, - q, - q_stop, - n_samples::Int, - ent_est ::AbstractEntropyEstimator + rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) - samples, entropy + return samples, entropy end function estimate_objective( - rng::Random.AbstractRNG, - obj::RepGradELBO, - q, - prob; - n_samples::Int = obj.n_samples + rng::Random.AbstractRNG, obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samples ) samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) - energy + entropy + return energy + entropy end -estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = - estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samples) + return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +end function estimate_repgradelbo_ad_forward(params′, aux) @unpack rng, obj, problem, restructure, q_stop = aux @@ -103,13 +97,13 @@ function estimate_repgradelbo_ad_forward(params′, aux) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(problem, samples) elbo = energy + entropy - -elbo + return -elbo end function init( - rng ::Random.AbstractRNG, - obj ::RepGradELBO, - adtype ::ADTypes.AbstractADType, + rng::Random.AbstractRNG, + obj::RepGradELBO, + adtype::ADTypes.AbstractADType, prob, params, restructure, @@ -117,26 +111,24 @@ function init( q_stop = restructure(params) aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux) - (ad_st=ad_st,) + return (ad_st=ad_st,) end function estimate_gradient!( - rng ::Random.AbstractRNG, - obj ::RepGradELBO, + rng::Random.AbstractRNG, + obj::RepGradELBO, adtype::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult, + out::DiffResults.MutableDiffResult, prob, params, restructure, state, ) q_stop = restructure(params) - ad_st = state.ad_st + ad_st = state.ad_st aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) - value_and_gradient!( - adtype, ad_st, estimate_repgradelbo_ad_forward, params, aux, out - ) + value_and_gradient!(adtype, ad_st, estimate_repgradelbo_ad_forward, params, aux, out) nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - out, state, stat + stat = (elbo=-nelbo,) + return out, state, stat end diff --git a/src/optimize.jl b/src/optimize.jl index 659f3d16..ea001678 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -45,39 +45,42 @@ Otherwise, just return `nothing`. """ function optimize( - rng ::Random.AbstractRNG, + rng::Random.AbstractRNG, problem, - objective ::AbstractVariationalObjective, + objective::AbstractVariationalObjective, q_init, - max_iter ::Int, + max_iter::Int, objargs...; - adtype ::ADTypes.AbstractADType, - optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - show_progress::Bool = true, - state_init ::NamedTuple = NamedTuple(), - callback = nothing, - prog = ProgressMeter.Progress( - max_iter; - desc = "Optimizing", - barlen = 31, - showspeed = true, - enabled = show_progress + adtype::ADTypes.AbstractADType, + optimizer::Optimisers.AbstractRule=Optimisers.Adam(), + show_progress::Bool=true, + state_init::NamedTuple=NamedTuple(), + callback=nothing, + prog=ProgressMeter.Progress( + max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=show_progress ), ) params, restructure = Optimisers.destructure(deepcopy(q_init)) - opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective( + opt_st = maybe_init_optimizer(state_init, optimizer, params) + obj_st = maybe_init_objective( state_init, rng, adtype, objective, problem, params, restructure ) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - stats = NamedTuple[] + stats = NamedTuple[] - for t = 1:max_iter + for t in 1:max_iter stat = (iteration=t,) grad_buf, obj_st, stat′ = estimate_gradient!( - rng, objective, adtype, grad_buf, problem, - params, restructure, obj_st, objargs... + rng, + objective, + adtype, + grad_buf, + problem, + params, + restructure, + obj_st, + objargs..., ) stat = merge(stat, stat′) @@ -87,13 +90,16 @@ function optimize( ) if !isnothing(callback) - stat′ = callback( - ; stat, restructure, params=params, gradient=grad, - state=(optimizer=opt_st, objective=obj_st) + stat′ = callback(; + stat, + restructure, + params=params, + gradient=grad, + state=(optimizer=opt_st, objective=obj_st), ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - + @debug "Iteration $t" stat... pm_next!(prog, stat) @@ -101,24 +107,18 @@ function optimize( end state = (optimizer=opt_st, objective=obj_st) stats = map(identity, stats) - restructure(params), stats, state + return restructure(params), stats, state end function optimize( problem, objective::AbstractVariationalObjective, q_init, - max_iter ::Int, + max_iter::Int, objargs...; - kwargs... + kwargs..., ) - optimize( - Random.default_rng(), - problem, - objective, - q_init, - max_iter, - objargs...; - kwargs... + return optimize( + Random.default_rng(), problem, objective, q_init, max_iter, objargs...; kwargs... ) end diff --git a/src/utils.jl b/src/utils.jl index fbfdc330..d9ce4bad 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,12 +1,10 @@ function pm_next!(pm, stats::NamedTuple) - ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) + return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) end function maybe_init_optimizer( - state_init::NamedTuple, - optimizer ::Optimisers.AbstractRule, - params + state_init::NamedTuple, optimizer::Optimisers.AbstractRule, params ) if haskey(state_init, :optimizer) state_init.optimizer @@ -17,12 +15,12 @@ end function maybe_init_objective( state_init::NamedTuple, - rng ::Random.AbstractRNG, - adtype ::ADTypes.AbstractADType, - objective ::AbstractVariationalObjective, + rng::Random.AbstractRNG, + adtype::ADTypes.AbstractADType, + objective::AbstractVariationalObjective, problem, params, - restructure + restructure, ) if haskey(state_init, :objective) state_init.objective @@ -34,11 +32,9 @@ end eachsample(samples::AbstractMatrix) = eachcol(samples) function catsamples_and_acc( - state_curr::Tuple{<:AbstractArray, <:Real}, - state_new ::Tuple{<:AbstractVector, <:Real} + state_curr::Tuple{<:AbstractArray,<:Real}, state_new::Tuple{<:AbstractVector,<:Real} ) - x = hcat(first(state_curr), first(state_new)) + x = hcat(first(state_curr), first(state_new)) ∑y = last(state_curr) + last(state_new) return (x, ∑y) end - diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 201f69b2..2ae332a3 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,37 +1,36 @@ @testset "inference RepGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( - :Normal=> normal_meanfield, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in Dict(:Normal => normal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), + (adbackname, adtype) in Dict( + :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Tapir => AutoTapir(), + :Zygote => AutoZygote(), + :Tapir => AutoTapir(), #:Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity μ0 = Zeros(realtype, n_dims) L0 = Diagonal(Ones(realtype, n_dims)) @@ -40,17 +39,21 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -58,20 +61,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = mean(q) L_repl = sqrt(cov(q)) @@ -80,4 +91,3 @@ end end end - diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index b46d4d4d..a0490ac4 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,60 +1,63 @@ @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype in [Float64, Float32], - (modelname, modelconstr) in Dict( - :Normal=> normal_meanfield, - :Normal=> normal_fullrank, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), + :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - :Tapir => AutoTapir(safe_mode=false), - :Zygote => AutoZygote(), + :Tapir => AutoTapir(; safe_mode=false), + :Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else - L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) FullRankGaussian(zeros(realtype, n_dims), L0) end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -62,20 +65,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = q.location L_repl = q.scale @@ -84,4 +95,3 @@ end end end - diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 5f04c25a..b6ed0d53 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,42 +1,42 @@ @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype in [Float64, Float32], - (modelname, modelconstr) in Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), + :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Tapir => AutoTapir(safe_mode=false), + :Zygote => AutoZygote(), + :Tapir => AutoTapir(; safe_mode=false), #:Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - μ0 = Zeros(realtype, n_dims) - L0 = Diagonal(Ones(realtype, n_dims)) + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) q0_η = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else - L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) FullRankGaussian(zeros(realtype, n_dims), L0) end q0_z = Bijectors.transformed(q0_η, b⁻¹) @@ -44,22 +44,26 @@ # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q.dist.location + L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -67,20 +71,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q.dist.location + L = q.dist.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = q.dist.location L_repl = q.dist.scale diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 63a837a4..3ec4e638 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -2,24 +2,24 @@ using Test @testset "ad" begin - @testset "$(adname)" for (adname, adtype) ∈ Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Tapir => AutoTapir(), - #:Enzyme => AutoEnzyme() - ) + @testset "$(adname)" for (adname, adtype) in Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Tapir => AutoTapir(), + #:Enzyme => AutoEnzyme() + ) D = 10 A = randn(D, D) λ = randn(D) - f(λ′) = λ′'*A*λ′ / 2 + f(λ′) = λ′' * A * λ′ / 2 ad_st = AdvancedVI.init_adbackend(adtype, f, λ) grad_buf = DiffResults.GradientResult(λ) AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A')*λ/2 - @test f ≈ λ'*A*λ / 2 + @test ∇ ≈ (A + A') * λ / 2 + @test f ≈ λ' * A * λ / 2 end end diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 7a129018..dcc3369d 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -1,22 +1,21 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for - basedist = [:gaussian], - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64] + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian], + covtype in [:meanfield, :fullrank], + realtype in [Float32, Float64] - n_dims = 10 + n_dims = 10 n_montecarlo = 1000_000 μ = randn(realtype, n_dims) L = if covtype == :fullrank - tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + LowerTriangular(tril(I + ones(realtype, n_dims, n_dims) / 2)) else Diagonal(ones(realtype, n_dims)) end - Σ = L*L' + Σ = L * L' - q = if covtype == :fullrank && basedist == :gaussian + q = if covtype == :fullrank && basedist == :gaussian FullRankGaussian(μ, L) elseif covtype == :meanfield && basedist == :gaussian MeanFieldGaussian(μ, L) @@ -31,13 +30,13 @@ @testset "logpdf" begin z = rand(q) - @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) - @test eltype(logpdf(q, z)) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol = realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype end @testset "entropy" begin @test eltype(entropy(q)) == realtype - @test entropy(q) ≈ entropy(q_true) + @test entropy(q) ≈ entropy(q_true) end @testset "length" begin @@ -46,37 +45,41 @@ @testset "statistics" begin @testset "mean" begin - @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test eltype(mean(q)) == realtype + @test mean(q) == μ end @testset "var" begin - @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) end @testset "cov" begin - @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ end end @testset "sampling" begin @testset "rand" begin - z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) end @testset "rand batch" begin - z_samples = rand(q, n_montecarlo) + z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -84,16 +87,18 @@ @testset "rand! AbstractVector" begin res = map(1:n_montecarlo) do _ - z_sample = Array{realtype}(undef, n_dims) + z_sample = Array{realtype}(undef, n_dims) z_sample_ret = rand!(q, z_sample) (z_sample, z_sample_ret) end - z_samples = mapreduce(first, hcat, res) + z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -104,12 +109,14 @@ end @testset "rand! AbstractMatrix" begin - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -123,44 +130,44 @@ @testset "Diagonal destructure" begin n_dims = 10 - μ = zeros(n_dims) - L = ones(n_dims) - q = MeanFieldGaussian(μ, L |> Diagonal) - λ, re = Optimisers.destructure(q) + μ = zeros(n_dims) + L = ones(n_dims) + q = MeanFieldGaussian(μ, Diagonal(L)) + λ, re = Optimisers.destructure(q) - @test length(λ) == 2*n_dims - @test q == re(λ) + @test length(λ) == 2 * n_dims + @test q == re(λ) end end @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64], - bijector = [nothing, :identity] + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] d = 5 μ = zeros(realtype, d) ϵ = sqrt(realtype(0.5)) q = if covtype == :fullrank - L = LowerTriangular(Matrix{realtype}(I,d,d)) + L = LowerTriangular(Matrix{realtype}(I, d, d)) FullRankGaussian(μ, L; scale_eps=ϵ) elseif covtype == :meanfield L = Diagonal(ones(realtype, d)) MeanFieldGaussian(μ, L; scale_eps=ϵ) end - q_trans = if isnothing(bijector) + q_trans = if isnothing(bijector) q else Bijectors.TransformedDistribution(q, identity) end g = deepcopy(q) - λ, re = Optimisers.destructure(q) + λ, re = Optimisers.destructure(q) grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) - q′ = re(λ′) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) @test all(diag(var(q′)) .≥ ϵ^2) end end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index ea60b764..eb006c98 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -3,7 +3,7 @@ using Test @testset "interface optimize" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) T = 1000 modelstats = normal_meanfield(rng, Float64) @@ -11,64 +11,54 @@ using Test @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) obj = RepGradELBO(10) - adtype = AutoForwardDiff() + adtype = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - rng = StableRNG(seed) + rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( - rng, model, obj, q0, T; - optimizer, - show_progress = false, - adtype, + rng, model, obj, q0, T; optimizer, show_progress=false, adtype ) @testset "default_rng" begin - optimize( - model, obj, q0, T; - optimizer, - show_progress = false, - adtype, - ) + optimize(model, obj, q0, T; optimizer, show_progress=false, adtype) end @testset "callback" begin - rng = StableRNG(seed) + rng = StableRNG(seed) test_values = rand(rng, T) - callback(; stat, args...) = (test_value = test_values[stat.iteration],) + callback(; stat, args...) = (test_value=test_values[stat.iteration],) - rng = StableRNG(seed) + rng = StableRNG(seed) _, stats, _ = optimize( - rng, model, obj, q0, T; - show_progress = false, - adtype, - callback + rng, model, obj, q0, T; show_progress=false, adtype, callback ) - @test [stat.test_value for stat ∈ stats] == test_values + @test [stat.test_value for stat in stats] == test_values end @testset "warm start" begin - rng = StableRNG(seed) + rng = StableRNG(seed) - T_first = div(T,2) - T_last = T - T_first + T_first = div(T, 2) + T_last = T - T_first q_first, _, state = optimize( - rng, model, obj, q0, T_first; - optimizer, - show_progress = false, - adtype + rng, model, obj, q0, T_first; optimizer, show_progress=false, adtype ) q, stats, _ = optimize( - rng, model, obj, q_first, T_last; + rng, + model, + obj, + q_first, + T_last; optimizer, - show_progress = false, - state_init = state, - adtype + show_progress=false, + state_init=state, + adtype, ) @test q == q_ref end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 58e440ce..31a967fb 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -3,7 +3,7 @@ using Test @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = normal_meanfield(rng, Float64) @@ -11,42 +11,39 @@ using Test q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - obj = RepGradELBO(10) - rng = StableRNG(seed) + obj = RepGradELBO(10) + rng = StableRNG(seed) elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) @testset "determinism" begin - rng = StableRNG(seed) + rng = StableRNG(seed) elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin elbo = estimate_objective(obj, q0, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 + @test elbo ≈ elbo_ref rtol = 0.1 end end @testset "interface RepGradELBO STL variance reduction" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats @testset for adtype in [ - ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(), - ADTypes.AutoZygote() + ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] q_true = MeanFieldGaussian( - Vector{eltype(μ_true)}(μ_true), - Diagonal(Vector{eltype(L_true)}(diag(L_true))) + Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) - obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) - out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) ad_st = AdvancedVI.init_adbackend( adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux ) @@ -54,6 +51,6 @@ end adtype, ad_st, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) - @test norm(grad) ≈ 0 atol=1e-5 + @test norm(grad) ≈ 0 atol = 1e-5 end end diff --git a/test/models/normal.jl b/test/models/normal.jl index 5c3e22e8..3aa95f4c 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -6,42 +6,42 @@ end function LogDensityProblems.logdensity(model::TestNormal, θ) @unpack μ, Σ = model - logpdf(MvNormal(μ, Σ), θ) + return logpdf(MvNormal(μ, Σ), θ) end function LogDensityProblems.dimension(model::TestNormal) - length(model.μ) + return length(model.μ) end function LogDensityProblems.capabilities(::Type{<:TestNormal}) - LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{0}() end function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_dims) - L = Matrix(σ0*I, n_dims, n_dims) - Σ = L*L' |> Hermitian + μ = Fill(realtype(5), n_dims) + L = Matrix(σ0 * I, n_dims, n_dims) + Σ = Hermitian(L * L') model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) - TestModel(model, μ, LowerTriangular(L), n_dims, 1/σ0^2, false) + return TestModel(model, μ, LowerTriangular(L), n_dims, 1 / σ0^2, false) end function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_dims) - #randn(rng, realtype, n_dims) - σ = Fill(σ0, n_dims) - #log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + μ = Fill(realtype(5), n_dims) + #randn(rng, realtype, n_dims) + σ = Fill(σ0, n_dims) + #log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - model = TestNormal(μ, Diagonal(σ.^2)) + model = TestNormal(μ, Diagonal(σ .^ 2)) - L = σ |> Diagonal + L = Diagonal(σ) - TestModel(model, μ, L, n_dims, 1/σ0^2, true) + return TestModel(model, μ, L, n_dims, 1 / σ0^2, true) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 54adcd48..176aab2f 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -1,54 +1,55 @@ -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type) - n_y_dims = 5 +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + return length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + return LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + return Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:(1 + length(μ_y))], + ) +end + +function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5.0), n_y_dims+1) - L = Matrix(σ0*I, n_y_dims+1, n_y_dims+1) - Σ = L*L' |> Hermitian + μ = Fill(realtype(5.0), n_y_dims + 1) + L = Matrix(σ0 * I, n_y_dims + 1, n_y_dims + 1) + Σ = Hermitian(L * L') model = NormalLogNormal( - μ[1], L[1,1], μ[2:end], PDMat(Σ[2:end,2:end], Cholesky(L[2:end,2:end], 'L', 0)) + μ[1], L[1, 1], μ[2:end], PDMat(Σ[2:end, 2:end], Cholesky(L[2:end, 2:end], 'L', 0)) ) - TestModel(model, μ, LowerTriangular(L), n_y_dims+1, 1/σ0^2, false) -end + return TestModel(model, μ, LowerTriangular(L), n_y_dims + 1, 1 / σ0^2, false) +end -function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type) - n_y_dims = 5 +function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_y_dims + 1) - σ = Fill(σ0, n_y_dims + 1) - L = Diagonal(σ) + μ = Fill(realtype(5), n_y_dims + 1) + σ = Fill(σ0, n_y_dims + 1) + L = Diagonal(σ) - model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end].^2)) + model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end] .^ 2)) - TestModel(model, μ, L, n_y_dims+1, 1/σ0^2, true) -end + return TestModel(model, μ, L, n_y_dims + 1, 1 / σ0^2, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 479e596a..4011a05d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,6 @@ if GROUP == "All" || GROUP == "Interface" include("interface/location_scale.jl") end - const PROGRESS = haskey(ENV, "PROGRESS") if GROUP == "All" || GROUP == "Inference" From 3a3a829bfc6eaf053e64929e2c27081e6d91a7b8 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 9 Aug 2024 16:15:47 +0100 Subject: [PATCH 08/15] format ext/ --- ext/AdvancedVIBijectorsExt.jl | 22 +++++++++++----------- ext/AdvancedVIEnzymeExt.jl | 10 ++-------- ext/AdvancedVIForwardDiffExt.jl | 18 +++++++++--------- ext/AdvancedVIReverseDiffExt.jl | 16 ++++++++-------- ext/AdvancedVITapirExt.jl | 12 ++++++------ ext/AdvancedVIZygoteExt.jl | 16 ++++++++-------- 6 files changed, 44 insertions(+), 50 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index a227fdf2..f66e0ea2 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -20,7 +20,7 @@ function AdvancedVI.update_variational_params!( opt_st, params, restructure, - grad + grad, ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) @@ -32,18 +32,18 @@ function AdvancedVI.update_variational_params!( params, _ = Optimisers.destructure(q) - opt_st, params + return opt_st, params end function AdvancedVI.reparam_with_entropy( - rng ::Random.AbstractRNG, - q ::Bijectors.TransformedDistribution, - q_stop ::Bijectors.TransformedDistribution, + rng::Random.AbstractRNG, + q::Bijectors.TransformedDistribution, + q_stop::Bijectors.TransformedDistribution, n_samples::Int, - ent_est ::AdvancedVI.AbstractEntropyEstimator + ent_est::AdvancedVI.AbstractEntropyEstimator, ) - transform = q.transform - q_unconst = q.dist + transform = q.transform + q_unconst = q.dist q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution @@ -58,14 +58,14 @@ function AdvancedVI.reparam_with_entropy( samples_and_logjac = mapreduce( AdvancedVI.catsamples_and_acc, Iterators.drop(unconstr_iter, 1); - init=(reshape(samples_init, (:,1)), logjac_init) + init=(reshape(samples_init, (:, 1)), logjac_init), ) do sample with_logabsdet_jacobian(transform, sample) end samples = first(samples_and_logjac) - logjac = last(samples_and_logjac)/n_samples + logjac = last(samples_and_logjac) / n_samples entropy = unconst_entropy + logjac - samples, entropy + return samples, entropy end end diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 29a0a695..55f22897 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -12,18 +12,12 @@ else end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoEnzyme, - f, - θ::AbstractVector{T}, - out::DiffResults.MutableDiffResult, + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult ) where {T<:Real} ∇θ = DiffResults.gradient(out) fill!(∇θ, zero(T)) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, - f, - Enzyme.Active, - Enzyme.Duplicated(θ, ∇θ), + Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ) ) DiffResults.value!(out, y) return out diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index bf0da458..d9d255be 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,11 +14,11 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, - ::Any, + ad::ADTypes.AutoForwardDiff, + ::Any, f, - x ::AbstractVector, - out ::DiffResults.MutableDiffResult + x::AbstractVector, + out::DiffResults.MutableDiffResult, ) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) @@ -31,14 +31,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad::ADTypes.AutoForwardDiff, st_ad, f, - x ::AbstractVector, - aux, - out ::DiffResults.MutableDiffResult + x::AbstractVector, + aux, + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index cde187b5..c0a87502 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,11 +13,11 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoReverseDiff, - ::Any, + ::ADTypes.AutoReverseDiff, + ::Any, f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, ) tp = ReverseDiff.GradientTape(f, x) ReverseDiff.gradient!(out, tp, x) @@ -25,14 +25,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoReverseDiff, + ad::ADTypes.AutoReverseDiff, st_ad, f, - x ::AbstractVector{<:Real}, + x::AbstractVector{<:Real}, aux, - out ::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl index 3794789b..14025fbc 100644 --- a/ext/AdvancedVITapirExt.jl +++ b/ext/AdvancedVITapirExt.jl @@ -14,11 +14,11 @@ end AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x) function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, + ::ADTypes.AutoTapir, st_ad, f, - x ::AbstractVector{<:Real}, - out ::DiffResults.MutableDiffResult + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, ) rule = st_ad y, g = Tapir.value_and_gradient!!(rule, f, x) @@ -30,12 +30,12 @@ end AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux) function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, + ::ADTypes.AutoTapir, st_ad, f, - x ::AbstractVector{<:Real}, + x::AbstractVector{<:Real}, aux, - out ::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) rule = st_ad y, g = Tapir.value_and_gradient!!(rule, f, x, aux) diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index f60b31fb..3a477a24 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -14,11 +14,11 @@ else end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, - ::Any, + ::ADTypes.AutoZygote, + ::Any, f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, ) y, back = Zygote.pullback(f, x) ∇x = back(one(y)) @@ -28,14 +28,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoZygote, + ad::ADTypes.AutoZygote, st_ad, f, - x ::AbstractVector{<:Real}, + x::AbstractVector{<:Real}, aux, - out ::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end From cba9568ecaeb0bb2ca11dc07d1c6fda8a0fe1f49 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:56:12 +0100 Subject: [PATCH 09/15] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 85f9e13c..e124b371 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -Tapir = "0.2.23" +Tapir = "0.2.33" Zygote = "0.6.63" julia = "1.7" From d9573de832d8a8f378ea81fa97a73c668ddbf26f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:51:41 +0100 Subject: [PATCH 10/15] Update Project.toml Co-authored-by: Will Tebbutt --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e124b371..c1581c18 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -Tapir = "0.2.33" +Tapir = "0.2.34" Zygote = "0.6.63" julia = "1.7" From 5664ec1151889c31e97d6eb1e473355a101934ca Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 21 Aug 2024 20:37:54 +0100 Subject: [PATCH 11/15] Update ext/AdvancedVITapirExt.jl --- ext/AdvancedVITapirExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl index 14025fbc..b2e00af9 100644 --- a/ext/AdvancedVITapirExt.jl +++ b/ext/AdvancedVITapirExt.jl @@ -40,7 +40,7 @@ function AdvancedVI.value_and_gradient!( rule = st_ad y, g = Tapir.value_and_gradient!!(rule, f, x, aux) DiffResults.value!(out, y) - DiffResults.gradient!(out, last(g)) + DiffResults.gradient!(out, g[2]) return out end From dd0fc0b43a8ff383725859b230b919092f1259c1 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 21:27:54 +0100 Subject: [PATCH 12/15] Update test/inference/repgradelbo_distributionsad.jl --- test/inference/repgradelbo_distributionsad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 59e57971..a4bc86fe 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -13,7 +13,7 @@ :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Tapir => AutoTapir(), + :Tapir => AutoTapir(; safe_mode=false), #:Enzyme => AutoEnzyme(), ) From 826d3da50714166ca05622013f2f4aeafeeb2b16 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 21 Aug 2024 21:28:27 +0100 Subject: [PATCH 13/15] Update test/interface/ad.jl --- test/interface/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 3ec4e638..00fac213 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -6,7 +6,7 @@ using Test :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Tapir => AutoTapir(), + :Tapir => AutoTapir(; safe_mode=false), #:Enzyme => AutoEnzyme() ) D = 10 From ae8d6d42a0ac89631a2d583a9701d8e624e7c14d Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 22 Aug 2024 08:58:46 +0100 Subject: [PATCH 14/15] add back init_backend --- src/AdvancedVI.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 6a9f568d..8b9d7921 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -41,6 +41,21 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif """ function value_and_gradient! end +""" + init_adbackend(adtype, f, x) + init_adbackend(adtype, f, x, aux) +Initialize the AD backend and setup necessary states. +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. +# Returns +- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.) +""" +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing + """ restructure_ad_forward(adtype, restructure, params) From 99a91b343e95f51e6cb1762dc569212adaddfa83 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:02:22 +0100 Subject: [PATCH 15/15] fix formatting --- test/interface/repgradelbo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 6e92f19f..afaba148 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -39,7 +39,7 @@ end ADTypes.AutoReverseDiff(), ADTypes.AutoZygote(), ADTypes.AutoEnzyme(), - ADTypes.AutoTapir(false) + ADTypes.AutoTapir(false), ] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true)))