diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..0d9f0f8d --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,2 @@ +style="blue" +format_markdown = true diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index a227fdf2..f66e0ea2 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -20,7 +20,7 @@ function AdvancedVI.update_variational_params!( opt_st, params, restructure, - grad + grad, ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) @@ -32,18 +32,18 @@ function AdvancedVI.update_variational_params!( params, _ = Optimisers.destructure(q) - opt_st, params + return opt_st, params end function AdvancedVI.reparam_with_entropy( - rng ::Random.AbstractRNG, - q ::Bijectors.TransformedDistribution, - q_stop ::Bijectors.TransformedDistribution, + rng::Random.AbstractRNG, + q::Bijectors.TransformedDistribution, + q_stop::Bijectors.TransformedDistribution, n_samples::Int, - ent_est ::AdvancedVI.AbstractEntropyEstimator + ent_est::AdvancedVI.AbstractEntropyEstimator, ) - transform = q.transform - q_unconst = q.dist + transform = q.transform + q_unconst = q.dist q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution @@ -58,14 +58,14 @@ function AdvancedVI.reparam_with_entropy( samples_and_logjac = mapreduce( AdvancedVI.catsamples_and_acc, Iterators.drop(unconstr_iter, 1); - init=(reshape(samples_init, (:,1)), logjac_init) + init=(reshape(samples_init, (:, 1)), logjac_init), ) do sample with_logabsdet_jacobian(transform, sample) end samples = first(samples_and_logjac) - logjac = last(samples_and_logjac)/n_samples + logjac = last(samples_and_logjac) / n_samples entropy = unconst_entropy + logjac - samples, entropy + return samples, entropy end end diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 03e3a6a1..45b3c547 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -11,46 +11,39 @@ else using ..AdvancedVI: ADTypes, DiffResults end - -AdvancedVI.restructure_ad_forward( - ::ADTypes.AutoEnzyme, restructure, params -) = restructure(params)::typeof(restructure.model) +function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params) + return restructure(params)::typeof(restructure.model) +end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoEnzyme, - f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, - f, - Enzyme.Active, - Enzyme.Duplicated(x, ∇x) + Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x) ) DiffResults.value!(out, y) return out end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoEnzyme, + ::ADTypes.AutoEnzyme, f, - x ::AbstractVector{<:Real}, + x::AbstractVector{<:Real}, aux, - out ::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(out) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( Enzyme.ReverseWithPrimal, - f, + Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x), - Enzyme.Const(aux) + Enzyme.Const(aux), ) DiffResults.value!(out, y) return out diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index a8afd031..6904fa7a 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,10 +14,10 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad::ADTypes.AutoForwardDiff, f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, ) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) @@ -30,13 +30,13 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad::ADTypes.AutoForwardDiff, f, - x ::AbstractVector, - aux, - out::DiffResults.MutableDiffResult + x::AbstractVector, + aux, + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 392f5cea..9cde91a1 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -16,7 +16,7 @@ function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoReverseDiff, f, x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) tp = ReverseDiff.GradientTape(f, x) ReverseDiff.gradient!(out, tp, x) @@ -28,9 +28,9 @@ function AdvancedVI.value_and_gradient!( f, x::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 806c08e4..2cdd8392 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -14,10 +14,7 @@ else end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, - f, - x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + ::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) y, back = Zygote.pullback(f, x) ∇x = back(one(y)) @@ -31,9 +28,9 @@ function AdvancedVI.value_and_gradient!( f, x::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out::DiffResults.MutableDiffResult, ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a7f0b6aa..7cf9519e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -77,8 +77,9 @@ Instead, the return values should be used. """ function update_variational_params! end -update_variational_params!(::Type, opt_st, params, restructure, grad) = - Optimisers.update!(opt_st, params, grad) +function update_variational_params!(::Type, opt_st, params, restructure, grad) + return Optimisers.update!(opt_st, params, grad) +end # estimators """ @@ -105,13 +106,7 @@ This function needs to be implemented only if `obj` is stateful. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ -init( - ::Random.AbstractRNG, - ::AbstractVariationalObjective, - ::Any, - ::Any, - ::Any, -) = nothing +init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing """ estimate_objective([rng,] obj, q, prob; kwargs...) @@ -135,7 +130,6 @@ function estimate_objective end export estimate_objective - """ estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) @@ -176,25 +170,16 @@ Estimate the entropy of `q`. """ function estimate_entropy end -export - RepGradELBO, - ClosedFormEntropy, - StickingTheLandingEntropy, - MonteCarloEntropy +export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") - # Variational Families -export - MvLocationScale, - MeanFieldGaussian, - FullRankGaussian +export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") - # Optimization Routine function optimize end @@ -204,7 +189,6 @@ export optimize include("utils.jl") include("optimize.jl") - # optional dependencies if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base using Requires @@ -231,4 +215,3 @@ end end end - diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index f86257db..1aab2e71 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -13,22 +13,21 @@ represented as follows: z = scale*u + location ``` """ -struct MvLocationScale{ - S, D <: ContinuousDistribution, L, E <: Real -} <: ContinuousMultivariateDistribution - location ::L - scale ::S - dist ::D +struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: + ContinuousMultivariateDistribution + location::L + scale::S + dist::D scale_eps::E end function MvLocationScale( - location ::AbstractVector{T}, - scale ::AbstractMatrix{T}, - dist ::ContinuousDistribution; - scale_eps::T = sqrt(eps(T)) -) where {T <: Real} - MvLocationScale(location, scale, dist, scale_eps) + location::AbstractVector{T}, + scale::AbstractMatrix{T}, + dist::ContinuousDistribution; + scale_eps::T=sqrt(eps(T)), +) where {T<:Real} + return MvLocationScale(location, scale, dist, scale_eps) end Functors.@functor MvLocationScale (location, scale) @@ -38,23 +37,21 @@ Functors.@functor MvLocationScale (location, scale) # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{S <: Diagonal, D, L} - model::MvLocationScale{S, D, L} +struct RestructureMeanField{S<:Diagonal,D,L} + model::MvLocationScale{S,D,L} end function (re::RestructureMeanField)(flat::AbstractVector) - n_dims = div(length(flat), 2) + n_dims = div(length(flat), 2) location = first(flat, n_dims) - scale = Diagonal(last(flat, n_dims)) - MvLocationScale(location, scale, re.model.dist, re.model.scale_eps) + scale = Diagonal(last(flat, n_dims)) + return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps) end -function Optimisers.destructure( - q::MvLocationScale{<:Diagonal, D, L} -) where {D, L} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) - flat, RestructureMeanField(q) + return flat, RestructureMeanField(q) end # end @@ -62,61 +59,63 @@ Base.length(q::MvLocationScale) = length(q.location) Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) - @unpack location, scale, dist = q + @unpack location, scale, dist = q n_dims = length(location) # `convert` is necessary because `entropy` is not type stable upstream - n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) + return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) + return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) end function Distributions.rand(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - scale*rand(dist, n_dims) + location + return scale * rand(dist, n_dims) + location end function Distributions.rand( - rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int -) where {S, D, L} + rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int +) where {S,D,L} @unpack location, scale, dist = q n_dims = length(location) - scale*rand(rng, dist, n_dims, num_samples) .+ location + return scale * rand(rng, dist, n_dims, num_samples) .+ location end # This specialization improves AD performance of the sampling path function Distributions.rand( - rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int -) where {L, D} + rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int +) where {L,D} @unpack location, scale, dist = q - n_dims = length(location) + n_dims = length(location) scale_diag = diag(scale) - scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location + return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location end -function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}) +function Distributions._rand!( + rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real} +) @unpack location, scale, dist = q rand!(rng, dist, x) - x[:] = scale*x + x[:] = scale * x return x .+= location end Distributions.mean(q::MvLocationScale) = q.location -function Distributions.var(q::MvLocationScale) +function Distributions.var(q::MvLocationScale) C = q.scale - Diagonal(C*C') + return Diagonal(C * C') end function Distributions.cov(q::MvLocationScale) C = q.scale - Hermitian(C*C') + return Hermitian(C * C') end """ @@ -132,13 +131,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix. - `check_args`: Check the conditioning of the initial scale (default: `true`). """ function FullRankGaussian( - μ::AbstractVector{T}, - L::LinearAlgebra.AbstractTriangular{T}; - scale_eps::T = sqrt(eps(T)) -) where {T <: Real} + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T)) +) where {T<:Real} @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base, scale_eps) + return MvLocationScale(μ, L, q_base, scale_eps) end """ @@ -154,13 +151,11 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix - `check_args`: Check the conditioning of the initial scale (default: `true`). """ function MeanFieldGaussian( - μ::AbstractVector{T}, - L::Diagonal{T}; - scale_eps::T = sqrt(eps(T)), -) where {T <: Real} + μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T)) +) where {T<:Real} @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScale(μ, L, q_base, scale_eps) + return MvLocationScale(μ, L, q_base, scale_eps) end function update_variational_params!( @@ -176,5 +171,5 @@ function update_variational_params!( params, _ = Optimisers.destructure(q) - opt_st, params + return opt_st, params end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 973d54e2..210b49ca 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -12,7 +12,7 @@ struct ClosedFormEntropy <: AbstractEntropyEstimator end maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q function estimate_entropy(::ClosedFormEntropy, ::Any, q) - entropy(q) + return entropy(q) end """ @@ -31,9 +31,7 @@ struct MonteCarloEntropy <: AbstractEntropyEstimator end maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop function estimate_entropy( - ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, - mc_samples::AbstractMatrix, - q + ::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q ) mean(eachcol(mc_samples)) do mc_sample -logpdf(q, mc_sample) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2daf880b..b8d73eaa 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -28,31 +28,32 @@ This computes the evidence lower-bound (ELBO) through the formulation: Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - entropy ::EntropyEst +struct RepGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy::EntropyEst n_samples::Int end -RepGradELBO( - n_samples::Int; - entropy ::AbstractEntropyEstimator = ClosedFormEntropy() -) = RepGradELBO(entropy, n_samples) +function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy()) + return RepGradELBO(entropy, n_samples) +end function Base.show(io::IO, obj::RepGradELBO) print(io, "RepGradELBO(entropy=") print(io, obj.entropy) print(io, ", n_samples=") print(io, obj.n_samples) - print(io, ")") + return print(io, ")") end -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) +function estimate_entropy_maybe_stl( + entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop +) q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, samples, q_maybe_stop) + return estimate_entropy(entropy_estimator, samples, q_maybe_stop) end function estimate_energy_with_samples(prob, samples) - mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end """ @@ -71,31 +72,24 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng ::Random.AbstractRNG, - q, - q_stop, - n_samples::Int, - ent_est ::AbstractEntropyEstimator + rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) - samples, entropy + return samples, entropy end function estimate_objective( - rng::Random.AbstractRNG, - obj::RepGradELBO, - q, - prob; - n_samples::Int = obj.n_samples + rng::Random.AbstractRNG, obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samples ) samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) - energy + entropy + return energy + entropy end -estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = - estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samples) + return estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +end function estimate_repgradelbo_ad_forward(params′, aux) @unpack rng, obj, problem, adtype, restructure, q_stop = aux @@ -103,14 +97,14 @@ function estimate_repgradelbo_ad_forward(params′, aux) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(problem, samples) elbo = energy + entropy - -elbo + return -elbo end function estimate_gradient!( - rng ::Random.AbstractRNG, - obj ::RepGradELBO, + rng::Random.AbstractRNG, + obj::RepGradELBO, adtype::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult, + out::DiffResults.MutableDiffResult, prob, params, restructure, @@ -118,17 +112,15 @@ function estimate_gradient!( ) q_stop = restructure(params) aux = ( - rng = rng, - adtype = adtype, - obj = obj, - problem = prob, - restructure = restructure, - q_stop = q_stop - ) - value_and_gradient!( - adtype, estimate_repgradelbo_ad_forward, params, aux, out + rng=rng, + adtype=adtype, + obj=obj, + problem=prob, + restructure=restructure, + q_stop=q_stop, ) + value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out) nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - out, nothing, stat + stat = (elbo=-nelbo,) + return out, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index e5fe374d..5bef7eec 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -45,37 +45,40 @@ Otherwise, just return `nothing`. """ function optimize( - rng ::Random.AbstractRNG, + rng::Random.AbstractRNG, problem, - objective ::AbstractVariationalObjective, + objective::AbstractVariationalObjective, q_init, - max_iter ::Int, + max_iter::Int, objargs...; - adtype ::ADTypes.AbstractADType, - optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - show_progress::Bool = true, - state_init ::NamedTuple = NamedTuple(), - callback = nothing, - prog = ProgressMeter.Progress( - max_iter; - desc = "Optimizing", - barlen = 31, - showspeed = true, - enabled = show_progress + adtype::ADTypes.AbstractADType, + optimizer::Optimisers.AbstractRule=Optimisers.Adam(), + show_progress::Bool=true, + state_init::NamedTuple=NamedTuple(), + callback=nothing, + prog=ProgressMeter.Progress( + max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=show_progress ), ) params, restructure = Optimisers.destructure(deepcopy(q_init)) - opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) + opt_st = maybe_init_optimizer(state_init, optimizer, params) + obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - stats = NamedTuple[] + stats = NamedTuple[] - for t = 1:max_iter + for t in 1:max_iter stat = (iteration=t,) grad_buf, obj_st, stat′ = estimate_gradient!( - rng, objective, adtype, grad_buf, problem, - params, restructure, obj_st, objargs... + rng, + objective, + adtype, + grad_buf, + problem, + params, + restructure, + obj_st, + objargs..., ) stat = merge(stat, stat′) @@ -85,13 +88,16 @@ function optimize( ) if !isnothing(callback) - stat′ = callback( - ; stat, restructure, params=params, gradient=grad, - state=(optimizer=opt_st, objective=obj_st) + stat′ = callback(; + stat, + restructure, + params=params, + gradient=grad, + state=(optimizer=opt_st, objective=obj_st), ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - + @debug "Iteration $t" stat... pm_next!(prog, stat) @@ -99,24 +105,18 @@ function optimize( end state = (optimizer=opt_st, objective=obj_st) stats = map(identity, stats) - restructure(params), stats, state + return restructure(params), stats, state end function optimize( problem, objective::AbstractVariationalObjective, q_init, - max_iter ::Int, + max_iter::Int, objargs...; - kwargs... + kwargs..., ) - optimize( - Random.default_rng(), - problem, - objective, - q_init, - max_iter, - objargs...; - kwargs... + return optimize( + Random.default_rng(), problem, objective, q_init, max_iter, objargs...; kwargs... ) end diff --git a/src/utils.jl b/src/utils.jl index 3ae59a78..c504513d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,12 +1,10 @@ function pm_next!(pm, stats::NamedTuple) - ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) + return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) end function maybe_init_optimizer( - state_init::NamedTuple, - optimizer ::Optimisers.AbstractRule, - params + state_init::NamedTuple, optimizer::Optimisers.AbstractRule, params ) if haskey(state_init, :optimizer) state_init.optimizer @@ -17,11 +15,11 @@ end function maybe_init_objective( state_init::NamedTuple, - rng ::Random.AbstractRNG, - objective ::AbstractVariationalObjective, + rng::Random.AbstractRNG, + objective::AbstractVariationalObjective, problem, params, - restructure + restructure, ) if haskey(state_init, :objective) state_init.objective @@ -33,11 +31,9 @@ end eachsample(samples::AbstractMatrix) = eachcol(samples) function catsamples_and_acc( - state_curr::Tuple{<:AbstractArray, <:Real}, - state_new ::Tuple{<:AbstractVector, <:Real} + state_curr::Tuple{<:AbstractArray,<:Real}, state_new::Tuple{<:AbstractVector,<:Real} ) - x = hcat(first(state_curr), first(state_new)) + x = hcat(first(state_curr), first(state_new)) ∑y = last(state_curr) + last(state_new) return (x, ∑y) end - diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index c139facc..e743f437 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,36 +1,35 @@ @testset "inference RepGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( - :Normal => normal_meanfield, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in Dict(:Normal => normal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) ∈ Dict( - :ForwarDiff => AutoForwardDiff(), + (adbackname, adtype) in Dict( + :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity μ0 = zeros(realtype, n_dims) L0 = Diagonal(ones(realtype, n_dims)) @@ -39,17 +38,21 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -57,20 +60,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = mean(q) - L = sqrt(cov(q)) + μ = mean(q) + L = sqrt(cov(q)) rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = mean(q) L_repl = sqrt(cov(q)) @@ -79,4 +90,3 @@ end end end - diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index cb255226..2fcaa421 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -1,59 +1,62 @@ @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype in [Float64, Float32], - (modelname, modelconstr) in Dict( - :Normal=> normal_meanfield, - :Normal=> normal_fullrank, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), + :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else - L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) FullRankGaussian(zeros(realtype, n_dims), L0) end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -61,20 +64,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = q.location L_repl = q.scale @@ -83,4 +94,3 @@ end end end - diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index d56af333..83bc858f 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -1,41 +1,41 @@ @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype in [Float64, Float32], - (modelname, modelconstr) in Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - ), + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), - :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => + RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), (adbackname, adtype) in Dict( - :ForwarDiff => AutoForwardDiff(), + :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = modelconstr(rng, realtype) @unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats - T = 1000 - η = 1e-3 + T = 1000 + η = 1e-3 opt = Optimisers.Descent(realtype(η)) - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - μ0 = Zeros(realtype, n_dims) - L0 = Diagonal(Ones(realtype, n_dims)) + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) q0_η = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else - L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) FullRankGaussian(zeros(realtype, n_dims), L0) end q0_z = Bijectors.transformed(q0_η, b⁻¹) @@ -43,22 +43,26 @@ # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η*strong_convexity + contraction_rate = 1 - η * strong_convexity @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q.dist.location + L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ contraction_rate^(T/2)*Δλ0 + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -66,20 +70,28 @@ @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) - μ = q.dist.location - L = q.dist.scale + μ = q.dist.location + L = q.dist.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q0_z, T; - optimizer = opt, - show_progress = PROGRESS, - adtype = adtype, + rng_repl, + model, + objective, + q0_z, + T; + optimizer=opt, + show_progress=PROGRESS, + adtype=adtype, ) μ_repl = q.dist.location L_repl = q.dist.scale diff --git a/test/interface/ad.jl b/test/interface/ad.jl index faa1f01c..380c2b9b 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -2,40 +2,40 @@ using Test @testset "ad" begin - @testset "$(adname)" for (adname, adsymbol) ∈ Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - ) + @testset "$(adname)" for (adname, adsymbol) in Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) D = 10 A = randn(D, D) λ = randn(D) grad_buf = DiffResults.GradientResult(λ) - f(λ′) = λ′'*A*λ′ / 2 + f(λ′) = λ′' * A * λ′ / 2 AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A')*λ/2 - @test f ≈ λ'*A*λ / 2 + @test ∇ ≈ (A + A') * λ / 2 + @test f ≈ λ' * A * λ / 2 end - @testset "$(adname) with auxiliary input" for (adname, adsymbol) ∈ Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - ) + @testset "$(adname) with auxiliary input" for (adname, adsymbol) in Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) D = 10 A = randn(D, D) λ = randn(D) b = randn(D) grad_buf = DiffResults.GradientResult(λ) - f(λ′, aux) = λ′'*A*λ′ / 2 + dot(aux.b, λ′) + f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) AdvancedVI.value_and_gradient!(adsymbol, f, λ, (b=b,), grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A')*λ/2 + b - @test f ≈ λ'*A*λ / 2 + dot(b, λ) + @test ∇ ≈ (A + A') * λ / 2 + b + @test f ≈ λ' * A * λ / 2 + dot(b, λ) end end diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 7a129018..dcc3369d 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -1,22 +1,21 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for - basedist = [:gaussian], - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64] + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian], + covtype in [:meanfield, :fullrank], + realtype in [Float32, Float64] - n_dims = 10 + n_dims = 10 n_montecarlo = 1000_000 μ = randn(realtype, n_dims) L = if covtype == :fullrank - tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + LowerTriangular(tril(I + ones(realtype, n_dims, n_dims) / 2)) else Diagonal(ones(realtype, n_dims)) end - Σ = L*L' + Σ = L * L' - q = if covtype == :fullrank && basedist == :gaussian + q = if covtype == :fullrank && basedist == :gaussian FullRankGaussian(μ, L) elseif covtype == :meanfield && basedist == :gaussian MeanFieldGaussian(μ, L) @@ -31,13 +30,13 @@ @testset "logpdf" begin z = rand(q) - @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) - @test eltype(logpdf(q, z)) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol = realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype end @testset "entropy" begin @test eltype(entropy(q)) == realtype - @test entropy(q) ≈ entropy(q_true) + @test entropy(q) ≈ entropy(q_true) end @testset "length" begin @@ -46,37 +45,41 @@ @testset "statistics" begin @testset "mean" begin - @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test eltype(mean(q)) == realtype + @test mean(q) == μ end @testset "var" begin - @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) end @testset "cov" begin - @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ end end @testset "sampling" begin @testset "rand" begin - z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) end @testset "rand batch" begin - z_samples = rand(q, n_montecarlo) + z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -84,16 +87,18 @@ @testset "rand! AbstractVector" begin res = map(1:n_montecarlo) do _ - z_sample = Array{realtype}(undef, n_dims) + z_sample = Array{realtype}(undef, n_dims) z_sample_ret = rand!(q, z_sample) (z_sample, z_sample_ret) end - z_samples = mapreduce(first, hcat, res) + z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -104,12 +109,14 @@ end @testset "rand! AbstractMatrix" begin - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -123,44 +130,44 @@ @testset "Diagonal destructure" begin n_dims = 10 - μ = zeros(n_dims) - L = ones(n_dims) - q = MeanFieldGaussian(μ, L |> Diagonal) - λ, re = Optimisers.destructure(q) + μ = zeros(n_dims) + L = ones(n_dims) + q = MeanFieldGaussian(μ, Diagonal(L)) + λ, re = Optimisers.destructure(q) - @test length(λ) == 2*n_dims - @test q == re(λ) + @test length(λ) == 2 * n_dims + @test q == re(λ) end end @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64], - bijector = [nothing, :identity] + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] d = 5 μ = zeros(realtype, d) ϵ = sqrt(realtype(0.5)) q = if covtype == :fullrank - L = LowerTriangular(Matrix{realtype}(I,d,d)) + L = LowerTriangular(Matrix{realtype}(I, d, d)) FullRankGaussian(μ, L; scale_eps=ϵ) elseif covtype == :meanfield L = Diagonal(ones(realtype, d)) MeanFieldGaussian(μ, L; scale_eps=ϵ) end - q_trans = if isnothing(bijector) + q_trans = if isnothing(bijector) q else Bijectors.TransformedDistribution(q, identity) end g = deepcopy(q) - λ, re = Optimisers.destructure(q) + λ, re = Optimisers.destructure(q) grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) - q′ = re(λ′) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) @test all(diag(var(q′)) .≥ ϵ^2) end end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index ea60b764..eb006c98 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -3,7 +3,7 @@ using Test @testset "interface optimize" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) T = 1000 modelstats = normal_meanfield(rng, Float64) @@ -11,64 +11,54 @@ using Test @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) obj = RepGradELBO(10) - adtype = AutoForwardDiff() + adtype = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - rng = StableRNG(seed) + rng = StableRNG(seed) q_ref, stats_ref, _ = optimize( - rng, model, obj, q0, T; - optimizer, - show_progress = false, - adtype, + rng, model, obj, q0, T; optimizer, show_progress=false, adtype ) @testset "default_rng" begin - optimize( - model, obj, q0, T; - optimizer, - show_progress = false, - adtype, - ) + optimize(model, obj, q0, T; optimizer, show_progress=false, adtype) end @testset "callback" begin - rng = StableRNG(seed) + rng = StableRNG(seed) test_values = rand(rng, T) - callback(; stat, args...) = (test_value = test_values[stat.iteration],) + callback(; stat, args...) = (test_value=test_values[stat.iteration],) - rng = StableRNG(seed) + rng = StableRNG(seed) _, stats, _ = optimize( - rng, model, obj, q0, T; - show_progress = false, - adtype, - callback + rng, model, obj, q0, T; show_progress=false, adtype, callback ) - @test [stat.test_value for stat ∈ stats] == test_values + @test [stat.test_value for stat in stats] == test_values end @testset "warm start" begin - rng = StableRNG(seed) + rng = StableRNG(seed) - T_first = div(T,2) - T_last = T - T_first + T_first = div(T, 2) + T_last = T - T_first q_first, _, state = optimize( - rng, model, obj, q0, T_first; - optimizer, - show_progress = false, - adtype + rng, model, obj, q0, T_first; optimizer, show_progress=false, adtype ) q, stats, _ = optimize( - rng, model, obj, q_first, T_last; + rng, + model, + obj, + q_first, + T_last; optimizer, - show_progress = false, - state_init = state, - adtype + show_progress=false, + state_init=state, + adtype, ) @test q == q_ref end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 579aba78..5fec46ff 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -3,7 +3,7 @@ using Test @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = normal_meanfield(rng, Float64) @@ -11,25 +11,25 @@ using Test q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) - obj = RepGradELBO(10) - rng = StableRNG(seed) + obj = RepGradELBO(10) + rng = StableRNG(seed) elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) @testset "determinism" begin - rng = StableRNG(seed) + rng = StableRNG(seed) elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) @test elbo == elbo_ref end @testset "default_rng" begin elbo = estimate_objective(obj, q0, model; n_samples=10^4) - @test elbo ≈ elbo_ref rtol=0.1 + @test elbo ≈ elbo_ref rtol = 0.1 end end @testset "interface RepGradELBO STL variance reduction" begin seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) + rng = StableRNG(seed) modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats @@ -38,11 +38,10 @@ end ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote(), - ADTypes.AutoEnzyme() + ADTypes.AutoEnzyme(), ] q_true = MeanFieldGaussian( - Vector{eltype(μ_true)}(μ_true), - Diagonal(Vector{eltype(L_true)}(diag(L_true))) + Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) @@ -53,6 +52,6 @@ end ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) - @test norm(grad) ≈ 0 atol=1e-5 + @test norm(grad) ≈ 0 atol = 1e-5 end end diff --git a/test/models/normal.jl b/test/models/normal.jl index 59cf0043..9fc6ae38 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -6,40 +6,40 @@ end function LogDensityProblems.logdensity(model::TestNormal, θ) @unpack μ, Σ = model - logpdf(MvNormal(μ, Σ), θ) + return logpdf(MvNormal(μ, Σ), θ) end function LogDensityProblems.dimension(model::TestNormal) - length(model.μ) + return length(model.μ) end function LogDensityProblems.capabilities(::Type{<:TestNormal}) - LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{0}() end function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_dims) - L = Matrix(σ0*I, n_dims, n_dims) - Σ = L*L' |> Hermitian + μ = Fill(realtype(5), n_dims) + L = Matrix(σ0 * I, n_dims, n_dims) + Σ = Hermitian(L * L') model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) - TestModel(model, μ, LowerTriangular(L), n_dims, 1/σ0^2, false) + return TestModel(model, μ, LowerTriangular(L), n_dims, 1 / σ0^2, false) end function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) n_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_dims) - σ = Fill(σ0, n_dims) + μ = Fill(realtype(5), n_dims) + σ = Fill(σ0, n_dims) - model = TestNormal(μ, Diagonal(σ.^2)) + model = TestNormal(μ, Diagonal(σ .^ 2)) - L = σ |> Diagonal + L = Diagonal(σ) - TestModel(model, μ, L, n_dims, 1/σ0^2, true) + return TestModel(model, μ, L, n_dims, 1 / σ0^2, true) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 54adcd48..176aab2f 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -1,54 +1,55 @@ -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type) - n_y_dims = 5 +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + return length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + return LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + return Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:(1 + length(μ_y))], + ) +end + +function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5.0), n_y_dims+1) - L = Matrix(σ0*I, n_y_dims+1, n_y_dims+1) - Σ = L*L' |> Hermitian + μ = Fill(realtype(5.0), n_y_dims + 1) + L = Matrix(σ0 * I, n_y_dims + 1, n_y_dims + 1) + Σ = Hermitian(L * L') model = NormalLogNormal( - μ[1], L[1,1], μ[2:end], PDMat(Σ[2:end,2:end], Cholesky(L[2:end,2:end], 'L', 0)) + μ[1], L[1, 1], μ[2:end], PDMat(Σ[2:end, 2:end], Cholesky(L[2:end, 2:end], 'L', 0)) ) - TestModel(model, μ, LowerTriangular(L), n_y_dims+1, 1/σ0^2, false) -end + return TestModel(model, μ, LowerTriangular(L), n_y_dims + 1, 1 / σ0^2, false) +end -function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type) - n_y_dims = 5 +function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type) + n_y_dims = 5 σ0 = realtype(0.3) - μ = Fill(realtype(5), n_y_dims + 1) - σ = Fill(σ0, n_y_dims + 1) - L = Diagonal(σ) + μ = Fill(realtype(5), n_y_dims + 1) + σ = Fill(σ0, n_y_dims + 1) + L = Diagonal(σ) - model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end].^2)) + model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end] .^ 2)) - TestModel(model, μ, L, n_y_dims+1, 1/σ0^2, true) -end + return TestModel(model, μ, L, n_y_dims + 1, 1 / σ0^2, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 29850b65..a827dbf9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,6 @@ if GROUP == "All" || GROUP == "Interface" include("interface/location_scale.jl") end - const PROGRESS = haskey(ENV, "PROGRESS") if GROUP == "All" || GROUP == "Inference"