Skip to content

Commit

Permalink
Add analytical formulas for KL of Binomial, NegativeBinomial, and…
Browse files Browse the repository at this point in the history
… `Geometric` (#1541)

* Put numerical fallback into separate function

* Add method for `Binomial`

* Add method for `Geometric`

* Add method for `NegativeBinomial`

* Bump version

* Fix type instability

Co-authored-by: David Widmann <[email protected]>

* Use `invoke` instead of explicit fallback

Co-authored-by: David Widmann <[email protected]>

* Remove unnecessary fallback

Co-authored-by: David Widmann <[email protected]>

* Check for equality

Co-authored-by: David Widmann <[email protected]>

* Update src/univariate/discrete/negativebinomial.jl

Co-authored-by: David Widmann <[email protected]>

* Add tests for edge cases of `Geometric`

* Use numerically stabler computation

* Optimize KL for `Binomial`s with different n

* Add more tests

* Fix edge case

* Match return type of `res`

* Partially implement suggestion from review

* Fix typo

* Bump version

Co-authored-by: David Widmann <[email protected]>

* Update test/ref/readme.md

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
simsurace and devmotion authored May 6, 2022
1 parent a161b17 commit dd6ae8f
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.57"
version = "0.25.58"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
18 changes: 18 additions & 0 deletions src/univariate/discrete/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ function entropy(d::Binomial; approx::Bool=false)
end
end

function kldivergence(p::Binomial, q::Binomial; kwargs...)
np = ntrials(p)
nq = ntrials(q)
succp = succprob(p)
succq = succprob(q)
res = np * kldivergence(Bernoulli{typeof(succp)}(succp), Bernoulli{typeof(succq)}(succq))
if np == nq
iszero(np) && return zero(res)
return res
elseif np > nq
return oftype(res, Inf)
else
# pull some terms out of the expectation to make this more efficient:
res += logfactorial(np) - logfactorial(nq) - (nq - np) * log1p(-succq)
res += expectation(k -> logfactorial(nq - k) - logfactorial(np - k), p)
return res
end
end

#### Evaluation & Sampling

Expand Down
12 changes: 12 additions & 0 deletions src/univariate/discrete/geometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ kurtosis(d::Geometric) = 6 + abs2(d.p) / (1 - d.p)

entropy(d::Geometric) = (-xlogx(succprob(d)) - xlogx(failprob(d))) / d.p

function kldivergence(p::Geometric, q::Geometric)
x = succprob(p)
y = succprob(q)
if x == y
return zero(float(x / y))
elseif isone(x)
return -log(y / x)
else
return log(x) - log(y) + (inv(x) - one(x)) * (log1p(-x) - log1p(-y))
end
end


### Evaluations

Expand Down
10 changes: 10 additions & 0 deletions src/univariate/discrete/negativebinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ kurtosis(d::NegativeBinomial{T}) where {T} = (p = succprob(d); T(6) / d.r + (p *

mode(d::NegativeBinomial{T}) where {T} = (p = succprob(d); floor(Int,(one(T) - p) * (d.r - one(T)) / p))

function kldivergence(p::NegativeBinomial, q::NegativeBinomial; kwargs...)
if p.r == q.r
return p.r * kldivergence(Geometric(succprob(p)), Geometric(succprob(q)))
else
# There does not appear to be an analytical formula for
# this case. Hence we fall back to the numerical approximation.
return invoke(kldivergence, Tuple{UnivariateDistribution{Discrete},UnivariateDistribution{Discrete}}, p, q; kwargs...)
end
end


#### Evaluation & Sampling

Expand Down
35 changes: 35 additions & 0 deletions test/functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ end
q = Beta(3, 5)
test_kl(p, q)
end
@testset "Binomial" begin
p = Binomial(3, 0.3)
q = Binomial(3, 0.5)
test_kl(p, q)
@test iszero(kldivergence(Binomial(0, 0), Binomial(0, 1)))
@test iszero(kldivergence(Binomial(0, 0.5), Binomial(0, 0.3)))
@test isinf(kldivergence(Binomial(4, 0.3), Binomial(2, 0.3)))
@test isinf(kldivergence(Binomial(3, 0), Binomial(3, 1)))
@test isinf(kldivergence(Binomial(3, 0), Binomial(5, 1)))
@test kldivergence(p, q) 3 * kldivergence(Bernoulli(0.3), Bernoulli(0.5))
end
@testset "Categorical" begin
@test kldivergence(Categorical([0.0, 0.1, 0.9]), Categorical([0.1, 0.1, 0.8])) 0
@test kldivergence(Categorical([0.0, 0.1, 0.9]), Categorical([0.1, 0.1, 0.8]))
Expand Down Expand Up @@ -84,6 +95,24 @@ end
q = Gamma(3.0, 2.0)
test_kl(p, q)
end
@testset "Geometric" begin
p = Geometric(0.3)
q = Geometric(0.4)
test_kl(p, q)

x1 = nextfloat(0.0)
x2 = prevfloat(1.0)
p1 = Geometric(x1)
p2 = Geometric(x2)
@test iszero(kldivergence(p2, p2))
@test iszero(kldivergence(p1, p1))
@test isinf(kldivergence(p1, p2))
@test kldivergence(p2, p1) -log(x1)
@test isinf(kldivergence(p1, Geometric(0.5)))
@test kldivergence(p2, Geometric(0.5)) -log(0.5)
@test kldivergence(Geometric(0.5), p2) 2*log(0.5) - log(1-x2)
@test kldivergence(Geometric(0.5), p1) 2*log(0.5) - log(x1)
end
@testset "InverseGamma" begin
p = InverseGamma(2.0, 1.0)
q = InverseGamma(3.0, 2.0)
Expand All @@ -106,6 +135,12 @@ end
test_kl(p, q)
@test kldivergence(p, q) kldivergence(Normal(0, 1), Normal(0.5, 0.5))
end
@testset "NegativeBinomial" begin
p = NegativeBinomial(3, 0.3)
q = NegativeBinomial(3, 0.5)
test_kl(p, q)
@test kldivergence(p, q) 3 * kldivergence(Geometric(0.3), Geometric(0.5))
end
@testset "Normal" begin
p = Normal(0, 1)
q = Normal(0.5, 0.5)
Expand Down
2 changes: 1 addition & 1 deletion test/ref/readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# R References for Distributions.jl

We relies on the distribution-related functions provided by
We rely on the distribution-related functions provided by
[R](https://www.r-project.org) and a number of packages in
[CRAN](https://cran.r-project.org) to generate references
for verifying the correctness of our implementations.
Expand Down

2 comments on commit dd6ae8f

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59783

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.25.58 -m "<description of version>" dd6ae8f4eac304f404b0069540a6c3bb1c667f92
git push origin v0.25.58

Please sign in to comment.