Skip to content

Commit

Permalink
REINFORCE
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 5, 2024
1 parent a01829a commit 8bb5526
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.23"
Distributions = "0.25"
DocStringExtensions = "0.9"
OhMyThreads = "0.5"
LinearAlgebra = "1"
OhMyThreads = "0.5"
Statistics = "1"
julia = "1.10"

[extras]
Expand All @@ -26,6 +28,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand All @@ -36,5 +39,6 @@ test = [
"JuliaFormatter",
"Random",
"StableRNGs",
"Statistics",
"Test",
]
8 changes: 6 additions & 2 deletions src/DifferentiableExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ module DifferentiableExpectations
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using Distributions: Distribution
using DocStringExtensions
using OhMyThreads: tmap
using Random: AbstractRNG
using OhMyThreads: tmap, treduce, tmapreduce
using Random: AbstractRNG, default_rng
using Statistics: mean

include("abstract.jl")
include("reinforce.jl")
include("reparametrization.jl")
include("pushforward.jl")

export distribution, samples
export REINFORCE

end # module DifferentiableExpectations
49 changes: 44 additions & 5 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,50 @@
Abstract supertype for expectation wrappers.
The type parameter `threaded` is a `Bool` stating whether or not the Monte-Carlo samples should be computed in parallel.
The type parameter `threaded` is a `Bool` stating whether the Monte-Carlo samples should be computed in parallel.
"""
abstract type AbstractExpectation{threaded} end

"""
distribution(e::AbstractExpectation, θ...)
Implementing subtypes must have the following fields:
Build the sampling distribution for `e` based on parameters `\theta`.
"""
function distribution end

- `rng::AbstractRNG`: the random number generator
- `nb_samples::Integer`: the number of samples to draw
"""
abstract type AbstractExpectation{threaded} end
(e::AbstractExpectation)(θ...)
Return the Monte-Carlo average of the function represented by `e` over several samples of `distribution(e, θ...)`.
"""
function (e::AbstractExpectation{threaded})(θ...) where {threaded}
dist = distribution(e, θ...)
s = if threaded
tmapreduce(+, 1:(e.nb_samples)) do _
e.f(rand(e.rng, dist))
end
else
mapreduce(+, 1:(e.nb_samples)) do _
e.f(rand(e.rng, dist))
end
end
return s / e.nb_samples
end

"""
samples(e::AbstractExpectation, θ...)
Return the values of the function represented by `e` for several samples of `distribution(e, θ...)`.
"""
function samples(e::AbstractExpectation{threaded}, θ...) where {threaded}
dist = distribution(e, θ...)
if threaded
return tmap(1:(e.nb_samples)) do _
e.f(rand(e.rng, dist))
end
else
return map(1:(e.nb_samples)) do _
e.f(rand(e.rng, dist))
end
end
end
12 changes: 12 additions & 0 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
struct REINFORCE{t,F,D,R<:AbstractRNG} <: AbstractExpectation{t}
f::F
rng::R
nb_samples::Int
end

function REINFORCE(;
f::F, dist_type::Type{D}, rng::R=default_rng(), nb_samples=1, threaded=true
) where {F,D,R}
return REINFORCE{threaded,F,D,R}(f, rng, nb_samples)
end

distribution(::REINFORCE{threaded,F,D}, θ...) where {threaded,F,D} = D...)
11 changes: 11 additions & 0 deletions test/reinforce.jl
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
using Distributions
using DifferentiableExpectations
using Random
using StableRNGs
using Statistics
using Test

e = REINFORCE(; f=exp, dist_type=Normal, rng=StableRNG(63), nb_samples=10^3, threaded=false)
μ, σ = 2.0, 1.0
@test distribution(e, μ, σ) == Normal(μ, σ)
@test e(μ, σ) mean(LogNormal(μ, σ)) rtol = 0.1
@test std(samples(e, μ, σ)) std(LogNormal(μ, σ)) rtol = 0.1

0 comments on commit 8bb5526

Please sign in to comment.