Skip to content

Commit

Permalink
Implement variance reduction for Reinforce (#12)
Browse files Browse the repository at this point in the history
* Implement variance reduction for Reinforce

* Cleanup Reinforce rrule, and better variance reduction tests

* Clean up

* Fix docs

* Fix

* Fix bibtex

* Fix tests

---------

Co-authored-by: Guillaume Dalle <[email protected]>
  • Loading branch information
BatyLeo and gdalle authored Jun 26, 2024
1 parent 652da15 commit 87382ee
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 16 deletions.
30 changes: 25 additions & 5 deletions docs/src/DiffExp.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,35 @@ @misc{blondelElementsDifferentiableProgramming2024
title = {The {{Elements}} of {{Differentiable Programming}}},
author = {Blondel, Mathieu and Roulet, Vincent},
year = {2024},
month = mar,
number = {arXiv:2403.14606},
eprint = {2403.14606},
primaryclass = {cs},
publisher = {arXiv},
doi = {10.48550/arXiv.2403.14606},
url = {http://arxiv.org/abs/2403.14606},
urldate = {2024-03-22},
abstract = {Artificial intelligence has recently experienced remarkable advances, fueled by large models, vast datasets, accelerated hardware, and, last but not least, the transformative power of differentiable programming. This new programming paradigm enables end-to-end differentiation of complex computer programs (including those with control flows and data structures), making gradient-based optimization of program parameters possible. As an emerging paradigm, differentiable programming builds upon several areas of computer science and applied mathematics, including automatic differentiation, graphical models, optimization and statistics. This book presents a comprehensive review of the fundamental concepts useful for differentiable programming. We adopt two main perspectives, that of optimization and that of probability, with clear analogies between the two. Differentiable programming is not merely the differentiation of programs, but also the thoughtful design of programs intended for differentiation. By making programs differentiable, we inherently introduce probability distributions over their execution, providing a means to quantify the uncertainty associated with program outputs.},
archiveprefix = {arxiv},
keywords = {diffexp,done,tracer},
file = {/home/gdalle/snap/zotero-snap/common/Zotero/storage/3KCV6KRG/Blondel_Roulet_2024_The Elements of Differentiable Programming.pdf;/home/gdalle/snap/zotero-snap/common/Zotero/storage/B7URG3VS/2403.html}
archiveprefix = {arXiv},
}
% == BibTeX quality report for blondelElementsDifferentiableProgramming2024:
% ? Title looks like it was stored in title-case in Zotero
@article{koolBuyREINFORCESamples2022,
title = {Buy 4 {{REINFORCE Samples}}, {{Get}} a {{Baseline}} for {{Free}}!},
author = {Kool, Wouter and van Hoof, Herke and Welling, Max},
year = {2022},
month = jul,
url = {https://openreview.net/forum?id=r1lgTGL5DE},
urldate = {2023-04-17},
abstract = {REINFORCE can be used to train models in structured prediction settings to directly optimize the test-time objective. However, the common case of sampling one prediction per datapoint (input) is data-inefficient. We show that by drawing multiple samples (predictions) per datapoint, we can learn with significantly less data, as we freely obtain a REINFORCE baseline to reduce variance. Additionally we derive a REINFORCE estimator with baseline, based on sampling without replacement. Combined with a recent technique to sample sequences without replacement using Stochastic Beam Search, this improves the training procedure for a sequence model that predicts the solution to the Travelling Salesman Problem.},
langid = {english},
language = {en},
}
% == BibTeX quality report for koolBuyREINFORCESamples2022:
% Missing required field 'journal'
% ? Title looks like it was stored in title-case in Zotero
% ? unused Library catalog ("openreview.net")
@article{mohamedMonteCarloGradient2020,
title = {Monte {{Carlo Gradient Estimation}} in {{Machine Learning}}},
Expand All @@ -24,7 +42,9 @@ @article{mohamedMonteCarloGradient2020
pages = {1--62},
issn = {1533-7928},
url = {http://jmlr.org/papers/v21/19-346.html},
urldate = {2022-10-21},
abstract = {This paper is a broad and accessible survey of the methods we have at our disposal for Monte Carlo gradient estimation in machine learning and across the statistical sciences: the problem of computing the gradient of an expectation of a function with respect to parameters defining the distribution that is integrated; the problem of sensitivity analysis. In machine learning research, this gradient problem lies at the core of many learning problems, in supervised, unsupervised and reinforcement learning. We will generally seek to rewrite such gradients in a form that allows for Monte Carlo estimation, allowing them to be easily and efficiently used and analysed. We explore three strategies---the pathwise, score function, and measure-valued gradient estimators---exploring their historical development, derivation, and underlying assumptions. We describe their use in other fields, show how they are related and can be combined, and expand on their possible generalisations. Wherever Monte Carlo gradient estimators have been derived and deployed in the past, important advances have followed. A deeper and more widely-held understanding of this problem will lead to further advances, and it is these advances that we wish to support.},
keywords = {diffexp,done},
file = {/home/gdalle/snap/zotero-snap/common/Zotero/storage/6KTY5IG4/Mohamed et al. - 2020 - Monte Carlo Gradient Estimation in Machine Learnin.pdf;/home/gdalle/snap/zotero-snap/common/Zotero/storage/IMI4JXES/mc_gradients.html}
}
% == BibTeX quality report for mohamedMonteCarloGradient2020:
% ? Title looks like it was stored in title-case in Zotero
% ? unused Library catalog ("jmlr.org")
18 changes: 16 additions & 2 deletions docs/src/background.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,22 @@ And the vector-Jacobian product:

### Variance reduction

!!! warning
Work in progress.
Since the REINFORCE estimator has high variance, it can be reduced by using a baseline [koolBuyREINFORCESamples2022](@citep).
For $k > 1$ Monte-Carlo samples, we have

```math
\begin{aligned}
\partial F(\theta) &\simeq \frac{1}{k}\sum_{i=1}^k f(x_k) \nabla_\theta\log p(x_k, \theta)^\top\\
& \simeq \frac{1}{k}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k - 1}\sum_{j\neq i} f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top\\
& = \frac{1}{k - 1}\sum_{i=1}^k \left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right) \nabla_\theta\log p(x_i, \theta)^\top
\end{aligned}
```

This gives the following vector-Jacobian product:

```math
\partial F(\theta)^\top v \simeq \frac{1}{k - 1}\sum_{i=1}^k \left(\left(f(x_i) - \frac{1}{k}\sum_{j=1}^k f(x_j)\right)^\top v\right) \nabla_\theta\log p(x_i, \theta)
```

## Reparametrization

Expand Down
28 changes: 19 additions & 9 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ $(TYPEDFIELDS)
- [`DifferentiableExpectation`](@ref)
"""
struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
struct Reinforce{threaded,variance_reduction,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
DifferentiableExpectation{threaded}
"function applied inside the expectation"
f::F
Expand All @@ -59,11 +59,13 @@ struct Reinforce{threaded,F,D,G,R<:AbstractRNG,S<:Union{Int,Nothing}} <:
seed::S
end

function Base.show(io::IO, rep::Reinforce{threaded}) where {threaded}
function Base.show(
io::IO, rep::Reinforce{threaded,variance_reduction}
) where {threaded,variance_reduction}
(; f, dist_constructor, dist_logdensity_grad, rng, nb_samples) = rep
return print(
io,
"Reinforce{$threaded}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)",
"Reinforce{$threaded,$variance_reduction}($f, $dist_constructor, $dist_logdensity_grad, $rng, $nb_samples)",
)
end

Expand All @@ -74,9 +76,10 @@ function Reinforce(
rng::R=default_rng(),
nb_samples=1,
threaded=false,
variance_reduction=true,
seed::S=nothing,
) where {F,D,G,R,S}
return Reinforce{threaded,F,D,G,R,S}(
return Reinforce{threaded,variance_reduction,F,D,G,R,S}(
f, dist_constructor, dist_logdensity_grad, rng, nb_samples, seed
)
end
Expand All @@ -96,13 +99,14 @@ function dist_logdensity_grad(
end

function ChainRulesCore.rrule(
rc::RuleConfig, F::Reinforce{threaded}, θ...; kwargs...
) where {threaded}
rc::RuleConfig, F::Reinforce{threaded,variance_reduction}, θ...; kwargs...
) where {threaded,variance_reduction}
project_θ = ProjectTo(θ)

(; nb_samples) = F
xs = presamples(F, θ...)
ys = samples_from_presamples(F, xs; kwargs...)
y = threaded ? tmean(ys) : mean(ys)

_dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...)
gs = if threaded
Expand All @@ -111,21 +115,27 @@ function ChainRulesCore.rrule(
map(_dist_logdensity_grad_partial, xs)
end

ys_with_baseline = if (variance_reduction && nb_samples > 1)
map(yi -> yi .- y, ys)
else
ys
end
K = nb_samples - (variance_reduction && nb_samples > 1)

function pullback_Reinforce(dy_thunked)
dy = unthunk(dy_thunked)
dF = @not_implemented(
"The fields of the `Reinforce` object are considered constant."
)
_single_sample_pullback(g, y) = g .* dot(y, dy)
= if threaded
tmapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples
tmapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K
else
mapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples
mapreduce(_single_sample_pullback, .+, gs, ys_with_baseline) ./ K
end
dθ_proj = project_θ(dθ)
return (dF, dθ_proj...)
end

y = threaded ? tmean(ys) : mean(ys)
return y, pullback_Reinforce
end
16 changes: 16 additions & 0 deletions test/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,19 @@ end;
end
end
end

@testset "Reinforce variance reduction" begin
μ, σ = 0.5, 1.0
seed = 63

r = Reinforce(exp, Normal; nb_samples=100, variance_reduction=true, rng=StableRNG(seed))
r_no_vr = Reinforce(
exp, Normal; nb_samples=100, variance_reduction=false, rng=StableRNG(seed)
)

grads = [gradient(r, μ, σ) for _ in 1:1000]
grads_no_vr = [gradient(r_no_vr, μ, σ) for _ in 1:1000]

@test var(first.(grads)) < var(first.(grads_no_vr))
@test var(last.(grads)) < var(last.(grads_no_vr))
end

0 comments on commit 87382ee

Please sign in to comment.