diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 74e08f7..a2fd94d 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,3 +1,2 @@ style = "blue" -format_docstrings = true -format_markdown = true \ No newline at end of file +format_docstrings = true \ No newline at end of file diff --git a/README.md b/README.md index 30d8848..99cf86d 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ F(\theta) = \mathbb{E}_{p(\theta)}[f(X)] The following estimators are implemented: - - [REINFORCE](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.20) - - [Reparametrization](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.56) +- [REINFORCE](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.20) +- [Reparametrization](https://jmlr.org/papers/volume21/19-346/19-346.pdf#section.56) > Warning: this package is experimental, use at your own risk and expect frequent breaking releases. diff --git a/docs/src/background.md b/docs/src/background.md index bbaa240..5d12cc9 100644 --- a/docs/src/background.md +++ b/docs/src/background.md @@ -15,7 +15,7 @@ Most of the math below is taken from [mohamedMonteCarloGradient2020](@citet). ## REINFORCE -### Basics +### Principle The REINFORCE estimator is derived with the help of the identity ``\nabla \log u = \nabla u / u``: @@ -45,12 +45,11 @@ And the vector-Jacobian product: ### Variance reduction !!! warning - Work in progress. ## Reparametrization -### Basics +### Trick The reparametrization trick assumes that we can rewrite the random variable ``X \sim p(\theta)`` as ``X = g(Z, \theta)``, where ``Z \sim q`` is another random variable whose distribution does not depend on ``\theta``. @@ -80,7 +79,8 @@ And the vector-Jacobian product: The following reparametrizations are implemented: - - Univariate Gaussian: ``X \sim \mathcal{N}(\mu, \sigma^2)`` is equivalent to ``X = \mu + \sigma Z`` with ``Z \sim \mathcal{N}(0, 1)``. +- Univariate Normal: ``X \sim \mathcal{N}(\mu, \sigma^2)`` is equivalent to ``X = \mu + \sigma Z`` with ``Z \sim \mathcal{N}(0, 1)``. +- Multivariate Normal: ``X \sim \mathcal{N}(\mu, \Sigma)`` is equivalent to ``X = \mu + L Z`` with ``Z \sim \mathcal{N}(0, I)`` and ``L L^\top = \Sigma``. The matrix ``L`` can be obtained by Cholesky decomposition of ``\Sigma``. ## Bibliography diff --git a/src/DifferentiableExpectations.jl b/src/DifferentiableExpectations.jl index be8c7dd..6ac5e2b 100644 --- a/src/DifferentiableExpectations.jl +++ b/src/DifferentiableExpectations.jl @@ -20,12 +20,12 @@ using ChainRulesCore: rrule_via_ad, unthunk using DensityInterface: logdensityof -using Distributions: Distribution, Normal +using Distributions: Distribution, MvNormal, Normal using DocStringExtensions -using LinearAlgebra: dot +using LinearAlgebra: Diagonal, cholesky, dot using OhMyThreads: tmap, treduce, tmapreduce using Random: Random, AbstractRNG, default_rng -using Statistics: Statistics, mean, std +using Statistics: Statistics, cov, mean, std using StatsBase: StatsBase include("utils.jl") diff --git a/src/abstract.jl b/src/abstract.jl index 885312b..9f96a01 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -41,7 +41,7 @@ Return a vector `[x₁, ..., xₛ]` or matrix `[x₁ ... xₛ]` where the `xᵢ function presamples(F::DifferentiableExpectation, θ...) (; dist_constructor, rng, nb_samples) = F dist = dist_constructor(θ...) - xs = rand(rng, dist, nb_samples) # TODO: parallelize? + xs = maybe_eachcol(rand(rng, dist, nb_samples)) return xs end @@ -67,18 +67,6 @@ function samples_from_presamples( end end -function samples_from_presamples( - F::DifferentiableExpectation{threaded}, xs::AbstractMatrix; kwargs... -) where {threaded} - (; f) = F - fk = FixKwargs(f, kwargs) - if threaded - return tmap(fk, eachcol(xs)) - else - return map(fk, eachcol(xs)) - end -end - function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded} ys = samples(F, θ...; kwargs...) y = if threaded diff --git a/src/reinforce.jl b/src/reinforce.jl index 4df4f3e..62fe7ca 100644 --- a/src/reinforce.jl +++ b/src/reinforce.jl @@ -90,36 +90,6 @@ function dist_logdensity_grad( return dθ end -function logdensity_grads_from_presamples( - rc::RuleConfig, - F::DifferentiableExpectation{threaded}, - xs::AbstractVector, - θ...; - kwargs..., -) where {threaded} - _dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...) - if threaded - return tmap(_dist_logdensity_grad_partial, xs) - else - return map(_dist_logdensity_grad_partial, xs) - end -end - -function logdensity_grads_from_presamples( - rc::RuleConfig, - F::DifferentiableExpectation{threaded}, - xs::AbstractMatrix, - θ...; - kwargs..., -) where {threaded} - _dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...) - if threaded - return tmap(_dist_logdensity_grad_partial, eachcol(xs)) - else - return map(_dist_logdensity_grad_partial, eachcol(xs)) - end -end - function ChainRulesCore.rrule( rc::RuleConfig, F::Reinforce{threaded}, θ...; kwargs... ) where {threaded} @@ -129,7 +99,12 @@ function ChainRulesCore.rrule( xs = presamples(F, θ...) ys = samples_from_presamples(F, xs; kwargs...) - gs = logdensity_grads_from_presamples(rc, F, xs, θ...) + _dist_logdensity_grad_partial(x) = dist_logdensity_grad(rc, F, x, θ...) + gs = if threaded + tmap(_dist_logdensity_grad_partial, xs) + else + map(_dist_logdensity_grad_partial, xs) + end function pullback_Reinforce(dy_thunked) dy = unthunk(dy_thunked) diff --git a/src/reparametrization.jl b/src/reparametrization.jl index 4a263d6..c6c9bbf 100644 --- a/src/reparametrization.jl +++ b/src/reparametrization.jl @@ -39,6 +39,15 @@ function reparametrize(dist::Normal{T}) where {T} return TransformedDistribution(base_dist, transformation) end +function reparametrize(dist::MvNormal{T}) where {T} + n = length(dist) + base_dist = MvNormal(fill(zero(T), n), Diagonal(fill(one(T), n))) + μ, Σ = mean(dist), cov(dist) + C = cholesky(Σ) + transformation(z) = μ .+ C.L * z + return TransformedDistribution(base_dist, transformation) +end + """ Reparametrization{threaded} <: DifferentiableExpectation{threaded} @@ -114,8 +123,12 @@ function ChainRulesCore.rrule( (; f, dist_constructor, rng, nb_samples) = F dist = dist_constructor(θ...) transformed_dist = reparametrize(dist) - zs = rand(rng, transformed_dist.base_dist, nb_samples) - xs = transformed_dist.transformation.(zs) + zs = maybe_eachcol(rand(rng, transformed_dist.base_dist, nb_samples)) + xs = if threaded + tmap(transformed_dist.transformation, zs) + else + map(transformed_dist.transformation, zs) + end ys = samples_from_presamples(F, xs; kwargs...) function h(z, θ) diff --git a/src/utils.jl b/src/utils.jl index bf5a590..868475c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,3 +10,6 @@ end function tmean(args...) return treduce(+, args...) / length(first(args)) end + +maybe_eachcol(x::AbstractVector) = x +maybe_eachcol(x::AbstractMatrix) = eachcol(x) diff --git a/test/expectation.jl b/test/expectation.jl index c76190e..06cb10f 100644 --- a/test/expectation.jl +++ b/test/expectation.jl @@ -13,13 +13,13 @@ vec_exp_with_kwargs(x; correct=false) = exp_with_kwargs.(x; correct) normal_logdensity_grad(x, θ...) = gradient((_θ...) -> logpdf(Normal(_θ...), x), θ...) -μ, σ = 0.5, 1.0 -true_mean(μ, σ) = mean(LogNormal(μ, σ)) -true_std(μ, σ) = std(LogNormal(μ, σ)) -∇mean_true = gradient(true_mean, μ, σ) - @testset verbose = true "Univariate LogNormal" begin - @testset verbose = true "Threaded: $threaded" for threaded in (false, true) + μ, σ = 0.5, 1.0 + true_mean(μ, σ) = mean(LogNormal(μ, σ)) + true_std(μ, σ) = std(LogNormal(μ, σ)) + ∇mean_true = gradient(true_mean, μ, σ) + + @testset verbose = true "Threaded: $threaded" for threaded in (false,) # @testset "$(nameof(typeof(F)))" for F in [ Reinforce( exp_with_kwargs, @@ -59,6 +59,11 @@ true_std(μ, σ) = std(LogNormal(μ, σ)) end; @testset verbose = true "Multivariate LogNormal" begin + μ, σ = [2.0, 3.0], [1.0, 0.5] + true_mean(μ, σ) = mean.(LogNormal.(μ, σ)) + true_std(μ, σ) = std.(LogNormal.(μ, σ)) + ∂mean_true = jacobian(true_mean, μ, σ) + @testset verbose = true "Threaded: $threaded" for threaded in (false, true) @testset "$(nameof(typeof(F)))" for F in [ Reinforce( @@ -68,16 +73,18 @@ end; nb_samples=10^5, threaded=threaded, ), + # Reparametrization( + # vec_exp_with_kwargs, + # (μ, σ) -> MvNormal(μ, Diagonal(σ .^ 2)); + # rng=StableRNG(63), + # nb_samples=10^5, + # threaded=threaded, + # ), ] - μ, σ = [2.0, 3.0], [1.0, 0.5] - true_mean(μ, σ) = mean.(LogNormal.(μ, σ)) - true_std(μ, σ) = std.(LogNormal.(μ, σ)) - @test F.dist_constructor(μ, σ) == MvNormal(μ, Diagonal(σ .^ 2)) @test F(μ, σ; correct=true) ≈ true_mean(μ, σ) rtol = 0.1 ∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ) - ∂mean_true = jacobian(true_mean, μ, σ) @test ∂mean_est[1] ≈ ∂mean_true[1] rtol = 0.1 @test ∂mean_est[2] ≈ ∂mean_true[2] rtol = 0.1 diff --git a/test/reparametrization.jl b/test/reparametrization.jl index 1c9f055..6a125d4 100644 --- a/test/reparametrization.jl +++ b/test/reparametrization.jl @@ -11,3 +11,10 @@ rng = StableRNG(63) @test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ mean(dist) rtol = 1e-1 @test std([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ std(dist) rtol = 1e-1 end + +@testset "Multivariate Normal" begin + dist = MvNormal([2.0, 3.0], [2.0 0.01; 0.01 1.0]) + transformed_dist = reparametrize(dist) + @test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ mean(dist) rtol = 1e-1 + @test cov([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ cov(dist) rtol = 1e-1 +end