From 7642642c6357b981dc1351ac395c6a1744d94e85 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 10 Jun 2024 11:17:31 +0200 Subject: [PATCH] Add fixed atoms distribution --- Project.toml | 3 + src/DifferentiableExpectations.jl | 9 ++- src/distribution.jl | 96 +++++++++++++++++++++++++++++++ test/distribution.jl | 29 ++++++++++ test/reinforce.jl | 9 ++- test/runtests.jl | 3 + 6 files changed, 142 insertions(+), 7 deletions(-) create mode 100644 src/distribution.jl create mode 100644 test/distribution.jl diff --git a/Project.toml b/Project.toml index e21fb70..f4c482a 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] ChainRulesCore = "1.23" @@ -20,7 +21,9 @@ Distributions = "0.25" DocStringExtensions = "0.9" LinearAlgebra = "1" OhMyThreads = "0.5" +Random = "1" Statistics = "1" +StatsBase = "0.34" julia = "1.10" [extras] diff --git a/src/DifferentiableExpectations.jl b/src/DifferentiableExpectations.jl index 77306ca..01559fe 100644 --- a/src/DifferentiableExpectations.jl +++ b/src/DifferentiableExpectations.jl @@ -14,7 +14,9 @@ using ChainRulesCore: NoTangent, ProjectTo, RuleConfig, + Tangent, @not_implemented, + rrule, rrule_via_ad, unthunk using DensityInterface: logdensityof @@ -22,16 +24,19 @@ using Distributions: Distribution, gradlogpdf using DocStringExtensions using LinearAlgebra: dot using OhMyThreads: tmap, treduce, tmapreduce -using Random: AbstractRNG, default_rng -using Statistics: mean +using Random: Random, AbstractRNG, default_rng +using Statistics: Statistics, mean +using StatsBase: StatsBase include("utils.jl") include("abstract.jl") include("reinforce.jl") include("reparametrization.jl") +include("distribution.jl") include("pushforward.jl") export DifferentiableExpectation export REINFORCE +export FixedAtomsProbabilityDistribution end # module DifferentiableExpectations diff --git a/src/distribution.jl b/src/distribution.jl new file mode 100644 index 0000000..b3f6955 --- /dev/null +++ b/src/distribution.jl @@ -0,0 +1,96 @@ +""" + FixedAtomsProbabilityDistribution + +A probability distribution with finite support and fixed atoms. + +Whenever its expectation is differentiated, only the weights are considered active, whereas the atoms are considered constant. + +# Fields + +$(TYPEDFIELDS) +""" +struct FixedAtomsProbabilityDistribution{threaded,A,W} + atoms::Vector{A} + weights::Vector{W} + + function FixedAtomsProbabilityDistribution( + atoms::Vector{A}, weights::Vector{W}; threaded::Bool=false + ) where {A,W} + if isempty(atoms) + throw(ArgumentError("`atoms` must be non-empty.")) + elseif length(atoms) != length(weights) + throw(DimensionMismatch("`atoms` and `weights` must have the same length.")) + elseif !isapprox(sum(weights), one(W); atol=1e-4) + throw(ArgumentError("`weights` must be normalized to `1`.")) + end + return new{threaded,A,W}(atoms, weights) + end +end + +Base.length(dist::FixedAtomsProbabilityDistribution) = length(dist.atoms) + +function Random.rand(rng::AbstractRNG, dist::FixedAtomsProbabilityDistribution) + (; atoms, weights) = dist + return StatsBase.sample(rng, atoms, StatsBase.Weights(weights)) +end + +function Base.map(f, dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded} + (; atoms, weights) = dist + new_atoms = if threaded + tmap(f, atoms) + else + map(f, atoms) + end + return FixedAtomsProbabilityDistribution(new_atoms, weights) +end + +function Statistics.mean(dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded} + (; atoms, weights) = dist + if threaded + return tmapreduce(*, +, weights, atoms) + else + return mapreduce(*, +, weights, atoms) + end +end + +function Statistics.mean(f, dist::FixedAtomsProbabilityDistribution) + return mean(map(f, dist)) +end + +function ChainRulesCore.rrule( + ::typeof(mean), f, dist::FixedAtomsProbabilityDistribution{threaded} +) where {threaded} + (; atoms, weights) = dist + new_atoms = if threaded + tmap(f, atoms) + else + map(f, atoms) + end + + function expectation_pullback(de) + d_atoms = NoTangent() + d_weights = if threaded + tmap(Base.Fix1(dot, de), new_atoms) + else + map(Base.Fix1(dot, de), new_atoms) + end + d_dist = Tangent{FixedAtomsProbabilityDistribution}(; + atoms=d_atoms, weights=d_weights + ) + return NoTangent(), NoTangent(), d_dist + end + + e = mean(FixedAtomsProbabilityDistribution(new_atoms, weights)) + return e, expectation_pullback +end + +function ChainRulesCore.rrule( + ::typeof(mean), dist::FixedAtomsProbabilityDistribution{threaded} +) where {threaded} + e, pb = rrule(mean, identity, dist) + function pb_nof(de) + p = pb(de) + return p[1], p[3] + end + return e, pb_nof +end diff --git a/test/distribution.jl b/test/distribution.jl new file mode 100644 index 0000000..b2fcd6c --- /dev/null +++ b/test/distribution.jl @@ -0,0 +1,29 @@ +using ChainRulesCore +using Distributions +using DifferentiableExpectations +using LinearAlgebra +using Random +using StableRNGs +using Statistics +using Test +using Zygote + +rng = StableRNG(63) + +dist = FixedAtomsProbabilityDistribution([2.0, 3.0], [0.3, 0.7]) + +@test length(dist) == 2 + +@test mean(dist) ≈ 2.7 +@test mean(abs2, dist) ≈ 7.5 +@test mean([rand(rng, dist) for _ in 1:(10^5)]) ≈ 2.7 rtol = 0.1 +@test mean(abs2, [rand(rng, dist) for _ in 1:(10^5)]) ≈ 7.5 rtol = 0.1 + +@test map(abs2, dist).weights == dist.weights +@test map(abs2, dist).atoms == [4, 9] + +@test only(gradient(mean, dist)).atoms === nothing +@test only(gradient(mean, dist)).weights == [2, 3] + +@test last(gradient(mean, abs2, dist)).atoms === nothing +@test last(gradient(mean, abs2, dist)).weights == [4, 9] diff --git a/test/reinforce.jl b/test/reinforce.jl index bfaea40..c1b1569 100644 --- a/test/reinforce.jl +++ b/test/reinforce.jl @@ -39,12 +39,11 @@ end vec_exp_with_kwargs, (μ, σ) -> MvNormal(μ, Diagonal(σ .^ 2)); rng=StableRNG(63), - nb_samples=10^4, + nb_samples=10^5, threaded=threaded, ) - dim = 2 - μ, σ = randn(dim), rand(dim) + μ, σ = [2.0, 3.0], [1.0, 0.5] true_mean(μ, σ) = mean.(LogNormal.(μ, σ)) true_std(μ, σ) = std.(LogNormal.(μ, σ)) @@ -54,7 +53,7 @@ end ∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ) ∂mean_true = jacobian(true_mean, μ, σ) - @test ∂mean_est[1] ≈ ∂mean_true[1] rtol = 0.2 - @test ∂mean_est[2] ≈ ∂mean_true[2] rtol = 0.2 + @test ∂mean_est[1] ≈ ∂mean_true[1] rtol = 0.1 + @test ∂mean_est[2] ≈ ∂mean_true[2] rtol = 0.1 end end diff --git a/test/runtests.jl b/test/runtests.jl index 6534aab..33b3ecf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,9 @@ using Zygote @testset "Reparametrization" begin include("reparametrization.jl") end + @testset "Distribution" begin + include("distribution.jl") + end @testset "Pushforward" begin include("pushforward.jl") end