Skip to content

Commit

Permalink
improve RepGradELBO to not redefine integrand
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 13, 2024
1 parent 0632e26 commit e9b9609
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
7 changes: 5 additions & 2 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
"""
function value_and_gradient! end

stop_gradient(x) = x


# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
Expand Down Expand Up @@ -92,9 +95,9 @@ This function needs to be implemented only if `obj` is stateful.
init(
::Random.AbstractRNG,
::AbstractVariationalObjective,
::Any
::Any,
::Any
::Any,
::Any,
) = nothing

"""
Expand Down
52 changes: 32 additions & 20 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO)
print(io, ")")
end

function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end

"""
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
reparam_with_entropy(rng, q, n_samples, ent_est)
Draw `n_samples` from `q` and compute its entropy.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q`: Variational approximation.
- `q_stop`: `q` but with its gradient stopped.
- `n_samples::Int`: Number of Monte Carlo samples
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
Expand All @@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy.
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
"""
function reparam_with_entropy(
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator
)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
entropy = estimate_entropy(ent_est, samples, q)
samples, entropy
end

Expand All @@ -86,14 +80,39 @@ function estimate_objective(
prob;
n_samples::Int = obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy)
samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy)
energy = estimate_energy_with_samples(prob, samples)
energy + entropy
end

estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)

function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure)
function obj_integrand(params′)
q = restructure(params′)
samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
-elbo
end
obj_integrand
end

function estimate_objective_restructure(
rng::Random.AbstractRNG,
obj::RepGradELBO,
params,
restructure,
prob;
n_samples::Int = obj.n_samples
)
q = restructure(params)
samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy)
energy = estimate_energy_with_samples(prob, samples)
energy + entropy
end

function estimate_gradient!(
rng ::Random.AbstractRNG,
obj ::RepGradELBO,
Expand All @@ -104,18 +123,11 @@ function estimate_gradient!(
restructure,
state,
)
q_stop = restructure(λ)
function f(λ′)
q = restructure(λ′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(prob, samples)
elbo = energy + entropy
-elbo
end
value_and_gradient!(adtype, f, λ, out)
obj_integrand = state
value_and_gradient!(adtype, obj_integrand, λ, out)

nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)

out, nothing, stat
out, state, stat
end

0 comments on commit e9b9609

Please sign in to comment.