@@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO)
46
46
print (io, " )" )
47
47
end
48
48
49
- function estimate_entropy_maybe_stl (entropy_estimator:: AbstractEntropyEstimator , samples, q, q_stop)
50
- q_maybe_stop = maybe_stop_entropy_score (entropy_estimator, q, q_stop)
51
- estimate_entropy (entropy_estimator, samples, q_maybe_stop)
52
- end
53
-
54
49
function estimate_energy_with_samples (prob, samples)
55
50
mean (Base. Fix1 (LogDensityProblems. logdensity, prob), eachsample (samples))
56
51
end
57
52
58
53
"""
59
- reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
54
+ reparam_with_entropy(rng, q, n_samples, ent_est)
60
55
61
56
Draw `n_samples` from `q` and compute its entropy.
62
57
63
58
# Arguments
64
59
- `rng::Random.AbstractRNG`: Random number generator.
65
60
- `q`: Variational approximation.
66
- - `q_stop`: `q` but with its gradient stopped.
67
61
- `n_samples::Int`: Number of Monte Carlo samples
68
62
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
69
63
@@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy.
72
66
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
73
67
"""
74
68
function reparam_with_entropy (
75
- rng:: Random.AbstractRNG , q, q_stop, n_samples:: Int , ent_est:: AbstractEntropyEstimator
69
+ rng:: Random.AbstractRNG , q, n_samples:: Int , ent_est:: AbstractEntropyEstimator
76
70
)
77
71
samples = rand (rng, q, n_samples)
78
- entropy = estimate_entropy_maybe_stl (ent_est, samples, q, q_stop )
72
+ entropy = estimate_entropy (ent_est, samples, q)
79
73
samples, entropy
80
74
end
81
75
@@ -86,14 +80,39 @@ function estimate_objective(
86
80
prob;
87
81
n_samples:: Int = obj. n_samples
88
82
)
89
- samples, entropy = reparam_with_entropy (rng, q, q, n_samples, obj. entropy)
83
+ samples, entropy = reparam_with_entropy (rng, q, n_samples, obj. entropy)
90
84
energy = estimate_energy_with_samples (prob, samples)
91
85
energy + entropy
92
86
end
93
87
94
88
estimate_objective (obj:: RepGradELBO , q, prob; n_samples:: Int = obj. n_samples) =
95
89
estimate_objective (Random. default_rng (), obj, q, prob; n_samples)
96
90
91
+ function init (rng:: Random.AbstractRNG , obj:: RepGradELBO , problem, params, restructure)
92
+ function obj_integrand (params′)
93
+ q = restructure (params′)
94
+ samples, entropy = reparam_with_entropy (rng, q, obj. n_samples, obj. entropy)
95
+ energy = estimate_energy_with_samples (problem, samples)
96
+ elbo = energy + entropy
97
+ - elbo
98
+ end
99
+ obj_integrand
100
+ end
101
+
102
+ function estimate_objective_restructure (
103
+ rng:: Random.AbstractRNG ,
104
+ obj:: RepGradELBO ,
105
+ params,
106
+ restructure,
107
+ prob;
108
+ n_samples:: Int = obj. n_samples
109
+ )
110
+ q = restructure (params)
111
+ samples, entropy = reparam_with_entropy (rng, q, n_samples, obj. entropy)
112
+ energy = estimate_energy_with_samples (prob, samples)
113
+ energy + entropy
114
+ end
115
+
97
116
function estimate_gradient! (
98
117
rng :: Random.AbstractRNG ,
99
118
obj :: RepGradELBO ,
@@ -104,18 +123,11 @@ function estimate_gradient!(
104
123
restructure,
105
124
state,
106
125
)
107
- q_stop = restructure (λ)
108
- function f (λ′)
109
- q = restructure (λ′)
110
- samples, entropy = reparam_with_entropy (rng, q, q_stop, obj. n_samples, obj. entropy)
111
- energy = estimate_energy_with_samples (prob, samples)
112
- elbo = energy + entropy
113
- - elbo
114
- end
115
- value_and_gradient! (adtype, f, λ, out)
126
+ obj_integrand = state
127
+ value_and_gradient! (adtype, obj_integrand, λ, out)
116
128
117
129
nelbo = DiffResults. value (out)
118
130
stat = (elbo= - nelbo,)
119
131
120
- out, nothing , stat
132
+ out, state , stat
121
133
end
0 commit comments