diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index c743950..74e08f7 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1,3 @@ -style = "blue" \ No newline at end of file +style = "blue" +format_docstrings = true +format_markdown = true \ No newline at end of file diff --git a/Project.toml b/Project.toml index 806aece..e21fb70 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -35,4 +36,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Documenter", "JET", "JuliaFormatter", "Random", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["Aqua", "Documenter", "JET", "JuliaFormatter", "LinearAlgebra", "Random", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/README.md b/README.md index be022c3..e11f4e3 100644 --- a/README.md +++ b/README.md @@ -15,4 +15,4 @@ It allows the computation of approximate derivatives with respect to $\theta$ th For more details, refer to the following paper: -> [Monte-Carlo Gradient Estimation in Machine Learning](https://www.jmlr.org/papers/v21/19-346.html), Mohamed et al. (2020) \ No newline at end of file +> [Monte-Carlo Gradient Estimation in Machine Learning](https://www.jmlr.org/papers/v21/19-346.html), Mohamed et al. (2020) diff --git a/src/DifferentiableExpectations.jl b/src/DifferentiableExpectations.jl index 7f6cee1..77306ca 100644 --- a/src/DifferentiableExpectations.jl +++ b/src/DifferentiableExpectations.jl @@ -18,7 +18,7 @@ using ChainRulesCore: rrule_via_ad, unthunk using DensityInterface: logdensityof -using Distributions: Distribution +using Distributions: Distribution, gradlogpdf using DocStringExtensions using LinearAlgebra: dot using OhMyThreads: tmap, treduce, tmapreduce @@ -32,7 +32,6 @@ include("reparametrization.jl") include("pushforward.jl") export DifferentiableExpectation -export samples, distribution export REINFORCE end # module DifferentiableExpectations diff --git a/src/abstract.jl b/src/abstract.jl index f0a7b1f..5754e4d 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -1,74 +1,78 @@ """ - DifferentiableExpectation{threaded,D} + DifferentiableExpectation{threaded} Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages. # Type parameters -- `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl) -- `D::Type`: the type of the probability distribution, such that calling `D(θ...)` generates a sampleable object corresponding to `p(θ)` + - `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl) # Required fields -- `f`: the function applied inside the expectation -- `rng`: the random number generator -- `nb_samples`: the number of Monte-Carlo samples + - `f`: the function applied inside the expectation + - `dist_constructor`: the constructor of the probability distribution, such that calling `D(θ...)` generates an object corresponding to `p(θ)` + - `rng`: the random number generator + - `nb_samples`: the number of Monte-Carlo samples """ -abstract type DifferentiableExpectation{threaded,D} end +abstract type DifferentiableExpectation{threaded} end """ - distribution(F::DifferentiableExpectation, θ...) + presamples(F::DifferentiableExpectation, θ...) -Create a sampleable object `p(θ)`. +Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ ∼ p(θ)` are iid samples. """ -function distribution(::DifferentiableExpectation{threaded,D}, θ...) where {threaded,D} - return D(θ...) +function presamples(F::DifferentiableExpectation, θ...) + (; dist_constructor, rng, nb_samples) = F + dist = dist_constructor(θ...) + xs = rand(rng, dist, nb_samples) # TODO: parallelize? + return xs end """ - (F::DifferentiableExpectation)(θ...) + samples(F::DifferentiableExpectation, θ...; kwargs...) -Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. +Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples. """ -function (F::DifferentiableExpectation{threaded})(θ...) where {threaded} - dist = distribution(F, θ...) - _sample(_) = F.f(rand(F.rng, dist)) - y = if threaded - tmapmean(_sample, 1:(F.nb_samples)) +function samples(F::DifferentiableExpectation{threaded}, θ...; kwargs...) where {threaded} + xs = presamples(F, θ...) + return samples_from_presamples(F, xs; kwargs...) +end + +function samples_from_presamples( + F::DifferentiableExpectation{threaded}, xs::AbstractVector; kwargs... +) where {threaded} + (; f) = F + fk = FixKwargs(f, kwargs) + if threaded + return tmap(fk, xs) else - mean(_sample, 1:(F.nb_samples)) + return map(fk, xs) end - return y end -""" - pre_samples(F::DifferentiableExpectation, θ...) - -Return a vector `[x₁, ..., xₛ]` where the `xᵢ ∼ p(θ)` are iid samples. -""" -function pre_samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded} - dist = distribution(F, θ...) - _pre_sample(_) = rand(F.rng, dist) - xs = if threaded - tmap(_pre_sample, 1:(F.nb_samples)) +function samples_from_presamples( + F::DifferentiableExpectation{threaded}, xs::AbstractMatrix; kwargs... +) where {threaded} + (; f) = F + fk = FixKwargs(f, kwargs) + if threaded + return tmap(fk, eachcol(xs)) else - map(_pre_sample, 1:(F.nb_samples)) + return map(fk, eachcol(xs)) end - return xs end """ - samples(F::DifferentiableExpectation, θ...) + (F::DifferentiableExpectation)(θ...; kwargs...) -Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples. +Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. """ -function samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded} - dist = distribution(F, θ...) - _sample(_) = F.f(rand(F.rng, dist)) - ys = if threaded - map(_sample, 1:(F.nb_samples)) # TODO: tmap fails here +function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded} + ys = samples(F, θ...; kwargs...) + y = if threaded + tmean(ys) else - map(_sample, 1:(F.nb_samples)) + mean(ys) end - return ys + return y end diff --git a/src/reinforce.jl b/src/reinforce.jl index 1a94d41..15e0806 100644 --- a/src/reinforce.jl +++ b/src/reinforce.jl @@ -1,5 +1,5 @@ """ - REINFORCE <: DifferentiableExpectation + REINFORCE{threaded} <: DifferentiableExpectation{threaded} Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)` using the REINFORCE (or score function) gradient estimator: ``` @@ -8,7 +8,14 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)` # Constructor - REINFORCE(; f, dist_type::Type, rng::AbstractRNG, nb_samples::Integer, threaded::Bool) + REINFORCE( + f, + dist_constructor, + dist_gradlogpdf=nothing; + rng=Random.default_rng(), + nb_samples=1, + threaded=false + ) # Fields @@ -18,41 +25,80 @@ $(TYPEDFIELDS) - [`DifferentiableExpectation`](@ref) """ -struct REINFORCE{threaded,D,F,R<:AbstractRNG} <: DifferentiableExpectation{threaded,D} +struct REINFORCE{threaded,F,D,G,R<:AbstractRNG} <: DifferentiableExpectation{threaded} f::F + dist_constructor::D + dist_logdensity_grad::G rng::R nb_samples::Int end -""" - REINFORCE( - ::Type{D}, f; - rng::AbstractRNG=default_rng(), - nb_samples::Integer=1, - threaded::Bool=false - ) - -Constructor for [`REINFORCE`](@ref). -""" function REINFORCE( - ::Type{D}, f::F; rng::R=default_rng(), nb_samples=1, threaded=false -) where {F,D,R} - return REINFORCE{threaded,D,F,R}(f, rng, nb_samples) + f::F, + dist_constructor::D, + dist_logdensity_grad::G=nothing; + rng::R=default_rng(), + nb_samples=1, + threaded=false, +) where {F,D,G,R} + return REINFORCE{threaded,F,D,G,R}( + f, dist_constructor, dist_logdensity_grad, rng, nb_samples + ) end function logdensity_grad(rc::RuleConfig, F::REINFORCE{threaded}, x, θ...) where {threaded} - _logdensity_partial(_θ...) = logdensityof(distribution(F, _θ...), x) - l, pullback = rrule_via_ad(rc, _logdensity_partial, θ...) - dθ = Base.tail(pullback(one(l))) + (; dist_constructor, dist_logdensity_grad) = F + if !isnothing(dist_logdensity_grad) + dθ = dist_logdensity_grad(θ...) + else + # TODO: add Distributions.gradlogpdf + _logdensity_partial(_θ...) = logdensityof(dist_constructor(_θ...), x) + l, pullback = rrule_via_ad(rc, _logdensity_partial, θ...) + dθ = Base.tail(pullback(one(l))) + end return dθ end -function ChainRulesCore.rrule(rc::RuleConfig, F::REINFORCE{threaded}, θ...) where {threaded} - (; nb_samples) = F +function logdensity_grads_from_presamples( + rc::RuleConfig, + F::DifferentiableExpectation{threaded}, + xs::AbstractVector, + θ...; + kwargs..., +) where {threaded} _logdensity_grad_partial(x) = logdensity_grad(rc, F, x, θ...) - xs = pre_samples(F, θ...) - ys = threaded ? tmap(F.f, xs) : map(F.f, xs) - gs = threaded ? tmap(_logdensity_grad_partial, xs) : map(_logdensity_grad_partial, xs) + if threaded + return tmap(_logdensity_grad_partial, xs) + else + return map(_logdensity_grad_partial, xs) + end +end + +function logdensity_grads_from_presamples( + rc::RuleConfig, + F::DifferentiableExpectation{threaded}, + xs::AbstractMatrix, + θ...; + kwargs..., +) where {threaded} + _logdensity_grad_partial(x) = logdensity_grad(rc, F, x, θ...) + if threaded + return tmap(_logdensity_grad_partial, eachcol(xs)) + else + return map(_logdensity_grad_partial, eachcol(xs)) + end +end + +function ChainRulesCore.rrule( + rc::RuleConfig, F::REINFORCE{threaded}, θ...; kwargs... +) where {threaded} + project_θ = ProjectTo(θ) + + (; nb_samples) = F + xs = presamples(F, θ...) + ys = samples_from_presamples(F, xs; kwargs...) + gs = logdensity_grads_from_presamples(rc, F, xs, θ...) + function REINFORCE_pullback(dy_thunked) dy = unthunk(dy_thunked) dF = @not_implemented( @@ -64,8 +110,10 @@ function ChainRulesCore.rrule(rc::RuleConfig, F::REINFORCE{threaded}, θ...) whe else mapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples end - return (dF, dθ...) + dθ_proj = project_θ(dθ) + return (dF, dθ_proj...) end + y = threaded ? tmean(ys) : mean(ys) return y, REINFORCE_pullback end diff --git a/src/utils.jl b/src/utils.jl index dca358c..b277376 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,12 @@ +struct FixKwargs{F,K} + f::F + kwargs::K +end + +function (fk::FixKwargs)(args...) + return fk.f(args...; fk.kwargs...) +end + function tmapmean(f, args...) return tmapreduce(f, +, args...) / length(first(args)) end diff --git a/test/reinforce.jl b/test/reinforce.jl index c9f9108..bfaea40 100644 --- a/test/reinforce.jl +++ b/test/reinforce.jl @@ -1,31 +1,60 @@ using Distributions using DifferentiableExpectations +using DifferentiableExpectations: samples +using LinearAlgebra using Random using StableRNGs using Statistics using Test using Zygote +exp_with_kwargs(x; correct=false) = correct ? exp(x) : sin(x) +vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct) + @testset "Univariate LogNormal" begin for threaded in (false, true) - F = REINFORCE(Normal, exp; rng=StableRNG(63), nb_samples=10^5, threaded=threaded) + F = REINFORCE( + exp_with_kwargs, Normal; rng=StableRNG(63), nb_samples=10^4, threaded=threaded + ) + μ, σ = 2.0, 1.0 true_mean(μ, σ) = mean(LogNormal(μ, σ)) true_std(μ, σ) = std(LogNormal(μ, σ)) - @test distribution(F, μ, σ) == Normal(μ, σ) - @test F(μ, σ) ≈ true_mean(μ, σ) rtol = 0.1 - @test std(samples(F, μ, σ)) ≈ true_std(μ, σ) rtol = 0.1 + @test F.dist_constructor(μ, σ) == Normal(μ, σ) + @test F(μ, σ; correct=true) ≈ true_mean(μ, σ) rtol = 0.1 + @test std(samples(F, μ, σ; correct=true)) ≈ true_std(μ, σ) rtol = 0.1 - ∇mean_est = gradient(F, μ, σ) + ∇mean_est = gradient((μ, σ) -> F(μ, σ; correct=true), μ, σ) ∇mean_true = gradient(true_mean, μ, σ) - ∇std_est = gradient((_μ, _σ) -> std(samples(F, _μ, _σ)), μ, σ) - ∇std_true = gradient(true_std, μ, σ) + @test ∇mean_est[1] ≈ ∇mean_true[1] rtol = 0.2 + @test ∇mean_est[2] ≈ ∇mean_true[2] rtol = 0.2 + end +end + +@testset "Multivariate LogNormal" begin + for threaded in (false, true) + F = REINFORCE( + vec_exp_with_kwargs, + (μ, σ) -> MvNormal(μ, Diagonal(σ .^ 2)); + rng=StableRNG(63), + nb_samples=10^4, + threaded=threaded, + ) + + dim = 2 + μ, σ = randn(dim), rand(dim) + true_mean(μ, σ) = mean.(LogNormal.(μ, σ)) + true_std(μ, σ) = std.(LogNormal.(μ, σ)) + + @test F.dist_constructor(μ, σ) == MvNormal(μ, Diagonal(σ .^ 2)) + @test F(μ, σ; correct=true) ≈ true_mean(μ, σ) rtol = 0.1 + + ∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ) + ∂mean_true = jacobian(true_mean, μ, σ) - for i in 1:2 - @test ∇mean_est[i] ≈ ∇mean_true[i] rtol = 0.2 - @test ∇std_est[i] ≈ ∇std_true[i] rtol = 0.2 - end + @test ∂mean_est[1] ≈ ∂mean_true[1] rtol = 0.2 + @test ∂mean_est[2] ≈ ∂mean_true[2] rtol = 0.2 end end