diff --git a/docs/src/api.md b/docs/src/api.md index 2428148..908cbee 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,5 +1,9 @@ # API reference +```@meta +CollapsedDocStrings = true +``` + ## Public ```@autodocs diff --git a/src/abstract.jl b/src/abstract.jl index 465f9cc..885312b 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -3,6 +3,17 @@ Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages. +# Subtypes + + - [`Reinforce`](@ref) + - [`Reparametrization`](@ref) + +# Calling behavior + + (F::DifferentiableExpectation)(θ...; kwargs...) + +Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. + # Type parameters - `threaded::Bool`: specifies whether the sampling should be performed in parallel @@ -68,11 +79,6 @@ function samples_from_presamples( end end -""" - (F::DifferentiableExpectation)(θ...; kwargs...) - -Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples. -""" function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded} ys = samples(F, θ...; kwargs...) y = if threaded diff --git a/src/reparametrization.jl b/src/reparametrization.jl index 8e8cbf8..4a263d6 100644 --- a/src/reparametrization.jl +++ b/src/reparametrization.jl @@ -14,6 +14,16 @@ struct TransformedDistribution{D,T} transformation::T end +""" + rand(rng, dist::TransformedDistribution) + +Sample from `dist` by applying `dist.transformation` to `dist.base_dist`. +""" +function Random.rand(rng::AbstractRNG, dist::TransformedDistribution) + (; base_dist, transformation) = dist + return transformation(rand(rng, base_dist)) +end + """ reparametrize(dist) @@ -42,7 +52,7 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)` ```jldoctest using DifferentiableExpectations, Distributions, Zygote -F = Reparametrization(exp, Normal; nb_samples=10^3) +F = Reparametrization(exp, Normal; nb_samples=10^4) F_true(μ, σ) = mean(LogNormal(μ, σ)) μ, σ = 0.5, 1,0 diff --git a/test/expectation.jl b/test/expectation.jl index 49fa817..22f36b2 100644 --- a/test/expectation.jl +++ b/test/expectation.jl @@ -48,8 +48,8 @@ true_std(μ, σ) = std(LogNormal(μ, σ)) end end; -@testset "Multivariate LogNormal" begin - @testset "Threaded: $threaded" for threaded in (false, true) +@testset verbose = true "Multivariate LogNormal" begin + @testset verbose = true "Threaded: $threaded" for threaded in (false, true) @testset "$(nameof(typeof(F)))" for F in [ Reinforce( vec_exp_with_kwargs, diff --git a/test/reparametrization.jl b/test/reparametrization.jl new file mode 100644 index 0000000..1c9f055 --- /dev/null +++ b/test/reparametrization.jl @@ -0,0 +1,13 @@ +using DifferentiableExpectations: reparametrize +using Distributions +using StableRNGs +using Test + +rng = StableRNG(63) + +@testset "Univariate Normal" begin + dist = Normal(2.0, 1.0) + transformed_dist = reparametrize(dist) + @test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ mean(dist) rtol = 1e-1 + @test std([rand(rng, transformed_dist) for _ in 1:(10^4)]) ≈ std(dist) rtol = 1e-1 +end diff --git a/test/runtests.jl b/test/runtests.jl index d47c249..029ca4b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,10 +27,13 @@ using Zygote Documenter.doctest(DifferentiableExpectations) end end - @testset verbose = true "Expectation" begin - include("expectation.jl") - end @testset "Distribution" begin include("distribution.jl") end + @testset verbose = true "Reparametrization" begin + include("reparametrization.jl") + end + @testset verbose = true "Expectation" begin + include("expectation.jl") + end end