-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
142 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
FixedAtomsProbabilityDistribution | ||
A probability distribution with finite support and fixed atoms. | ||
Whenever its expectation is differentiated, only the weights are considered active, whereas the atoms are considered constant. | ||
# Fields | ||
$(TYPEDFIELDS) | ||
""" | ||
struct FixedAtomsProbabilityDistribution{threaded,A,W} | ||
atoms::Vector{A} | ||
weights::Vector{W} | ||
|
||
function FixedAtomsProbabilityDistribution( | ||
atoms::Vector{A}, weights::Vector{W}; threaded::Bool=false | ||
) where {A,W} | ||
if isempty(atoms) | ||
throw(ArgumentError("`atoms` must be non-empty.")) | ||
elseif length(atoms) != length(weights) | ||
throw(DimensionMismatch("`atoms` and `weights` must have the same length.")) | ||
elseif !isapprox(sum(weights), one(W); atol=1e-4) | ||
throw(ArgumentError("`weights` must be normalized to `1`.")) | ||
end | ||
return new{threaded,A,W}(atoms, weights) | ||
end | ||
end | ||
|
||
Base.length(dist::FixedAtomsProbabilityDistribution) = length(dist.atoms) | ||
|
||
function Random.rand(rng::AbstractRNG, dist::FixedAtomsProbabilityDistribution) | ||
(; atoms, weights) = dist | ||
return StatsBase.sample(rng, atoms, StatsBase.Weights(weights)) | ||
end | ||
|
||
function Base.map(f, dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded} | ||
(; atoms, weights) = dist | ||
new_atoms = if threaded | ||
tmap(f, atoms) | ||
else | ||
map(f, atoms) | ||
end | ||
return FixedAtomsProbabilityDistribution(new_atoms, weights) | ||
end | ||
|
||
function Statistics.mean(dist::FixedAtomsProbabilityDistribution{threaded}) where {threaded} | ||
(; atoms, weights) = dist | ||
if threaded | ||
return tmapreduce(*, +, weights, atoms) | ||
else | ||
return mapreduce(*, +, weights, atoms) | ||
end | ||
end | ||
|
||
function Statistics.mean(f, dist::FixedAtomsProbabilityDistribution) | ||
return mean(map(f, dist)) | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(mean), f, dist::FixedAtomsProbabilityDistribution{threaded} | ||
) where {threaded} | ||
(; atoms, weights) = dist | ||
new_atoms = if threaded | ||
tmap(f, atoms) | ||
else | ||
map(f, atoms) | ||
end | ||
|
||
function expectation_pullback(de) | ||
d_atoms = NoTangent() | ||
d_weights = if threaded | ||
tmap(Base.Fix1(dot, de), new_atoms) | ||
else | ||
map(Base.Fix1(dot, de), new_atoms) | ||
end | ||
d_dist = Tangent{FixedAtomsProbabilityDistribution}(; | ||
atoms=d_atoms, weights=d_weights | ||
) | ||
return NoTangent(), NoTangent(), d_dist | ||
end | ||
|
||
e = mean(FixedAtomsProbabilityDistribution(new_atoms, weights)) | ||
return e, expectation_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(mean), dist::FixedAtomsProbabilityDistribution{threaded} | ||
) where {threaded} | ||
e, pb = rrule(mean, identity, dist) | ||
function pb_nof(de) | ||
p = pb(de) | ||
return p[1], p[3] | ||
end | ||
return e, pb_nof | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
using ChainRulesCore | ||
using Distributions | ||
using DifferentiableExpectations | ||
using LinearAlgebra | ||
using Random | ||
using StableRNGs | ||
using Statistics | ||
using Test | ||
using Zygote | ||
|
||
rng = StableRNG(63) | ||
|
||
dist = FixedAtomsProbabilityDistribution([2.0, 3.0], [0.3, 0.7]) | ||
|
||
@test length(dist) == 2 | ||
|
||
@test mean(dist) ≈ 2.7 | ||
@test mean(abs2, dist) ≈ 7.5 | ||
@test mean([rand(rng, dist) for _ in 1:(10^5)]) ≈ 2.7 rtol = 0.1 | ||
@test mean(abs2, [rand(rng, dist) for _ in 1:(10^5)]) ≈ 7.5 rtol = 0.1 | ||
|
||
@test map(abs2, dist).weights == dist.weights | ||
@test map(abs2, dist).atoms == [4, 9] | ||
|
||
@test only(gradient(mean, dist)).atoms === nothing | ||
@test only(gradient(mean, dist)).weights == [2, 3] | ||
|
||
@test last(gradient(mean, abs2, dist)).atoms === nothing | ||
@test last(gradient(mean, abs2, dist)).weights == [4, 9] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters