Skip to content

Commit

Permalink
Add fixed atoms distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 10, 2024
1 parent 6e2b0bf commit 7642642
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 7 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
ChainRulesCore = "1.23"
Expand All @@ -20,7 +21,9 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
LinearAlgebra = "1"
OhMyThreads = "0.5"
Random = "1"
Statistics = "1"
StatsBase = "0.34"
julia = "1.10"

[extras]
Expand Down
9 changes: 7 additions & 2 deletions src/DifferentiableExpectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,29 @@ using ChainRulesCore:
NoTangent,
ProjectTo,
RuleConfig,
Tangent,
@not_implemented,
rrule,
rrule_via_ad,
unthunk
using DensityInterface: logdensityof
using Distributions: Distribution, gradlogpdf
using DocStringExtensions
using LinearAlgebra: dot
using OhMyThreads: tmap, treduce, tmapreduce
using Random: AbstractRNG, default_rng
using Statistics: mean
using Random: Random, AbstractRNG, default_rng
using Statistics: Statistics, mean
using StatsBase: StatsBase

include("utils.jl")
include("abstract.jl")
include("reinforce.jl")
include("reparametrization.jl")
include("distribution.jl")
include("pushforward.jl")

export DifferentiableExpectation
export REINFORCE
export FixedAtomsProbabilityDistribution

end # module DifferentiableExpectations
96 changes: 96 additions & 0 deletions src/distribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
FixedAtomsProbabilityDistribution
A probability distribution with finite support and fixed atoms.
Whenever its expectation is differentiated, only the weights are considered active, whereas the atoms are considered constant.
# Fields
$(TYPEDFIELDS)
"""
struct FixedAtomsProbabilityDistribution{threaded,A,W}
atoms::Vector{A}
weights::Vector{W}

function FixedAtomsProbabilityDistribution(
atoms::Vector{A}, weights::Vector{W}; threaded::Bool=false
) where {A,W}
if isempty(atoms)
throw(ArgumentError("`atoms` must be non-empty."))
elseif length(atoms) != length(weights)
throw(DimensionMismatch("`atoms` and `weights` must have the same length."))
elseif !isapprox(sum(weights), one(W); atol=1e-4)
throw(ArgumentError("`weights` must be normalized to `1`."))
end
return new{threaded,A,W}(atoms, weights)
end
end

Base.length(dist::FixedAtomsProbabilityDistribution) = length(dist.atoms)

function Random.rand(rng::AbstractRNG, dist::FixedAtomsProbabilityDistribution)
(; atoms, weights) = dist
return StatsBase.sample(rng, atoms, StatsBase.Weights(weights))
end

function Base.map(f, dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded}
(; atoms, weights) = dist
new_atoms = if threaded
tmap(f, atoms)
else
map(f, atoms)
end
return FixedAtomsProbabilityDistribution(new_atoms, weights)
end

function Statistics.mean(dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded}
(; atoms, weights) = dist
if threaded
return tmapreduce(*, +, weights, atoms)
else
return mapreduce(*, +, weights, atoms)
end
end

function Statistics.mean(f, dist::FixedAtomsProbabilityDistribution)
return mean(map(f, dist))
end

function ChainRulesCore.rrule(
::typeof(mean), f, dist::FixedAtomsProbabilityDistribution{threaded}
) where {threaded}
(; atoms, weights) = dist
new_atoms = if threaded
tmap(f, atoms)
else
map(f, atoms)
end

function expectation_pullback(de)
d_atoms = NoTangent()
d_weights = if threaded
tmap(Base.Fix1(dot, de), new_atoms)
else
map(Base.Fix1(dot, de), new_atoms)
end
d_dist = Tangent{FixedAtomsProbabilityDistribution}(;
atoms=d_atoms, weights=d_weights
)
return NoTangent(), NoTangent(), d_dist
end

e = mean(FixedAtomsProbabilityDistribution(new_atoms, weights))
return e, expectation_pullback
end

function ChainRulesCore.rrule(
::typeof(mean), dist::FixedAtomsProbabilityDistribution{threaded}
) where {threaded}
e, pb = rrule(mean, identity, dist)
function pb_nof(de)
p = pb(de)
return p[1], p[3]
end
return e, pb_nof
end
29 changes: 29 additions & 0 deletions test/distribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ChainRulesCore
using Distributions
using DifferentiableExpectations
using LinearAlgebra
using Random
using StableRNGs
using Statistics
using Test
using Zygote

rng = StableRNG(63)

dist = FixedAtomsProbabilityDistribution([2.0, 3.0], [0.3, 0.7])

@test length(dist) == 2

@test mean(dist) 2.7
@test mean(abs2, dist) 7.5
@test mean([rand(rng, dist) for _ in 1:(10^5)]) 2.7 rtol = 0.1
@test mean(abs2, [rand(rng, dist) for _ in 1:(10^5)]) 7.5 rtol = 0.1

@test map(abs2, dist).weights == dist.weights
@test map(abs2, dist).atoms == [4, 9]

@test only(gradient(mean, dist)).atoms === nothing
@test only(gradient(mean, dist)).weights == [2, 3]

@test last(gradient(mean, abs2, dist)).atoms === nothing
@test last(gradient(mean, abs2, dist)).weights == [4, 9]
9 changes: 4 additions & 5 deletions test/reinforce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ end
vec_exp_with_kwargs,
(μ, σ) -> MvNormal(μ, Diagonal.^ 2));
rng=StableRNG(63),
nb_samples=10^4,
nb_samples=10^5,
threaded=threaded,
)

dim = 2
μ, σ = randn(dim), rand(dim)
μ, σ = [2.0, 3.0], [1.0, 0.5]
true_mean(μ, σ) = mean.(LogNormal.(μ, σ))
true_std(μ, σ) = std.(LogNormal.(μ, σ))

Expand All @@ -54,7 +53,7 @@ end
∂mean_est = jacobian((μ, σ) -> F(μ, σ; correct=true), μ, σ)
∂mean_true = jacobian(true_mean, μ, σ)

@test ∂mean_est[1] ∂mean_true[1] rtol = 0.2
@test ∂mean_est[2] ∂mean_true[2] rtol = 0.2
@test ∂mean_est[1] ∂mean_true[1] rtol = 0.1
@test ∂mean_est[2] ∂mean_true[2] rtol = 0.1
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ using Zygote
@testset "Reparametrization" begin
include("reparametrization.jl")
end
@testset "Distribution" begin
include("distribution.jl")
end
@testset "Pushforward" begin
include("pushforward.jl")
end
Expand Down

0 comments on commit 7642642

Please sign in to comment.