diff --git a/Project.toml b/Project.toml index 4abd2a6..cb287d0 100644 --- a/Project.toml +++ b/Project.toml @@ -10,13 +10,15 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1.23" Distributions = "0.25" DocStringExtensions = "0.9" -OhMyThreads = "0.5" LinearAlgebra = "1" +OhMyThreads = "0.5" +Statistics = "1" julia = "1.10" [extras] @@ -26,6 +28,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] @@ -36,5 +39,6 @@ test = [ "JuliaFormatter", "Random", "StableRNGs", + "Statistics", "Test", ] diff --git a/src/DifferentiableExpectations.jl b/src/DifferentiableExpectations.jl index 91e50ab..53d8f78 100644 --- a/src/DifferentiableExpectations.jl +++ b/src/DifferentiableExpectations.jl @@ -3,12 +3,16 @@ module DifferentiableExpectations using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using Distributions: Distribution using DocStringExtensions -using OhMyThreads: tmap -using Random: AbstractRNG +using OhMyThreads: tmap, treduce, tmapreduce +using Random: AbstractRNG, default_rng +using Statistics: mean include("abstract.jl") include("reinforce.jl") include("reparametrization.jl") include("pushforward.jl") +export distribution, samples +export REINFORCE + end # module DifferentiableExpectations diff --git a/src/abstract.jl b/src/abstract.jl index b47cd43..ce79090 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -3,11 +3,50 @@ Abstract supertype for expectation wrappers. -The type parameter `threaded` is a `Bool` stating whether or not the Monte-Carlo samples should be computed in parallel. +The type parameter `threaded` is a `Bool` stating whether the Monte-Carlo samples should be computed in parallel. +""" +abstract type AbstractExpectation{threaded} end + +""" + distribution(e::AbstractExpectation, θ...) -Implementing subtypes must have the following fields: +Build the sampling distribution for `e` based on parameters `\theta`. +""" +function distribution end -- `rng::AbstractRNG`: the random number generator -- `nb_samples::Integer`: the number of samples to draw """ -abstract type AbstractExpectation{threaded} end + (e::AbstractExpectation)(θ...) + +Return the Monte-Carlo average of the function represented by `e` over several samples of `distribution(e, θ...)`. +""" +function (e::AbstractExpectation{threaded})(θ...) where {threaded} + dist = distribution(e, θ...) + s = if threaded + tmapreduce(+, 1:(e.nb_samples)) do _ + e.f(rand(e.rng, dist)) + end + else + mapreduce(+, 1:(e.nb_samples)) do _ + e.f(rand(e.rng, dist)) + end + end + return s / e.nb_samples +end + +""" + samples(e::AbstractExpectation, θ...) + +Return the values of the function represented by `e` for several samples of `distribution(e, θ...)`. +""" +function samples(e::AbstractExpectation{threaded}, θ...) where {threaded} + dist = distribution(e, θ...) + if threaded + return tmap(1:(e.nb_samples)) do _ + e.f(rand(e.rng, dist)) + end + else + return map(1:(e.nb_samples)) do _ + e.f(rand(e.rng, dist)) + end + end +end diff --git a/src/reinforce.jl b/src/reinforce.jl index 8b13789..254e124 100644 --- a/src/reinforce.jl +++ b/src/reinforce.jl @@ -1 +1,13 @@ +struct REINFORCE{t,F,D,R<:AbstractRNG} <: AbstractExpectation{t} + f::F + rng::R + nb_samples::Int +end +function REINFORCE(; + f::F, dist_type::Type{D}, rng::R=default_rng(), nb_samples=1, threaded=true +) where {F,D,R} + return REINFORCE{threaded,F,D,R}(f, rng, nb_samples) +end + +distribution(::REINFORCE{threaded,F,D}, θ...) where {threaded,F,D} = D(θ...) diff --git a/test/reinforce.jl b/test/reinforce.jl index 8b13789..bc2ba85 100644 --- a/test/reinforce.jl +++ b/test/reinforce.jl @@ -1 +1,12 @@ +using Distributions +using DifferentiableExpectations +using Random +using StableRNGs +using Statistics +using Test +e = REINFORCE(; f=exp, dist_type=Normal, rng=StableRNG(63), nb_samples=10^3, threaded=false) +μ, σ = 2.0, 1.0 +@test distribution(e, μ, σ) == Normal(μ, σ) +@test e(μ, σ) ≈ mean(LogNormal(μ, σ)) rtol = 0.1 +@test std(samples(e, μ, σ)) ≈ std(LogNormal(μ, σ)) rtol = 0.1