From 0632e26b4d2d00848d33a0cf034b4f770c2ff6f3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 21:56:38 +0100 Subject: [PATCH 1/5] update interface for objective initialization --- src/AdvancedVI.jl | 7 ++++--- src/optimize.jl | 2 +- src/utils.jl | 11 +++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7c7a1fc8..08b1b71d 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -78,7 +78,7 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, λ, restructure) + init(rng, obj, prob, params, restructure) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. @@ -86,13 +86,14 @@ This function needs to be implemented only if `obj` is stateful. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. -- `λ`: Initial variational parameters. +- `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::AbstractVector, + ::Any + ::Any, ::Any ) = nothing diff --git a/src/optimize.jl b/src/optimize.jl index 4eb6644a..e5fe374d 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,7 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, params, restructure) + obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 98b79b2d..a8039b22 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,19 +6,22 @@ end function maybe_init_optimizer( state_init::NamedTuple, optimizer ::Optimisers.AbstractRule, - params ::AbstractVector + params ) - haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params) + haskey(state_init, :optimizer) ? + state_init.optimizer : Optimisers.setup(optimizer, params) end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, objective ::AbstractVariationalObjective, - params ::AbstractVector, + problem, + params, restructure ) - haskey(state_init, :objective) ? state_init.objective : init(rng, objective, params, restructure) + haskey(state_init, :objective) ? + state_init.objective : init(rng, objective, params, problem, restructure) end eachsample(samples::AbstractMatrix) = eachcol(samples) From e9b960924d614d54f1b8b6009c2bf5c3fc325ee4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 22:39:21 +0100 Subject: [PATCH 2/5] improve `RepGradELBO` to not redefine integrand --- src/AdvancedVI.jl | 7 ++-- src/objectives/elbo/repgradelbo.jl | 52 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 08b1b71d..632bb182 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -37,6 +37,9 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di """ function value_and_gradient! end +stop_gradient(x) = x + + # Update for gradient descent step """ update_variational_params!(family_type, opt_st, params, restructure, grad) @@ -92,9 +95,9 @@ This function needs to be implemented only if `obj` is stateful. init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::Any ::Any, - ::Any + ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2d95d076..d860a600 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO) print(io, ")") end -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end """ - reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + reparam_with_entropy(rng, q, n_samples, ent_est) Draw `n_samples` from `q` and compute its entropy. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `q`: Variational approximation. -- `q_stop`: `q` but with its gradient stopped. - `n_samples::Int`: Number of Monte Carlo samples - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) @@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator + rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + entropy = estimate_entropy(ent_est, samples, q) samples, entropy end @@ -86,7 +80,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -94,6 +88,31 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure) + function obj_integrand(params′) + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo + end + obj_integrand +end + +function estimate_objective_restructure( + rng::Random.AbstractRNG, + obj::RepGradELBO, + params, + restructure, + prob; + n_samples::Int = obj.n_samples +) + q = restructure(params) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, @@ -104,18 +123,11 @@ function estimate_gradient!( restructure, state, ) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - elbo = energy + entropy - -elbo - end - value_and_gradient!(adtype, f, λ, out) + obj_integrand = state + value_and_gradient!(adtype, obj_integrand, λ, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat + out, state, stat end From bcbe11d251a762c3d1c360bb1c5e1dc158922115 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 22:58:01 +0100 Subject: [PATCH 3/5] fix bug --- src/utils.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a8039b22..3ae59a78 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,8 +8,11 @@ function maybe_init_optimizer( optimizer ::Optimisers.AbstractRule, params ) - haskey(state_init, :optimizer) ? - state_init.optimizer : Optimisers.setup(optimizer, params) + if haskey(state_init, :optimizer) + state_init.optimizer + else + Optimisers.setup(optimizer, params) + end end function maybe_init_objective( @@ -20,8 +23,11 @@ function maybe_init_objective( params, restructure ) - haskey(state_init, :objective) ? - state_init.objective : init(rng, objective, params, problem, restructure) + if haskey(state_init, :objective) + state_init.objective + else + init(rng, objective, problem, params, restructure) + end end eachsample(samples::AbstractMatrix) = eachcol(samples) From 00493693f22b44a5f01f472e7636ab0e03ea7c31 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 23:10:33 +0100 Subject: [PATCH 4/5] fix Bijectors ext to match new `reparam_with_entropy` interface --- ext/AdvancedVIBijectorsExt.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 4a88d6fb..950b5f15 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -38,17 +38,15 @@ end function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, - q_stop ::Bijectors.TransformedDistribution, n_samples::Int, ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_unconst = q.dist - q_unconst_stop = q_stop.dist + transform = q.transform + q_unconst = q.dist # Draw samples and compute entropy of the uncontrained distribution unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( - rng, q_unconst, q_unconst_stop, n_samples, ent_est + rng, q_unconst, n_samples, ent_est ) # Apply bijector to samples while estimating its jacobian From ef3c312f004ecf5eb14aa82878aac88037b88bc9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jun 2024 01:04:20 +0100 Subject: [PATCH 5/5] revert removal of q_stop, add auxiliary argto `value_and_gradient!` --- ext/AdvancedVIBijectorsExt.jl | 8 +++-- ext/AdvancedVIForwardDiffExt.jl | 23 +++++++++--- ext/AdvancedVIReverseDiffExt.jl | 19 ++++++++-- ext/AdvancedVIZygoteExt.jl | 23 +++++++++--- src/AdvancedVI.jl | 22 +++++++++--- src/objectives/elbo/repgradelbo.jl | 58 ++++++++++++++---------------- test/Project.toml | 3 +- test/interface/repgradelbo.jl | 29 +++++++++++++++ 8 files changed, 132 insertions(+), 53 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 950b5f15..a227fdf2 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -38,15 +38,17 @@ end function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, n_samples::Int, ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_unconst = q.dist + transform = q.transform + q_unconst = q.dist + q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( - rng, q_unconst, n_samples, ent_est + rng, q_unconst, q_unconst_stop, n_samples, ent_est ) # Apply bijector to samples while estimating its jacobian diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 5949bdf8..a8afd031 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,16 +14,29 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, θ) + ForwardDiff.GradientConfig(f, x) else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) + ForwardDiff.gradient!(out, f, x, config) return out end +function AdvancedVI.value_and_gradient!( + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 520cd9ff..392f5cea 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,11 +13,24 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) + tp = ReverseDiff.GradientTape(f, x) + ReverseDiff.gradient!(out, tp, x) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 7b8f8817..806c08e4 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -4,21 +4,36 @@ module AdvancedVIZygoteExt if isdefined(Base, :get_extension) using AdvancedVI using AdvancedVI: ADTypes, DiffResults + using ChainRulesCore using Zygote else using ..AdvancedVI using ..AdvancedVI: ADTypes, DiffResults + using ..ChainRulesCore using ..Zygote end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - y, back = Zygote.pullback(f, θ) - ∇θ = back(one(y)) + y, back = Zygote.pullback(f, x) + ∇x = back(one(y)) DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇θ)) + DiffResults.gradient!(out, only(∇x)) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 632bb182..7a09030b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,20 +25,34 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, θ, out) + value_and_gradient!(ad, f, x, out) + value_and_gradient!(ad, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +`f` may receive auxiliary input as `f(x,aux)`. # Arguments - `ad::ADTypes.AbstractADType`: Automatic differentiation backend. - `f`: Function subject to differentiation. -- `θ`: The point to evaluate the gradient. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function value_and_gradient! end -stop_gradient(x) = x +""" + stop_gradient(x) + +Stop the gradient from propagating to `x` if the selected ad backend supports it. +Otherwise, it is equivalent to `identity`. + +# Arguments +- `x`: Input +# Returns +- `x`: Same value as the input. +""" +function stop_gradient end # Update for gradient descent step """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index d860a600..27a937e8 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -46,6 +46,11 @@ function Base.show(io::IO, obj::RepGradELBO) print(io, ")") end +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -66,10 +71,14 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator + rng ::Random.AbstractRNG, + q, + q_stop, + n_samples::Int, + ent_est ::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy(ent_est, samples, q) + entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) samples, entropy end @@ -80,7 +89,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -88,29 +97,13 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) -function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure) - function obj_integrand(params′) - q = restructure(params′) - samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(problem, samples) - elbo = energy + entropy - -elbo - end - obj_integrand -end - -function estimate_objective_restructure( - rng::Random.AbstractRNG, - obj::RepGradELBO, - params, - restructure, - prob; - n_samples::Int = obj.n_samples -) - q = restructure(params) - samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - energy + entropy +function estimate_repgradelbo_ad_forward(params′, aux) + @unpack rng, obj, problem, restructure, q_stop = aux + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo end function estimate_gradient!( @@ -119,15 +112,16 @@ function estimate_gradient!( adtype::ADTypes.AbstractADType, out ::DiffResults.MutableDiffResult, prob, - λ, + params, restructure, state, ) - obj_integrand = state - value_and_gradient!(adtype, obj_integrand, λ, out) - + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + value_and_gradient!( + adtype, estimate_repgradelbo_ad_forward, params, aux, out + ) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - - out, state, stat + out, nothing, stat end diff --git a/test/Project.toml b/test/Project.toml index f3acedea..a0dba17f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,9 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -26,7 +26,6 @@ ADTypes = "0.2.1, 1" Bijectors = "0.13" Distributions = "0.25.100" DistributionsAD = "0.6.45" -Enzyme = "0.12" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 61ff0111..ac9bfeca 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -26,3 +26,32 @@ using Test @test elbo ≈ elbo_ref rtol=0.1 end end + +@testset "interface RepGradELBO STL variance reduction" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + @testset for ad in [ + ADTypes.AutoForwardDiff(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoZygote() + ] + q_true = MeanFieldGaussian( + Vector{eltype(μ_true)}(μ_true), + Diagonal(Vector{eltype(L_true)}(diag(L_true))) + ) + params, re = Optimisers.destructure(q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + AdvancedVI.value_and_gradient!( + ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + ) + grad = DiffResults.gradient(out) + @test norm(grad) ≈ 0 atol=1e-5 + end +end