Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent a517845 commit 4918648
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 68 deletions.
3 changes: 1 addition & 2 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
style = "blue"
format_docstrings = true
format_markdown = true
format_docstrings = true
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ F(\theta) = \mathbb{E}_{p(\theta)}[f(X)]

The following estimators are implemented:

- [REINFORCE](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.20)
- [Reparametrization](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.56)
- [REINFORCE](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.20)
- [Reparametrization](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.56)

> Warning: this package is experimental, use at your own risk and expect frequent breaking releases.
8 changes: 4 additions & 4 deletions docs/src/background.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Most of the math below is taken from [mohamedMonteCarloGradient2020](@citet).

## REINFORCE

### Basics
### Principle

The REINFORCE estimator is derived with the help of the identity ``\nabla \log u = \nabla u / u``:

Expand Down Expand Up @@ -45,12 +45,11 @@ And the vector-Jacobian product:
### Variance reduction

!!! warning

Work in progress.

## Reparametrization

### Basics
### Trick

The reparametrization trick assumes that we can rewrite the random variable ``X \sim p(\theta)`` as ``X = g(Z, \theta)``, where ``Z \sim q`` is another random variable whose distribution does not depend on ``\theta``.

Expand Down Expand Up @@ -80,7 +79,8 @@ And the vector-Jacobian product:

The following reparametrizations are implemented:

- Univariate Gaussian: ``X \sim \mathcal{N}(\mu, \sigma^2)`` is equivalent to ``X = \mu + \sigma Z`` with ``Z \sim \mathcal{N}(0, 1)``.
- Univariate Normal: ``X \sim \mathcal{N}(\mu, \sigma^2)`` is equivalent to ``X = \mu + \sigma Z`` with ``Z \sim \mathcal{N}(0, 1)``.
- Multivariate Normal: ``X \sim \mathcal{N}(\mu, \Sigma)`` is equivalent to ``X = \mu + L Z`` with ``Z \sim \mathcal{N}(0, I)`` and ``L L^\top = \Sigma``. The matrix ``L`` can be obtained by Cholesky decomposition of ``\Sigma``.

## Bibliography

Expand Down
6 changes: 3 additions & 3 deletions src/DifferentiableExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ using ChainRulesCore:
rrule_via_ad,
unthunk
using DensityInterface: logdensityof
using Distributions: Distribution, Normal
using Distributions: Distribution, MvNormal, Normal
using DocStringExtensions
using LinearAlgebra: dot
using LinearAlgebra: Diagonal, cholesky, dot
using OhMyThreads: tmap, treduce, tmapreduce
using Random: Random, AbstractRNG, default_rng
using Statistics: Statistics, mean, std
using Statistics: Statistics, cov, mean, std
using StatsBase: StatsBase

include("utils.jl")
Expand Down
14 changes: 1 addition & 13 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ
function presamples(F::DifferentiableExpectation, θ...)
(; dist_constructor, rng, nb_samples) = F
dist = dist_constructor...)
xs = rand(rng, dist, nb_samples) # TODO: parallelize?
xs = maybe_eachcol(rand(rng, dist, nb_samples))
return xs
end

Expand All @@ -67,18 +67,6 @@ function samples_from_presamples(
end
end

function samples_from_presamples(
F::DifferentiableExpectation{threaded}, xs::AbstractMatrix; kwargs...
) where {threaded}
(; f) = F
fk = FixKwargs(f, kwargs)
if threaded
return tmap(fk, eachcol(xs))
else
return map(fk, eachcol(xs))
end
end

function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded}
ys = samples(F, θ...; kwargs...)
y = if threaded
Expand Down
37 changes: 6 additions & 31 deletions src/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,6 @@ function dist_logdensity_grad(
return
end

function logdensity_grads_from_presamples(
rc::RuleConfig,
F::DifferentiableExpectation{threaded},
xs::AbstractVector,
θ...;
kwargs...,
) where {threaded}
_dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...)
if threaded
return tmap(_dist_logdensity_grad_partial, xs)
else
return map(_dist_logdensity_grad_partial, xs)
end
end

function logdensity_grads_from_presamples(
rc::RuleConfig,
F::DifferentiableExpectation{threaded},
xs::AbstractMatrix,
θ...;
kwargs...,
) where {threaded}
_dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...)
if threaded
return tmap(_dist_logdensity_grad_partial, eachcol(xs))
else
return map(_dist_logdensity_grad_partial, eachcol(xs))
end
end

function ChainRulesCore.rrule(
rc::RuleConfig, F::Reinforce{threaded}, θ...; kwargs...
) where {threaded}
Expand All @@ -129,7 +99,12 @@ function ChainRulesCore.rrule(
xs = presamples(F, θ...)
ys = samples_from_presamples(F, xs; kwargs...)

gs = logdensity_grads_from_presamples(rc, F, xs, θ...)
_dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...)
gs = if threaded
tmap(_dist_logdensity_grad_partial, xs)
else
map(_dist_logdensity_grad_partial, xs)
end

function pullback_Reinforce(dy_thunked)
dy = unthunk(dy_thunked)
Expand Down
17 changes: 15 additions & 2 deletions src/reparametrization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ function reparametrize(dist::Normal{T}) where {T}
return TransformedDistribution(base_dist, transformation)
end

function reparametrize(dist::MvNormal{T}) where {T}
n = length(dist)
base_dist = MvNormal(fill(zero(T), n), Diagonal(fill(one(T), n)))
μ, Σ = mean(dist), cov(dist)
C = cholesky(Σ)
transformation(z) = μ .+ C.L * z
return TransformedDistribution(base_dist, transformation)
end

"""
Reparametrization{threaded} <: DifferentiableExpectation{threaded}
Expand Down Expand Up @@ -114,8 +123,12 @@ function ChainRulesCore.rrule(
(; f, dist_constructor, rng, nb_samples) = F
dist = dist_constructor...)
transformed_dist = reparametrize(dist)
zs = rand(rng, transformed_dist.base_dist, nb_samples)
xs = transformed_dist.transformation.(zs)
zs = maybe_eachcol(rand(rng, transformed_dist.base_dist, nb_samples))
xs = if threaded
tmap(transformed_dist.transformation, zs)
else
map(transformed_dist.transformation, zs)
end
ys = samples_from_presamples(F, xs; kwargs...)

function h(z, θ)
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ end
function tmean(args...)
return treduce(+, args...) / length(first(args))
end

maybe_eachcol(x::AbstractVector) = x
maybe_eachcol(x::AbstractMatrix) = eachcol(x)
29 changes: 18 additions & 11 deletions test/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct)

normal_logdensity_grad(x, θ...) = gradient((_θ...) -> logpdf(Normal(_θ...), x), θ...)

μ, σ = 0.5, 1.0
true_mean(μ, σ) = mean(LogNormal(μ, σ))
true_std(μ, σ) = std(LogNormal(μ, σ))
∇mean_true = gradient(true_mean, μ, σ)

@testset verbose = true "Univariate LogNormal" begin
@testset verbose = true "Threaded: $threaded" for threaded in (false, true)
μ, σ = 0.5, 1.0
true_mean(μ, σ) = mean(LogNormal(μ, σ))
true_std(μ, σ) = std(LogNormal(μ, σ))
∇mean_true = gradient(true_mean, μ, σ)

@testset verbose = true "Threaded: $threaded" for threaded in (false,) #
@testset "$(nameof(typeof(F)))" for F in [
Reinforce(
exp_with_kwargs,
Expand Down Expand Up @@ -59,6 +59,11 @@ true_std(μ, σ) = std(LogNormal(μ, σ))
end;

@testset verbose = true "Multivariate LogNormal" begin
μ, σ = [2.0, 3.0], [1.0, 0.5]
true_mean(μ, σ) = mean.(LogNormal.(μ, σ))
true_std(μ, σ) = std.(LogNormal.(μ, σ))
∂mean_true = jacobian(true_mean, μ, σ)

@testset verbose = true "Threaded: $threaded" for threaded in (false, true)
@testset "$(nameof(typeof(F)))" for F in [
Reinforce(
Expand All @@ -68,16 +73,18 @@ end;
nb_samples=10^5,
threaded=threaded,
),
# Reparametrization(
# vec_exp_with_kwargs,
# (μ, σ) -> MvNormal(μ, Diagonal(σ .^ 2));
# rng=StableRNG(63),
# nb_samples=10^5,
# threaded=threaded,
# ),
]
μ, σ = [2.0, 3.0], [1.0, 0.5]
true_mean(μ, σ) = mean.(LogNormal.(μ, σ))
true_std(μ, σ) = std.(LogNormal.(μ, σ))

@test F.dist_constructor(μ, σ) == MvNormal(μ, Diagonal.^ 2))
@test F(μ, σ; correct=true) true_mean(μ, σ) rtol = 0.1

∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ)
∂mean_true = jacobian(true_mean, μ, σ)

@test ∂mean_est[1] ∂mean_true[1] rtol = 0.1
@test ∂mean_est[2] ∂mean_true[2] rtol = 0.1
Expand Down
7 changes: 7 additions & 0 deletions test/reparametrization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ rng = StableRNG(63)
@test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) mean(dist) rtol = 1e-1
@test std([rand(rng, transformed_dist) for _ in 1:(10^4)]) std(dist) rtol = 1e-1
end

@testset "Multivariate Normal" begin
dist = MvNormal([2.0, 3.0], [2.0 0.01; 0.01 1.0])
transformed_dist = reparametrize(dist)
@test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) mean(dist) rtol = 1e-1
@test cov([rand(rng, transformed_dist) for _ in 1:(10^4)]) cov(dist) rtol = 1e-1
end

0 comments on commit 4918648

Please sign in to comment.