Skip to content

Commit

Permalink
More docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent c42cb6b commit 1054068
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 11 deletions.
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# API reference

```@meta
CollapsedDocStrings = true
```

## Public

```@autodocs
Expand Down
16 changes: 11 additions & 5 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
Abstract supertype for differentiable parametric expectations `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`, whose value and derivative are approximated with Monte-Carlo averages.
# Subtypes
- [`Reinforce`](@ref)
- [`Reparametrization`](@ref)
# Calling behavior
(F::DifferentiableExpectation)(θ...; kwargs...)
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
# Type parameters
- `threaded::Bool`: specifies whether the sampling should be performed in parallel
Expand Down Expand Up @@ -68,11 +79,6 @@ function samples_from_presamples(
end
end

"""
(F::DifferentiableExpectation)(θ...; kwargs...)
Return a Monte-Carlo average `(1/s) ∑f(xᵢ)` where the `xᵢ ∼ p(θ)` are iid samples.
"""
function (F::DifferentiableExpectation{threaded})(θ...; kwargs...) where {threaded}
ys = samples(F, θ...; kwargs...)
y = if threaded
Expand Down
12 changes: 11 additions & 1 deletion src/reparametrization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ struct TransformedDistribution{D,T}
transformation::T
end

"""
rand(rng, dist::TransformedDistribution)
Sample from `dist` by applying `dist.transformation` to `dist.base_dist`.
"""
function Random.rand(rng::AbstractRNG, dist::TransformedDistribution)
(; base_dist, transformation) = dist
return transformation(rand(rng, base_dist))
end

"""
reparametrize(dist)
Expand Down Expand Up @@ -42,7 +52,7 @@ Differentiable parametric expectation `F : θ -> 𝔼[f(X)]` where `X ∼ p(θ)`
```jldoctest
using DifferentiableExpectations, Distributions, Zygote
F = Reparametrization(exp, Normal; nb_samples=10^3)
F = Reparametrization(exp, Normal; nb_samples=10^4)
F_true(μ, σ) = mean(LogNormal(μ, σ))
μ, σ = 0.5, 1,0
Expand Down
4 changes: 2 additions & 2 deletions test/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ true_std(μ, σ) = std(LogNormal(μ, σ))
end
end;

@testset "Multivariate LogNormal" begin
@testset "Threaded: $threaded" for threaded in (false, true)
@testset verbose = true "Multivariate LogNormal" begin
@testset verbose = true "Threaded: $threaded" for threaded in (false, true)
@testset "$(nameof(typeof(F)))" for F in [
Reinforce(
vec_exp_with_kwargs,
Expand Down
13 changes: 13 additions & 0 deletions test/reparametrization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using DifferentiableExpectations: reparametrize
using Distributions
using StableRNGs
using Test

rng = StableRNG(63)

@testset "Univariate Normal" begin
dist = Normal(2.0, 1.0)
transformed_dist = reparametrize(dist)
@test mean([rand(rng, transformed_dist) for _ in 1:(10^4)]) mean(dist) rtol = 1e-1
@test std([rand(rng, transformed_dist) for _ in 1:(10^4)]) std(dist) rtol = 1e-1
end
9 changes: 6 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ using Zygote
Documenter.doctest(DifferentiableExpectations)
end
end
@testset verbose = true "Expectation" begin
include("expectation.jl")
end
@testset "Distribution" begin
include("distribution.jl")
end
@testset verbose = true "Reparametrization" begin
include("reparametrization.jl")
end
@testset verbose = true "Expectation" begin
include("expectation.jl")
end
end

0 comments on commit 1054068

Please sign in to comment.