Skip to content

Commit e9b9609

Browse files
committed
improve RepGradELBO to not redefine integrand
1 parent 0632e26 commit e9b9609

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/AdvancedVI.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
3737
"""
3838
function value_and_gradient! end
3939

40+
stop_gradient(x) = x
41+
42+
4043
# Update for gradient descent step
4144
"""
4245
update_variational_params!(family_type, opt_st, params, restructure, grad)
@@ -92,9 +95,9 @@ This function needs to be implemented only if `obj` is stateful.
9295
init(
9396
::Random.AbstractRNG,
9497
::AbstractVariationalObjective,
95-
::Any
9698
::Any,
97-
::Any
99+
::Any,
100+
::Any,
98101
) = nothing
99102

100103
"""

src/objectives/elbo/repgradelbo.jl

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO)
4646
print(io, ")")
4747
end
4848

49-
function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
50-
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
51-
estimate_entropy(entropy_estimator, samples, q_maybe_stop)
52-
end
53-
5449
function estimate_energy_with_samples(prob, samples)
5550
mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
5651
end
5752

5853
"""
59-
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
54+
reparam_with_entropy(rng, q, n_samples, ent_est)
6055
6156
Draw `n_samples` from `q` and compute its entropy.
6257
6358
# Arguments
6459
- `rng::Random.AbstractRNG`: Random number generator.
6560
- `q`: Variational approximation.
66-
- `q_stop`: `q` but with its gradient stopped.
6761
- `n_samples::Int`: Number of Monte Carlo samples
6862
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
6963
@@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy.
7266
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
7367
"""
7468
function reparam_with_entropy(
75-
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
69+
rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator
7670
)
7771
samples = rand(rng, q, n_samples)
78-
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
72+
entropy = estimate_entropy(ent_est, samples, q)
7973
samples, entropy
8074
end
8175

@@ -86,14 +80,39 @@ function estimate_objective(
8680
prob;
8781
n_samples::Int = obj.n_samples
8882
)
89-
samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy)
83+
samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy)
9084
energy = estimate_energy_with_samples(prob, samples)
9185
energy + entropy
9286
end
9387

9488
estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
9589
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
9690

91+
function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure)
92+
function obj_integrand(params′)
93+
q = restructure(params′)
94+
samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy)
95+
energy = estimate_energy_with_samples(problem, samples)
96+
elbo = energy + entropy
97+
-elbo
98+
end
99+
obj_integrand
100+
end
101+
102+
function estimate_objective_restructure(
103+
rng::Random.AbstractRNG,
104+
obj::RepGradELBO,
105+
params,
106+
restructure,
107+
prob;
108+
n_samples::Int = obj.n_samples
109+
)
110+
q = restructure(params)
111+
samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy)
112+
energy = estimate_energy_with_samples(prob, samples)
113+
energy + entropy
114+
end
115+
97116
function estimate_gradient!(
98117
rng ::Random.AbstractRNG,
99118
obj ::RepGradELBO,
@@ -104,18 +123,11 @@ function estimate_gradient!(
104123
restructure,
105124
state,
106125
)
107-
q_stop = restructure(λ)
108-
function f(λ′)
109-
q = restructure(λ′)
110-
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
111-
energy = estimate_energy_with_samples(prob, samples)
112-
elbo = energy + entropy
113-
-elbo
114-
end
115-
value_and_gradient!(adtype, f, λ, out)
126+
obj_integrand = state
127+
value_and_gradient!(adtype, obj_integrand, λ, out)
116128

117129
nelbo = DiffResults.value(out)
118130
stat = (elbo=-nelbo,)
119131

120-
out, nothing, stat
132+
out, state, stat
121133
end

0 commit comments

Comments
 (0)