Skip to content

Commit

Permalink
Implement variance reduction for Reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
BatyLeo committed Jun 25, 2024
1 parent 652da15 commit cc87254
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
17 changes: 14 additions & 3 deletions docs/src/background.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,21 @@ And the vector-Jacobian product:
\partial F(\theta)^\top v = \mathbb{E}_{p(\theta)} \left[(f(X)^\top v) \nabla_\theta \log p(X, \theta)\right]
```

### Variance reduction
### Variance reduction (k > 1)

!!! warning
Work in progress.
The Reinforce estimator having high variance, it can be reduced by using a baseline (see this [paper](https://openreview.net/pdf?id=r1lgTGL5DE)) as follows:
```math
\begin{aligned}
\partial F(\theta) &\simeq \frac{1}{k}\sum_{i=1}^k f(x_k) \nabla_\theta\log p(x_k, \theta)^\top\\
& \simeq \frac{1}{k}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k - 1}\sum_{j\neq i} f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top\\
& = \frac{1}{k - 1}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top
\end{aligned}
```

This gives the following Vector-Jacobian product:
```math
\partial F(\theta)^\top v \simeq \frac{1}{k - 1}\sum_{i=1}^k \left(\left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right)^\top v\right) \nabla_\theta\log p(x_i, \theta)^\top
```

## Reparametrization

Expand Down
33 changes: 25 additions & 8 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ $(TYPEDFIELDS)
- [`DifferentiableExpectation`](@ref)
"""
struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
struct Reinforce{threaded,variance_reduction,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
DifferentiableExpectation{threaded}
"function applied inside the expectation"
f::F
Expand All @@ -59,11 +59,13 @@ struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
seed::S
end

function Base.show(io::IO, rep::Reinforce{threaded}) where {threaded}
function Base.show(
io::IO, rep::Reinforce{threaded,variance_reduction}
) where {threaded,variance_reduction}
(; f, dist_constructor, dist_logdensity_grad, rng, nb_samples) = rep
return print(
io,
"Reinforce{$threaded}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)",
"Reinforce{$threaded,$variance_reduction}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)",
)
end

Expand All @@ -74,9 +76,10 @@ function Reinforce(
rng::R=default_rng(),
nb_samples=1,
threaded=false,
variance_reduction=true,
seed::S=nothing,
) where {F,D,G,R,S}
return Reinforce{threaded,F,D,G,R,S}(
return Reinforce{threaded,variance_reduction,F,D,G,R,S}(
f, dist_constructor, dist_logdensity_grad, rng, nb_samples, seed
)
end
Expand All @@ -96,8 +99,8 @@ function dist_logdensity_grad(
end

function ChainRulesCore.rrule(
rc::RuleConfig, F::Reinforce{threaded}, θ...; kwargs...
) where {threaded}
rc::RuleConfig, F::Reinforce{threaded,variance_reduction}, θ...; kwargs...
) where {threaded,variance_reduction}
project_θ = ProjectTo(θ)

(; nb_samples) = F
Expand All @@ -111,16 +114,30 @@ function ChainRulesCore.rrule(
map(_dist_logdensity_grad_partial, xs)
end

ys_with_baseline = if variance_reduction && nb_samples > 1
y_sum = threaded ? tmean(ys) : mean(ys)
map(ys) do yᵢ
yᵢ .- y_sum
end
else
ys
end
K = if variance_reduction && nb_samples > 1
nb_samples - 1
else
nb_samples
end

function pullback_Reinforce(dy_thunked)
dy = unthunk(dy_thunked)
dF = @not_implemented(
"The fields of the `Reinforce` object are considered constant."
)
_single_sample_pullback(g, y) = g .* dot(y, dy)
= if threaded
tmapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples
tmapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K
else
mapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples
mapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K
end
dθ_proj = project_θ(dθ)
return (dF, dθ_proj...)
Expand Down
28 changes: 28 additions & 0 deletions test/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,31 @@ end;
end
end
end

@testset "Variance reduction" begin
for seed in 1:10
rng = StableRNG(seed)
f(x) = x
dist_constructor(θ) = MvNormal(θ, I)
n = 10
θ = randn(rng, n)
r = Reinforce(
f, dist_constructor; rng=rng, nb_samples=100, seed=seed, variance_reduction=true
)
r_no_variance_reduction = Reinforce(
f,
dist_constructor;
rng=rng,
nb_samples=100,
seed=seed,
variance_reduction=false,
)

J_reduced_variance = jacobian(r, θ)[1]
J_no_reduced_variance = jacobian(r_no_variance_reduction, θ)[1]
J_true = Matrix(I, n, n)

mape(x::AbstractArray, y::AbstractArray) = mean(abs.(x .- y))
@test mape(J_reduced_variance, J_true) < mape(J_no_reduced_variance, J_true)
end
end

0 comments on commit cc87254

Please sign in to comment.