Skip to content

Commit

Permalink
Add MvLogitNormal (#1774)
Browse files Browse the repository at this point in the history
* Create MvLogitNormal

* Add MvLogitNormal to docs

* Simplify constructors

* Fix conversions

* Rearrange code

* Fix computation of  -Inf

* Add meanform and canonform

* Add back type constructor

* Add MvLogitNormal tests

* Update and test show method

* Fix testset name

* Fix for older Julia versions

* Restrict testing of `show` method to newer versions

* Add kldivergence tests

* Improve documentation

* Remove constructor with type and AbstractMvNormal params

* Update show method

* Update docstring

* Remove reference to Dirichlet

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
sethaxen and devmotion authored Sep 28, 2023
1 parent b21e515 commit cd5d4cc
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Multinomial
Distributions.AbstractMvNormal
MvNormal
MvNormalCanon
MvLogitNormal
MvLogNormal
Dirichlet
Product
Expand Down
1 change: 1 addition & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export
Logistic,
LogNormal,
LogUniform,
MvLogitNormal,
LogitNormal,
MatrixBeta,
MatrixFDist,
Expand Down
140 changes: 140 additions & 0 deletions src/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
MvLogitNormal{<:AbstractMvNormal}
The [multivariate logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization)
is a multivariate generalization of [`LogitNormal`](@ref) capable of handling correlations
between variables.
If ``\\mathbf{y} \\sim \\mathrm{MvNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})`` is a
length ``d-1`` vector, then
```math
\\mathbf{x} = \\operatorname{softmax}\\left(\\begin{bmatrix}\\mathbf{y} \\\\ 0 \\end{bmatrix}\\right) \\sim \\mathrm{MvLogitNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})
```
is a length ``d`` probability vector.
```julia
MvLogitNormal(μ, Σ) # MvLogitNormal with y ~ MvNormal(μ, Σ)
MvLogitNormal(MvNormal(μ, Σ)) # same as above
MvLogitNormal(MvNormalCanon(μ, J)) # MvLogitNormal with y ~ MvNormalCanon(μ, J)
```
# Fields
- `normal::AbstractMvNormal`: contains the ``d-1``-dimensional distribution of ``y``
"""
struct MvLogitNormal{D<:AbstractMvNormal} <: ContinuousMultivariateDistribution
normal::D
MvLogitNormal{D}(normal::D) where {D<:AbstractMvNormal} = new{D}(normal)
end
MvLogitNormal(d::AbstractMvNormal) = MvLogitNormal{typeof(d)}(d)
MvLogitNormal(args...) = MvLogitNormal(MvNormal(args...))

function Base.show(io::IO, d::MvLogitNormal; indent::String=" ")
print(io, distrname(d))
println(io, "(")
normstr = strip(sprint(show, d.normal; context=IOContext(io)))
normstr = replace(normstr, "\n" => "\n$indent")
print(io, indent)
println(io, normstr)
println(io, ")")
end

# Conversions

function convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal) where {D}
return MvLogitNormal(convert(D, d.normal))
end
Base.convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal{D}) where {D} = d

meanform(d::MvLogitNormal{<:MvNormalCanon}) = MvLogitNormal(meanform(d.normal))
canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal))

# Properties

length(d::MvLogitNormal) = length(d.normal) + 1
Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D)
Base.eltype(d::MvLogitNormal) = eltype(d.normal)
params(d::MvLogitNormal) = params(d.normal)
@inline partype(d::MvLogitNormal) = partype(d.normal)

location(d::MvLogitNormal) = mean(d.normal)
minimum(d::MvLogitNormal) = fill(zero(eltype(d)), length(d))
maximum(d::MvLogitNormal) = fill(oneunit(eltype(d)), length(d))

function insupport(d::MvLogitNormal, x::AbstractVector{<:Real})
return length(d) == length(x) && all((0), x) && sum(x) 1
end

# Evaluation

function _logpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
if !insupport(d, x)
return oftype(logpdf(d.normal, _inv_softmax1(abs.(x))), -Inf)
else
return logpdf(d.normal, _inv_softmax1(x)) - sum(log, x)
end
end

function gradlogpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
y = _inv_softmax1(x)
∂y = gradlogpdf(d.normal, y)
∂x = (vcat(∂y, -sum(∂y)) .- 1) ./ x
return ∂x
end

# Statistics

kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.normal)

# Sampling

function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
y = @views _drop1(x)
rand!(rng, d.normal, y)
_softmax1!(x, y)
return x
end

# Fitting

function fit_mle(::Type{MvLogitNormal{D}}, x::AbstractMatrix{<:Real}; kwargs...) where {D}
y = similar(x, size(x, 1) - 1, size(x, 2))
map(_inv_softmax1!, eachcol(y), eachcol(x))
normal = fit_mle(D, y; kwargs...)
return MvLogitNormal(normal)
end
function fit_mle(::Type{MvLogitNormal}, x::AbstractMatrix{<:Real}; kwargs...)
return fit_mle(MvLogitNormal{MvNormal}, x; kwargs...)
end

# Utility

function _softmax1!(x::AbstractVector, y::AbstractVector)
u = max(0, maximum(y))
_drop1(x) .= exp.(y .- u)
x[end] = exp(-u)
LinearAlgebra.normalize!(x, 1)
return x
end
function _softmax1!(x::AbstractMatrix, y::AbstractMatrix)
map(_softmax1!, eachcol(x), eachcol(y))
return x
end

_drop1(x::AbstractVector) = @views x[firstindex(x, 1):(end - 1)]
_drop1(x::AbstractMatrix) = @views x[firstindex(x, 1):(end - 1), :]

_last1(x::AbstractVector) = x[end]
_last1(x::AbstractMatrix) = @views x[end, :]

function _inv_softmax1!(y::AbstractVecOrMat, x::AbstractVecOrMat)
x₋ = _drop1(x)
xd = _last1(x)
@. y = log(x₋) - log(xd)
return y
end
function _inv_softmax1(x::AbstractVecOrMat)
y = similar(_drop1(x))
_inv_softmax1!(y, x)
return y
end
1 change: 1 addition & 0 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ for fname in ["dirichlet.jl",
"jointorderstatistics.jl",
"mvnormal.jl",
"mvnormalcanon.jl",
"mvlogitnormal.jl",
"mvlognormal.jl",
"mvtdist.jl",
"product.jl", # deprecated
Expand Down
158 changes: 158 additions & 0 deletions test/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Tests on Multivariate Logit-Normal distributions
using Distributions
using ForwardDiff
using LinearAlgebra
using Random
using Test

####### Core testing procedure

function test_mvlogitnormal(d::MvLogitNormal; nsamples::Int=10^6)
@test d.normal isa AbstractMvNormal
dnorm = d.normal

@testset "properties" begin
@test length(d) == length(dnorm) + 1
@test params(d) == params(dnorm)
@test partype(d) == partype(dnorm)
@test eltype(d) == eltype(dnorm)
@test eltype(typeof(d)) == eltype(typeof(dnorm))
@test location(d) == mean(dnorm)
@test minimum(d) == fill(0, length(d))
@test maximum(d) == fill(1, length(d))
@test insupport(d, normalize(rand(length(d)), 1))
@test !insupport(d, normalize(rand(length(d) + 1), 1))
@test !insupport(d, rand(length(d)))
x = rand(length(d) - 1)
x = vcat(x, -sum(x))
@test !insupport(d, x)
end

@testset "conversions" begin
@test convert(typeof(d), d) === d
T = partype(d) <: Float64 ? Float32 : Float64
if dnorm isa MvNormal
@test convert(MvLogitNormal{MvNormal{T}}, d).normal ==
convert(MvNormal{T}, dnorm)
@test partype(convert(MvLogitNormal{MvNormal{T}}, d)) <: T
@test canonform(d) isa MvLogitNormal{<:MvNormalCanon}
@test canonform(d).normal == canonform(dnorm)
elseif dnorm isa MvNormalCanon
@test convert(MvLogitNormal{MvNormalCanon{T}}, d).normal ==
convert(MvNormalCanon{T}, dnorm)
@test partype(convert(MvLogitNormal{MvNormalCanon{T}}, d)) <: T
@test meanform(d) isa MvLogitNormal{<:MvNormal}
@test meanform(d).normal == meanform(dnorm)
end
end

@testset "sampling" begin
X = rand(d, nsamples)
Y = @views log.(X[1:(end - 1), :]) .- log.(X[end, :]')
Ymean = vec(mean(Y; dims=2))
Ycov = cov(Y; dims=2)
for i in 1:(length(d) - 1)
@test isapprox(
Ymean[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
)
end
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
@test isapprox(
Ycov[i, j],
cov(dnorm)[i, j],
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
)
end
end

@testset "fitting" begin
X = rand(d, nsamples)
dfit = fit_mle(MvLogitNormal, X)
dfit_norm = dfit.normal
for i in 1:(length(d) - 1)
@test isapprox(
mean(dfit_norm)[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
)
end
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
@test isapprox(
cov(dfit_norm)[i, j],
cov(dnorm)[i, j],
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
)
end
@test fit_mle(MvLogitNormal{IsoNormal}, X) isa MvLogitNormal{<:IsoNormal}
end

@testset "evaluation" begin
X = rand(d, nsamples)
for i in 1:min(100, nsamples)
@test @inferred(logpdf(d, X[:, i])) log(pdf(d, X[:, i]))
if dnorm isa MvNormal
@test @inferred(gradlogpdf(d, X[:, i]))
ForwardDiff.gradient(x -> logpdf(d, x), X[:, i])
end
end
@test logpdf(d, X) log.(pdf(d, X))
@test isequal(logpdf(d, zeros(length(d))), -Inf)
@test isequal(logpdf(d, ones(length(d))), -Inf)
@test isequal(pdf(d, zeros(length(d))), 0)
@test isequal(pdf(d, ones(length(d))), 0)
end
end

@testset "Results MvLogitNormal consistent with univariate LogitNormal" begin
μ = randn()
σ = rand()
d = MvLogitNormal([μ], fill^2, 1, 1))
duni = LogitNormal(μ, σ)
@test location(d) [location(duni)]
x = normalize(rand(2), 1)
@test logpdf(d, x) logpdf(duni, x[1])
@test pdf(d, x) pdf(duni, x[1])
@test (Random.seed!(9274); rand(d)[1]) (Random.seed!(9274); rand(duni))
end

###### General Testing

@testset "MvLogitNormal tests" begin
mvnorm_params = [
(randn(5), I * rand()),
(randn(4), Diagonal(rand(4))),
(Diagonal(rand(6)),),
(randn(5), exp(Symmetric(randn(5, 5)))),
(exp(Symmetric(randn(5, 5))),),
]
@testset "wraps MvNormal" begin
@testset "$(typeof(prms))" for prms in mvnorm_params
d = MvLogitNormal(prms...)
@test d == MvLogitNormal(MvNormal(prms...))
test_mvlogitnormal(d; nsamples=10^4)
end
end
@testset "wraps MvNormalCanon" begin
@testset "$(typeof(prms))" for prms in mvnorm_params
d = MvLogitNormal(MvNormalCanon(prms...))
test_mvlogitnormal(d; nsamples=10^4)
end
end

@testset "kldivergence" begin
d1 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
d2 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
@test kldivergence(d1, d2) kldivergence(d1.normal, d2.normal)
end

VERSION v"1.8" && @testset "show" begin
d = MvLogitNormal([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0]))
@test sprint(show, d) === """
MvLogitNormal{DiagNormal}(
DiagNormal(
dim: 3
μ: [1.0, 2.0, 3.0]
Σ: [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0]
)
)
"""
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const tests = [
"univariate/continuous/uniform",
"univariate/continuous/lognormal",
"multivariate/mvnormal",
"multivariate/mvlogitnormal",
"multivariate/mvlognormal",
"types", # extra file compared to /src
"utils",
Expand Down

0 comments on commit cd5d4cc

Please sign in to comment.