From e9b960924d614d54f1b8b6009c2bf5c3fc325ee4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 22:39:21 +0100 Subject: [PATCH] improve `RepGradELBO` to not redefine integrand --- src/AdvancedVI.jl | 7 ++-- src/objectives/elbo/repgradelbo.jl | 52 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 08b1b71d..632bb182 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -37,6 +37,9 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di """ function value_and_gradient! end +stop_gradient(x) = x + + # Update for gradient descent step """ update_variational_params!(family_type, opt_st, params, restructure, grad) @@ -92,9 +95,9 @@ This function needs to be implemented only if `obj` is stateful. init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::Any ::Any, - ::Any + ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2d95d076..d860a600 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO) print(io, ")") end -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end """ - reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + reparam_with_entropy(rng, q, n_samples, ent_est) Draw `n_samples` from `q` and compute its entropy. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `q`: Variational approximation. -- `q_stop`: `q` but with its gradient stopped. - `n_samples::Int`: Number of Monte Carlo samples - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) @@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator + rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + entropy = estimate_entropy(ent_est, samples, q) samples, entropy end @@ -86,7 +80,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -94,6 +88,31 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure) + function obj_integrand(params′) + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo + end + obj_integrand +end + +function estimate_objective_restructure( + rng::Random.AbstractRNG, + obj::RepGradELBO, + params, + restructure, + prob; + n_samples::Int = obj.n_samples +) + q = restructure(params) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, @@ -104,18 +123,11 @@ function estimate_gradient!( restructure, state, ) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - elbo = energy + entropy - -elbo - end - value_and_gradient!(adtype, f, λ, out) + obj_integrand = state + value_and_gradient!(adtype, obj_integrand, λ, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat + out, state, stat end