-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add kwargs and multidimensional support
- Loading branch information
Showing
8 changed files
with
174 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
style = "blue" | ||
style = "blue" | ||
format_docstrings = true | ||
format_markdown = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,78 @@ | ||
""" | ||
DifferentiableExpectation{threaded,D} | ||
DifferentiableExpectation{threaded} | ||
Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages. | ||
# Type parameters | ||
- `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl) | ||
- `D::Type`: the type of the probability distribution, such that calling `D(θ...)` generates a sampleable object corresponding to `p(θ)` | ||
- `threaded::Bool`: specifies whether the sampling should be performed in parallel (with OhMyThreads.jl) | ||
# Required fields | ||
- `f`: the function applied inside the expectation | ||
- `rng`: the random number generator | ||
- `nb_samples`: the number of Monte-Carlo samples | ||
- `f`: the function applied inside the expectation | ||
- `dist_constructor`: the constructor of the probability distribution, such that calling `D(θ...)` generates an object corresponding to `p(θ)` | ||
- `rng`: the random number generator | ||
- `nb_samples`: the number of Monte-Carlo samples | ||
""" | ||
abstract type DifferentiableExpectation{threaded,D} end | ||
abstract type DifferentiableExpectation{threaded} end | ||
|
||
""" | ||
distribution(F::DifferentiableExpectation, θ...) | ||
presamples(F::DifferentiableExpectation, θ...) | ||
Create a sampleable object `p(θ)`. | ||
Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ ∼ p(θ)` are iid samples. | ||
""" | ||
function distribution(::DifferentiableExpectation{threaded,D}, θ...) where {threaded,D} | ||
return D(θ...) | ||
function presamples(F::DifferentiableExpectation, θ...) | ||
(; dist_constructor, rng, nb_samples) = F | ||
dist = dist_constructor(θ...) | ||
xs = rand(rng, dist, nb_samples) # TODO: parallelize? | ||
return xs | ||
end | ||
|
||
""" | ||
(F::DifferentiableExpectation)(θ...) | ||
samples(F::DifferentiableExpectation, θ...; kwargs...) | ||
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. | ||
Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples. | ||
""" | ||
function (F::DifferentiableExpectation{threaded})(θ...) where {threaded} | ||
dist = distribution(F, θ...) | ||
_sample(_) = F.f(rand(F.rng, dist)) | ||
y = if threaded | ||
tmapmean(_sample, 1:(F.nb_samples)) | ||
function samples(F::DifferentiableExpectation{threaded}, θ...; kwargs...) where {threaded} | ||
xs = presamples(F, θ...) | ||
return samples_from_presamples(F, xs; kwargs...) | ||
end | ||
|
||
function samples_from_presamples( | ||
F::DifferentiableExpectation{threaded}, xs::AbstractVector; kwargs... | ||
) where {threaded} | ||
(; f) = F | ||
fk = FixKwargs(f, kwargs) | ||
if threaded | ||
return tmap(fk, xs) | ||
else | ||
mean(_sample, 1:(F.nb_samples)) | ||
return map(fk, xs) | ||
end | ||
return y | ||
end | ||
|
||
""" | ||
pre_samples(F::DifferentiableExpectation, θ...) | ||
Return a vector `[x₁, ..., xₛ]` where the `xᵢ ∼ p(θ)` are iid samples. | ||
""" | ||
function pre_samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded} | ||
dist = distribution(F, θ...) | ||
_pre_sample(_) = rand(F.rng, dist) | ||
xs = if threaded | ||
tmap(_pre_sample, 1:(F.nb_samples)) | ||
function samples_from_presamples( | ||
F::DifferentiableExpectation{threaded}, xs::AbstractMatrix; kwargs... | ||
) where {threaded} | ||
(; f) = F | ||
fk = FixKwargs(f, kwargs) | ||
if threaded | ||
return tmap(fk, eachcol(xs)) | ||
else | ||
map(_pre_sample, 1:(F.nb_samples)) | ||
return map(fk, eachcol(xs)) | ||
end | ||
return xs | ||
end | ||
|
||
""" | ||
samples(F::DifferentiableExpectation, θ...) | ||
(F::DifferentiableExpectation)(θ...; kwargs...) | ||
Return a vector `[f(x₁), ..., f(xₛ)]` where the `xᵢ ∼ p(θ)` are iid samples. | ||
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. | ||
""" | ||
function samples(F::DifferentiableExpectation{threaded}, θ...) where {threaded} | ||
dist = distribution(F, θ...) | ||
_sample(_) = F.f(rand(F.rng, dist)) | ||
ys = if threaded | ||
map(_sample, 1:(F.nb_samples)) # TODO: tmap fails here | ||
function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded} | ||
ys = samples(F, θ...; kwargs...) | ||
y = if threaded | ||
tmean(ys) | ||
else | ||
map(_sample, 1:(F.nb_samples)) | ||
mean(ys) | ||
end | ||
return ys | ||
return y | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,60 @@ | ||
using Distributions | ||
using DifferentiableExpectations | ||
using DifferentiableExpectations: samples | ||
using LinearAlgebra | ||
using Random | ||
using StableRNGs | ||
using Statistics | ||
using Test | ||
using Zygote | ||
|
||
exp_with_kwargs(x; correct=false) = correct ? exp(x) : sin(x) | ||
vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct) | ||
|
||
@testset "Univariate LogNormal" begin | ||
for threaded in (false, true) | ||
F = REINFORCE(Normal, exp; rng=StableRNG(63), nb_samples=10^5, threaded=threaded) | ||
F = REINFORCE( | ||
exp_with_kwargs, Normal; rng=StableRNG(63), nb_samples=10^4, threaded=threaded | ||
) | ||
|
||
μ, σ = 2.0, 1.0 | ||
true_mean(μ, σ) = mean(LogNormal(μ, σ)) | ||
true_std(μ, σ) = std(LogNormal(μ, σ)) | ||
|
||
@test distribution(F, μ, σ) == Normal(μ, σ) | ||
@test F(μ, σ) ≈ true_mean(μ, σ) rtol = 0.1 | ||
@test std(samples(F, μ, σ)) ≈ true_std(μ, σ) rtol = 0.1 | ||
@test F.dist_constructor(μ, σ) == Normal(μ, σ) | ||
@test F(μ, σ; correct=true) ≈ true_mean(μ, σ) rtol = 0.1 | ||
@test std(samples(F, μ, σ; correct=true)) ≈ true_std(μ, σ) rtol = 0.1 | ||
|
||
∇mean_est = gradient(F, μ, σ) | ||
∇mean_est = gradient((μ, σ) -> F(μ, σ; correct=true), μ, σ) | ||
∇mean_true = gradient(true_mean, μ, σ) | ||
|
||
∇std_est = gradient((_μ, _σ) -> std(samples(F, _μ, _σ)), μ, σ) | ||
∇std_true = gradient(true_std, μ, σ) | ||
@test ∇mean_est[1] ≈ ∇mean_true[1] rtol = 0.2 | ||
@test ∇mean_est[2] ≈ ∇mean_true[2] rtol = 0.2 | ||
end | ||
end | ||
|
||
@testset "Multivariate LogNormal" begin | ||
for threaded in (false, true) | ||
F = REINFORCE( | ||
vec_exp_with_kwargs, | ||
(μ, σ) -> MvNormal(μ, Diagonal(σ .^ 2)); | ||
rng=StableRNG(63), | ||
nb_samples=10^4, | ||
threaded=threaded, | ||
) | ||
|
||
dim = 2 | ||
μ, σ = randn(dim), rand(dim) | ||
true_mean(μ, σ) = mean.(LogNormal.(μ, σ)) | ||
true_std(μ, σ) = std.(LogNormal.(μ, σ)) | ||
|
||
@test F.dist_constructor(μ, σ) == MvNormal(μ, Diagonal(σ .^ 2)) | ||
@test F(μ, σ; correct=true) ≈ true_mean(μ, σ) rtol = 0.1 | ||
|
||
∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ) | ||
∂mean_true = jacobian(true_mean, μ, σ) | ||
|
||
for i in 1:2 | ||
@test ∇mean_est[i] ≈ ∇mean_true[i] rtol = 0.2 | ||
@test ∇std_est[i] ≈ ∇std_true[i] rtol = 0.2 | ||
end | ||
@test ∂mean_est[1] ≈ ∂mean_true[1] rtol = 0.2 | ||
@test ∂mean_est[2] ≈ ∂mean_true[2] rtol = 0.2 | ||
end | ||
end |