diff --git a/docs/src/background.md b/docs/src/background.md index 5d12cc9..26c9fef 100644 --- a/docs/src/background.md +++ b/docs/src/background.md @@ -42,10 +42,21 @@ And the vector-Jacobian product: \partial F(\theta)^\top v = \mathbb{E}_{p(\theta)} \left[(f(X)^\top v) \nabla_\theta \log p(X, \theta)\right] ``` -### Variance reduction +### Variance reduction (k > 1) -!!! warning - Work in progress. +The Reinforce estimator having high variance, it can be reduced by using a baseline (see this [paper](https://openreview.net/pdf?id=r1lgTGL5DE)) as follows: +```math +\begin{aligned} +\partial F(\theta) &\simeq \frac{1}{k}\sum_{i=1}^k f(x_k) \nabla_\theta\log p(x_k, \theta)^\top\\ +& \simeq \frac{1}{k}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k - 1}\sum_{j\neq i} f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top\\ +& = \frac{1}{k - 1}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top +\end{aligned} +``` + +This gives the following Vector-Jacobian product: +```math +\partial F(\theta)^\top v \simeq \frac{1}{k - 1}\sum_{i=1}^k \left(\left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right)^\top v\right) \nabla_\theta\log p(x_i, \theta)^\top +``` ## Reparametrization diff --git a/src/reinforce.jl b/src/reinforce.jl index e5f4c3b..ab995ce 100644 --- a/src/reinforce.jl +++ b/src/reinforce.jl @@ -43,7 +43,7 @@ $(TYPEDFIELDS) - [`DifferentiableExpectation`](@ref) """ -struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <: +struct Reinforce{threaded,variance_reduction,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <: DifferentiableExpectation{threaded} "function applied inside the expectation" f::F @@ -59,11 +59,13 @@ struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <: seed::S end -function Base.show(io::IO, rep::Reinforce{threaded}) where {threaded} +function Base.show( + io::IO, rep::Reinforce{threaded,variance_reduction} +) where {threaded,variance_reduction} (; f, dist_constructor, dist_logdensity_grad, rng, nb_samples) = rep return print( io, - "Reinforce{$threaded}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)", + "Reinforce{$threaded,$variance_reduction}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)", ) end @@ -74,9 +76,10 @@ function Reinforce( rng::R=default_rng(), nb_samples=1, threaded=false, + variance_reduction=true, seed::S=nothing, ) where {F,D,G,R,S} - return Reinforce{threaded,F,D,G,R,S}( + return Reinforce{threaded,variance_reduction,F,D,G,R,S}( f, dist_constructor, dist_logdensity_grad, rng, nb_samples, seed ) end @@ -96,8 +99,8 @@ function dist_logdensity_grad( end function ChainRulesCore.rrule( - rc::RuleConfig, F::Reinforce{threaded}, θ...; kwargs... -) where {threaded} + rc::RuleConfig, F::Reinforce{threaded,variance_reduction}, θ...; kwargs... +) where {threaded,variance_reduction} project_θ = ProjectTo(θ) (; nb_samples) = F @@ -111,6 +114,20 @@ function ChainRulesCore.rrule( map(_dist_logdensity_grad_partial, xs) end + ys_with_baseline = if variance_reduction && nb_samples > 1 + y_sum = threaded ? tmean(ys) : mean(ys) + map(ys) do yᵢ + yᵢ .- y_sum + end + else + ys + end + K = if variance_reduction && nb_samples > 1 + nb_samples - 1 + else + nb_samples + end + function pullback_Reinforce(dy_thunked) dy = unthunk(dy_thunked) dF = @not_implemented( @@ -118,9 +135,9 @@ function ChainRulesCore.rrule( ) _single_sample_pullback(g, y) = g .* dot(y, dy) dθ = if threaded - tmapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples + tmapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K else - mapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples + mapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K end dθ_proj = project_θ(dθ) return (dF, dθ_proj...) diff --git a/test/expectation.jl b/test/expectation.jl index 6c36f15..866fd6f 100644 --- a/test/expectation.jl +++ b/test/expectation.jl @@ -96,3 +96,31 @@ end; end end end + +@testset "Variance reduction" begin + for seed in 1:10 + rng = StableRNG(seed) + f(x) = x + dist_constructor(θ) = MvNormal(θ, I) + n = 10 + θ = randn(rng, n) + r = Reinforce( + f, dist_constructor; rng=rng, nb_samples=100, seed=seed, variance_reduction=true + ) + r_no_variance_reduction = Reinforce( + f, + dist_constructor; + rng=rng, + nb_samples=100, + seed=seed, + variance_reduction=false, + ) + + J_reduced_variance = jacobian(r, θ)[1] + J_no_reduced_variance = jacobian(r_no_variance_reduction, θ)[1] + J_true = Matrix(I, n, n) + + mape(x::AbstractArray, y::AbstractArray) = mean(abs.(x .- y)) + @test mape(J_reduced_variance, J_true) < mape(J_no_reduced_variance, J_true) + end +end