Skip to content

Commit

Permalink
Add kwargs and multidimensional support
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent 03e2318 commit 6e2b0bf
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 82 deletions.
4 changes: 3 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
style = "blue"
style = "blue"
format_docstrings = true
format_markdown = true
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Documenter", "JET", "JuliaFormatter", "Random", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["Aqua", "Documenter", "JET", "JuliaFormatter", "LinearAlgebra", "Random", "StableRNGs", "Statistics", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ It allows the computation of approximate derivatives with respect to $\theta$ th

For more details, refer to the following paper:

> [Monte-Carlo Gradient Estimation in Machine Learning](https://www.jmlr.org/papers/v21/19-346.html), Mohamed et al. (2020)
> [Monte-Carlo Gradient Estimation in Machine Learning](https://www.jmlr.org/papers/v21/19-346.html), Mohamed et al. (2020)
3 changes: 1 addition & 2 deletions src/DifferentiableExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using ChainRulesCore:
rrule_via_ad,
unthunk
using DensityInterface: logdensityof
using Distributions: Distribution
using Distributions: Distribution, gradlogpdf
using DocStringExtensions
using LinearAlgebra: dot
using OhMyThreads: tmap, treduce, tmapreduce
Expand All @@ -32,7 +32,6 @@ include("reparametrization.jl")
include("pushforward.jl")

export DifferentiableExpectation
export samples, distribution
export REINFORCE

end # module DifferentiableExpectations
86 changes: 45 additions & 41 deletions src/abstract.jl
Original file line number Diff line number Diff line change
@@ -1,74 +1,78 @@
"""
DifferentiableExpectation{threaded,D}
DifferentiableExpectation{threaded}
Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages.
# Type parameters
- `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl)
- `D::Type`: the type of the probability distribution, such that calling `D(θ...)` generates a sampleable object corresponding to `p(θ)`
- `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl)
# Required fields
- `f`: the function applied inside the expectation
- `rng`: the random number generator
- `nb_samples`: the number of Monte-Carlo samples
- `f`: the function applied inside the expectation
- `dist_constructor`: the constructor of the probability distribution, such that calling `D(θ...)` generates an object corresponding to `p(θ)`
- `rng`: the random number generator
- `nb_samples`: the number of Monte-Carlo samples
"""
abstract type DifferentiableExpectation{threaded,D} end
abstract type DifferentiableExpectation{threaded} end

"""
distribution(F::DifferentiableExpectation, θ...)
presamples(F::DifferentiableExpectation, θ...)
Create a sampleable object `p(θ)`.
Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ ∼ p(θ)` are iid samples.
"""
function distribution(::DifferentiableExpectation{threaded,D}, θ...) where {threaded,D}
return D...)
function presamples(F::DifferentiableExpectation, θ...)
(; dist_constructor, rng, nb_samples) = F
dist = dist_constructor...)
xs = rand(rng, dist, nb_samples) # TODO: parallelize?
return xs
end

"""
(F::DifferentiableExpectation)(θ...)
samples(F::DifferentiableExpectation, θ...; kwargs...)
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples.
"""
function (F::DifferentiableExpectation{threaded})(θ...) where {threaded}
dist = distribution(F, θ...)
_sample(_) = F.f(rand(F.rng, dist))
y = if threaded
tmapmean(_sample, 1:(F.nb_samples))
function samples(F::DifferentiableExpectation{threaded}, θ...; kwargs...) where {threaded}
xs = presamples(F, θ...)
return samples_from_presamples(F, xs; kwargs...)
end

function samples_from_presamples(
F::DifferentiableExpectation{threaded}, xs::AbstractVector; kwargs...
) where {threaded}
(; f) = F
fk = FixKwargs(f, kwargs)
if threaded
return tmap(fk, xs)
else
mean(_sample, 1:(F.nb_samples))
return map(fk, xs)
end
return y
end

"""
pre_samples(F::DifferentiableExpectation, θ...)
Return a vector `[x₁, ..., xₛ]` where the `xᵢ ∼ p(θ)` are iid samples.
"""
function pre_samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded}
dist = distribution(F, θ...)
_pre_sample(_) = rand(F.rng, dist)
xs = if threaded
tmap(_pre_sample, 1:(F.nb_samples))
function samples_from_presamples(
F::DifferentiableExpectation{threaded}, xs::AbstractMatrix; kwargs...
) where {threaded}
(; f) = F
fk = FixKwargs(f, kwargs)
if threaded
return tmap(fk, eachcol(xs))
else
map(_pre_sample, 1:(F.nb_samples))
return map(fk, eachcol(xs))
end
return xs
end

"""
samples(F::DifferentiableExpectation, θ...)
(F::DifferentiableExpectation)(θ...; kwargs...)
Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples.
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
"""
function samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded}
dist = distribution(F, θ...)
_sample(_) = F.f(rand(F.rng, dist))
ys = if threaded
map(_sample, 1:(F.nb_samples)) # TODO: tmap fails here
function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded}
ys = samples(F, θ...; kwargs...)
y = if threaded
tmean(ys)
else
map(_sample, 1:(F.nb_samples))
mean(ys)
end
return ys
return y
end
98 changes: 73 additions & 25 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
REINFORCE <: DifferentiableExpectation
REINFORCE{threaded} <: DifferentiableExpectation{threaded}
Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)` using the REINFORCE (or score function) gradient estimator:
```
Expand All @@ -8,7 +8,14 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`
# Constructor
REINFORCE(; f, dist_type::Type, rng::AbstractRNG, nb_samples::Integer, threaded::Bool)
REINFORCE(
f,
dist_constructor,
dist_gradlogpdf=nothing;
rng=Random.default_rng(),
nb_samples=1,
threaded=false
)
# Fields
Expand All @@ -18,41 +25,80 @@ $(TYPEDFIELDS)
- [`DifferentiableExpectation`](@ref)
"""
struct REINFORCE{threaded,D,F,R<:AbstractRNG} <: DifferentiableExpectation{threaded,D}
struct REINFORCE{threaded,F,D,G,R<:AbstractRNG} <: DifferentiableExpectation{threaded}
f::F
dist_constructor::D
dist_logdensity_grad::G
rng::R
nb_samples::Int
end

"""
REINFORCE(
::Type{D}, f;
rng::AbstractRNG=default_rng(),
nb_samples::Integer=1,
threaded::Bool=false
)
Constructor for [`REINFORCE`](@ref).
"""
function REINFORCE(
::Type{D}, f::F; rng::R=default_rng(), nb_samples=1, threaded=false
) where {F,D,R}
return REINFORCE{threaded,D,F,R}(f, rng, nb_samples)
f::F,
dist_constructor::D,
dist_logdensity_grad::G=nothing;
rng::R=default_rng(),
nb_samples=1,
threaded=false,
) where {F,D,G,R}
return REINFORCE{threaded,F,D,G,R}(
f, dist_constructor, dist_logdensity_grad, rng, nb_samples
)
end

function logdensity_grad(rc::RuleConfig, F::REINFORCE{threaded}, x, θ...) where {threaded}
_logdensity_partial(_θ...) = logdensityof(distribution(F, _θ...), x)
l, pullback = rrule_via_ad(rc, _logdensity_partial, θ...)
= Base.tail(pullback(one(l)))
(; dist_constructor, dist_logdensity_grad) = F
if !isnothing(dist_logdensity_grad)
= dist_logdensity_grad...)
else
# TODO: add Distributions.gradlogpdf
_logdensity_partial(_θ...) = logdensityof(dist_constructor(_θ...), x)
l, pullback = rrule_via_ad(rc, _logdensity_partial, θ...)
= Base.tail(pullback(one(l)))
end
return
end

function ChainRulesCore.rrule(rc::RuleConfig, F::REINFORCE{threaded}, θ...) where {threaded}
(; nb_samples) = F
function logdensity_grads_from_presamples(
rc::RuleConfig,
F::DifferentiableExpectation{threaded},
xs::AbstractVector,
θ...;
kwargs...,
) where {threaded}
_logdensity_grad_partial(x) = logdensity_grad(rc, F, x, θ...)
xs = pre_samples(F, θ...)
ys = threaded ? tmap(F.f, xs) : map(F.f, xs)
gs = threaded ? tmap(_logdensity_grad_partial, xs) : map(_logdensity_grad_partial, xs)
if threaded
return tmap(_logdensity_grad_partial, xs)
else
return map(_logdensity_grad_partial, xs)
end
end

function logdensity_grads_from_presamples(
rc::RuleConfig,
F::DifferentiableExpectation{threaded},
xs::AbstractMatrix,
θ...;
kwargs...,
) where {threaded}
_logdensity_grad_partial(x) = logdensity_grad(rc, F, x, θ...)
if threaded
return tmap(_logdensity_grad_partial, eachcol(xs))
else
return map(_logdensity_grad_partial, eachcol(xs))
end
end

function ChainRulesCore.rrule(
rc::RuleConfig, F::REINFORCE{threaded}, θ...; kwargs...
) where {threaded}
project_θ = ProjectTo(θ)

(; nb_samples) = F
xs = presamples(F, θ...)
ys = samples_from_presamples(F, xs; kwargs...)
gs = logdensity_grads_from_presamples(rc, F, xs, θ...)

function REINFORCE_pullback(dy_thunked)
dy = unthunk(dy_thunked)
dF = @not_implemented(
Expand All @@ -64,8 +110,10 @@ function ChainRulesCore.rrule(rc::RuleConfig, F::REINFORCE{threaded}, θ...) whe
else
mapreduce(_single_sample_pullback, .+, gs, ys) ./ nb_samples
end
return (dF, dθ...)
dθ_proj = project_θ(dθ)
return (dF, dθ_proj...)
end

y = threaded ? tmean(ys) : mean(ys)
return y, REINFORCE_pullback
end
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
struct FixKwargs{F,K}
f::F
kwargs::K
end

function (fk::FixKwargs)(args...)
return fk.f(args...; fk.kwargs...)
end

function tmapmean(f, args...)
return tmapreduce(f, +, args...) / length(first(args))
end
Expand Down
51 changes: 40 additions & 11 deletions test/reinforce.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,60 @@
using Distributions
using DifferentiableExpectations
using DifferentiableExpectations: samples
using LinearAlgebra
using Random
using StableRNGs
using Statistics
using Test
using Zygote

exp_with_kwargs(x; correct=false) = correct ? exp(x) : sin(x)
vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct)

@testset "Univariate LogNormal" begin
for threaded in (false, true)
F = REINFORCE(Normal, exp; rng=StableRNG(63), nb_samples=10^5, threaded=threaded)
F = REINFORCE(
exp_with_kwargs, Normal; rng=StableRNG(63), nb_samples=10^4, threaded=threaded
)

μ, σ = 2.0, 1.0
true_mean(μ, σ) = mean(LogNormal(μ, σ))
true_std(μ, σ) = std(LogNormal(μ, σ))

@test distribution(F, μ, σ) == Normal(μ, σ)
@test F(μ, σ) true_mean(μ, σ) rtol = 0.1
@test std(samples(F, μ, σ)) true_std(μ, σ) rtol = 0.1
@test F.dist_constructor(μ, σ) == Normal(μ, σ)
@test F(μ, σ; correct=true) true_mean(μ, σ) rtol = 0.1
@test std(samples(F, μ, σ; correct=true)) true_std(μ, σ) rtol = 0.1

∇mean_est = gradient(F, μ, σ)
∇mean_est = gradient((μ, σ) -> F(μ, σ; correct=true), μ, σ)
∇mean_true = gradient(true_mean, μ, σ)

∇std_est = gradient((_μ, _σ) -> std(samples(F, _μ, _σ)), μ, σ)
∇std_true = gradient(true_std, μ, σ)
@test ∇mean_est[1] ∇mean_true[1] rtol = 0.2
@test ∇mean_est[2] ∇mean_true[2] rtol = 0.2
end
end

@testset "Multivariate LogNormal" begin
for threaded in (false, true)
F = REINFORCE(
vec_exp_with_kwargs,
(μ, σ) -> MvNormal(μ, Diagonal.^ 2));
rng=StableRNG(63),
nb_samples=10^4,
threaded=threaded,
)

dim = 2
μ, σ = randn(dim), rand(dim)
true_mean(μ, σ) = mean.(LogNormal.(μ, σ))
true_std(μ, σ) = std.(LogNormal.(μ, σ))

@test F.dist_constructor(μ, σ) == MvNormal(μ, Diagonal.^ 2))
@test F(μ, σ; correct=true) true_mean(μ, σ) rtol = 0.1

∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ)
∂mean_true = jacobian(true_mean, μ, σ)

for i in 1:2
@test ∇mean_est[i] ∇mean_true[i] rtol = 0.2
@test ∇std_est[i] ∇std_true[i] rtol = 0.2
end
@test ∂mean_est[1] ∂mean_true[1] rtol = 0.2
@test ∂mean_est[2] ∂mean_true[2] rtol = 0.2
end
end

0 comments on commit 6e2b0bf

Please sign in to comment.