diff --git a/docs/src/multivariate.md b/docs/src/multivariate.md index 2ae5ef59b..c4e7c1764 100644 --- a/docs/src/multivariate.md +++ b/docs/src/multivariate.md @@ -55,6 +55,7 @@ Multinomial Distributions.AbstractMvNormal MvNormal MvNormalCanon +MvLogitNormal MvLogNormal Dirichlet Product diff --git a/src/Distributions.jl b/src/Distributions.jl index 1e2d580c5..7e4f6bd62 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -122,6 +122,7 @@ export Logistic, LogNormal, LogUniform, + MvLogitNormal, LogitNormal, MatrixBeta, MatrixFDist, diff --git a/src/multivariate/mvlogitnormal.jl b/src/multivariate/mvlogitnormal.jl new file mode 100644 index 000000000..0d60ddf65 --- /dev/null +++ b/src/multivariate/mvlogitnormal.jl @@ -0,0 +1,140 @@ +""" + MvLogitNormal{<:AbstractMvNormal} + +The [multivariate logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization) +is a multivariate generalization of [`LogitNormal`](@ref) capable of handling correlations +between variables. + +If ``\\mathbf{y} \\sim \\mathrm{MvNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})`` is a +length ``d-1`` vector, then +```math +\\mathbf{x} = \\operatorname{softmax}\\left(\\begin{bmatrix}\\mathbf{y} \\\\ 0 \\end{bmatrix}\\right) \\sim \\mathrm{MvLogitNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma}) +``` +is a length ``d`` probability vector. + +```julia +MvLogitNormal(μ, Σ) # MvLogitNormal with y ~ MvNormal(μ, Σ) +MvLogitNormal(MvNormal(μ, Σ)) # same as above +MvLogitNormal(MvNormalCanon(μ, J)) # MvLogitNormal with y ~ MvNormalCanon(μ, J) +``` + +# Fields + +- `normal::AbstractMvNormal`: contains the ``d-1``-dimensional distribution of ``y`` +""" +struct MvLogitNormal{D<:AbstractMvNormal} <: ContinuousMultivariateDistribution + normal::D + MvLogitNormal{D}(normal::D) where {D<:AbstractMvNormal} = new{D}(normal) +end +MvLogitNormal(d::AbstractMvNormal) = MvLogitNormal{typeof(d)}(d) +MvLogitNormal(args...) = MvLogitNormal(MvNormal(args...)) + +function Base.show(io::IO, d::MvLogitNormal; indent::String=" ") + print(io, distrname(d)) + println(io, "(") + normstr = strip(sprint(show, d.normal; context=IOContext(io))) + normstr = replace(normstr, "\n" => "\n$indent") + print(io, indent) + println(io, normstr) + println(io, ")") +end + +# Conversions + +function convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal) where {D} + return MvLogitNormal(convert(D, d.normal)) +end +Base.convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal{D}) where {D} = d + +meanform(d::MvLogitNormal{<:MvNormalCanon}) = MvLogitNormal(meanform(d.normal)) +canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal)) + +# Properties + +length(d::MvLogitNormal) = length(d.normal) + 1 +Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D) +Base.eltype(d::MvLogitNormal) = eltype(d.normal) +params(d::MvLogitNormal) = params(d.normal) +@inline partype(d::MvLogitNormal) = partype(d.normal) + +location(d::MvLogitNormal) = mean(d.normal) +minimum(d::MvLogitNormal) = fill(zero(eltype(d)), length(d)) +maximum(d::MvLogitNormal) = fill(oneunit(eltype(d)), length(d)) + +function insupport(d::MvLogitNormal, x::AbstractVector{<:Real}) + return length(d) == length(x) && all(≥(0), x) && sum(x) ≈ 1 +end + +# Evaluation + +function _logpdf(d::MvLogitNormal, x::AbstractVector{<:Real}) + if !insupport(d, x) + return oftype(logpdf(d.normal, _inv_softmax1(abs.(x))), -Inf) + else + return logpdf(d.normal, _inv_softmax1(x)) - sum(log, x) + end +end + +function gradlogpdf(d::MvLogitNormal, x::AbstractVector{<:Real}) + y = _inv_softmax1(x) + ∂y = gradlogpdf(d.normal, y) + ∂x = (vcat(∂y, -sum(∂y)) .- 1) ./ x + return ∂x +end + +# Statistics + +kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.normal) + +# Sampling + +function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real}) + y = @views _drop1(x) + rand!(rng, d.normal, y) + _softmax1!(x, y) + return x +end + +# Fitting + +function fit_mle(::Type{MvLogitNormal{D}}, x::AbstractMatrix{<:Real}; kwargs...) where {D} + y = similar(x, size(x, 1) - 1, size(x, 2)) + map(_inv_softmax1!, eachcol(y), eachcol(x)) + normal = fit_mle(D, y; kwargs...) + return MvLogitNormal(normal) +end +function fit_mle(::Type{MvLogitNormal}, x::AbstractMatrix{<:Real}; kwargs...) + return fit_mle(MvLogitNormal{MvNormal}, x; kwargs...) +end + +# Utility + +function _softmax1!(x::AbstractVector, y::AbstractVector) + u = max(0, maximum(y)) + _drop1(x) .= exp.(y .- u) + x[end] = exp(-u) + LinearAlgebra.normalize!(x, 1) + return x +end +function _softmax1!(x::AbstractMatrix, y::AbstractMatrix) + map(_softmax1!, eachcol(x), eachcol(y)) + return x +end + +_drop1(x::AbstractVector) = @views x[firstindex(x, 1):(end - 1)] +_drop1(x::AbstractMatrix) = @views x[firstindex(x, 1):(end - 1), :] + +_last1(x::AbstractVector) = x[end] +_last1(x::AbstractMatrix) = @views x[end, :] + +function _inv_softmax1!(y::AbstractVecOrMat, x::AbstractVecOrMat) + x₋ = _drop1(x) + xd = _last1(x) + @. y = log(x₋) - log(xd) + return y +end +function _inv_softmax1(x::AbstractVecOrMat) + y = similar(_drop1(x)) + _inv_softmax1!(y, x) + return y +end diff --git a/src/multivariates.jl b/src/multivariates.jl index 7a6f926a6..56d91233c 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -115,6 +115,7 @@ for fname in ["dirichlet.jl", "jointorderstatistics.jl", "mvnormal.jl", "mvnormalcanon.jl", + "mvlogitnormal.jl", "mvlognormal.jl", "mvtdist.jl", "product.jl", # deprecated diff --git a/test/multivariate/mvlogitnormal.jl b/test/multivariate/mvlogitnormal.jl new file mode 100644 index 000000000..cb53c9b37 --- /dev/null +++ b/test/multivariate/mvlogitnormal.jl @@ -0,0 +1,158 @@ +# Tests on Multivariate Logit-Normal distributions +using Distributions +using ForwardDiff +using LinearAlgebra +using Random +using Test + +####### Core testing procedure + +function test_mvlogitnormal(d::MvLogitNormal; nsamples::Int=10^6) + @test d.normal isa AbstractMvNormal + dnorm = d.normal + + @testset "properties" begin + @test length(d) == length(dnorm) + 1 + @test params(d) == params(dnorm) + @test partype(d) == partype(dnorm) + @test eltype(d) == eltype(dnorm) + @test eltype(typeof(d)) == eltype(typeof(dnorm)) + @test location(d) == mean(dnorm) + @test minimum(d) == fill(0, length(d)) + @test maximum(d) == fill(1, length(d)) + @test insupport(d, normalize(rand(length(d)), 1)) + @test !insupport(d, normalize(rand(length(d) + 1), 1)) + @test !insupport(d, rand(length(d))) + x = rand(length(d) - 1) + x = vcat(x, -sum(x)) + @test !insupport(d, x) + end + + @testset "conversions" begin + @test convert(typeof(d), d) === d + T = partype(d) <: Float64 ? Float32 : Float64 + if dnorm isa MvNormal + @test convert(MvLogitNormal{MvNormal{T}}, d).normal == + convert(MvNormal{T}, dnorm) + @test partype(convert(MvLogitNormal{MvNormal{T}}, d)) <: T + @test canonform(d) isa MvLogitNormal{<:MvNormalCanon} + @test canonform(d).normal == canonform(dnorm) + elseif dnorm isa MvNormalCanon + @test convert(MvLogitNormal{MvNormalCanon{T}}, d).normal == + convert(MvNormalCanon{T}, dnorm) + @test partype(convert(MvLogitNormal{MvNormalCanon{T}}, d)) <: T + @test meanform(d) isa MvLogitNormal{<:MvNormal} + @test meanform(d).normal == meanform(dnorm) + end + end + + @testset "sampling" begin + X = rand(d, nsamples) + Y = @views log.(X[1:(end - 1), :]) .- log.(X[end, :]') + Ymean = vec(mean(Y; dims=2)) + Ycov = cov(Y; dims=2) + for i in 1:(length(d) - 1) + @test isapprox( + Ymean[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8 + ) + end + for i in 1:(length(d) - 1), j in 1:(length(d) - 1) + @test isapprox( + Ycov[i, j], + cov(dnorm)[i, j], + atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20, + ) + end + end + + @testset "fitting" begin + X = rand(d, nsamples) + dfit = fit_mle(MvLogitNormal, X) + dfit_norm = dfit.normal + for i in 1:(length(d) - 1) + @test isapprox( + mean(dfit_norm)[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8 + ) + end + for i in 1:(length(d) - 1), j in 1:(length(d) - 1) + @test isapprox( + cov(dfit_norm)[i, j], + cov(dnorm)[i, j], + atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20, + ) + end + @test fit_mle(MvLogitNormal{IsoNormal}, X) isa MvLogitNormal{<:IsoNormal} + end + + @testset "evaluation" begin + X = rand(d, nsamples) + for i in 1:min(100, nsamples) + @test @inferred(logpdf(d, X[:, i])) ≈ log(pdf(d, X[:, i])) + if dnorm isa MvNormal + @test @inferred(gradlogpdf(d, X[:, i])) ≈ + ForwardDiff.gradient(x -> logpdf(d, x), X[:, i]) + end + end + @test logpdf(d, X) ≈ log.(pdf(d, X)) + @test isequal(logpdf(d, zeros(length(d))), -Inf) + @test isequal(logpdf(d, ones(length(d))), -Inf) + @test isequal(pdf(d, zeros(length(d))), 0) + @test isequal(pdf(d, ones(length(d))), 0) + end +end + +@testset "Results MvLogitNormal consistent with univariate LogitNormal" begin + μ = randn() + σ = rand() + d = MvLogitNormal([μ], fill(σ^2, 1, 1)) + duni = LogitNormal(μ, σ) + @test location(d) ≈ [location(duni)] + x = normalize(rand(2), 1) + @test logpdf(d, x) ≈ logpdf(duni, x[1]) + @test pdf(d, x) ≈ pdf(duni, x[1]) + @test (Random.seed!(9274); rand(d)[1]) ≈ (Random.seed!(9274); rand(duni)) +end + +###### General Testing + +@testset "MvLogitNormal tests" begin + mvnorm_params = [ + (randn(5), I * rand()), + (randn(4), Diagonal(rand(4))), + (Diagonal(rand(6)),), + (randn(5), exp(Symmetric(randn(5, 5)))), + (exp(Symmetric(randn(5, 5))),), + ] + @testset "wraps MvNormal" begin + @testset "$(typeof(prms))" for prms in mvnorm_params + d = MvLogitNormal(prms...) + @test d == MvLogitNormal(MvNormal(prms...)) + test_mvlogitnormal(d; nsamples=10^4) + end + end + @testset "wraps MvNormalCanon" begin + @testset "$(typeof(prms))" for prms in mvnorm_params + d = MvLogitNormal(MvNormalCanon(prms...)) + test_mvlogitnormal(d; nsamples=10^4) + end + end + + @testset "kldivergence" begin + d1 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5)))) + d2 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5)))) + @test kldivergence(d1, d2) ≈ kldivergence(d1.normal, d2.normal) + end + + VERSION ≥ v"1.8" && @testset "show" begin + d = MvLogitNormal([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0])) + @test sprint(show, d) === """ + MvLogitNormal{DiagNormal}( + DiagNormal( + dim: 3 + μ: [1.0, 2.0, 3.0] + Σ: [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0] + ) + ) + """ + end +end diff --git a/test/runtests.jl b/test/runtests.jl index cb724278b..ce3f16b79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ const tests = [ "univariate/continuous/uniform", "univariate/continuous/lognormal", "multivariate/mvnormal", + "multivariate/mvlogitnormal", "multivariate/mvlognormal", "types", # extra file compared to /src "utils",