-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
additionally filter specific non-relevant warning messages with JET.test_package
- Loading branch information
Showing
15 changed files
with
279 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
```@meta | ||
CurrentModule = DistributionFits | ||
``` | ||
|
||
# Multivariate LogNormal distribution | ||
|
||
Can be fitted to a given mean, provided the Covariance of the underlying | ||
normal distribution. | ||
|
||
```@docs | ||
fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real} | ||
``` | ||
|
||
```jldoctest; output = false, setup = :(using DistributionFits) | ||
Σ = hcat([0.6,0.02],[0.02,0.7]) | ||
μ = [1.2,1.3] | ||
d = MvLogNormal(μ, Σ) | ||
d2 = fit_mean_Σ(MvLogNormal, mean(d), Σ) | ||
isapprox(d2, d, rtol = 1e6) | ||
# output | ||
true | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
fit_mean_Σ(::Type{<:Distribution}, mean, Σ) | ||
Fit a Distribution to mean and uncertainty quantificator Σ. | ||
The meaning of `Σ` depends on the type of distribution: | ||
- `MvLogNormal`: the Covariancematrix of the associated normal distribution | ||
- `LogNormal`: the scale parameter, i.e. the standard deviation at log-scale, `σ` | ||
""" | ||
function fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real} | ||
_T = promote_type(T1, T2) | ||
fit_mean_Σ(MvLogNormal{_T}, mean, Σ) | ||
end | ||
function fit_mean_Σ(::Type{MvLogNormal{T}}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T, T1 <:Real,T2 <:Real} | ||
meanT = T1 == T ? mean : begin | ||
meanT = similar(mean, T) | ||
meanT .= mean | ||
end | ||
ΣT = T2 == T ? Σ : begin | ||
ΣT = similar(Σ, T) | ||
ΣT .= Σ | ||
end | ||
σ2 = diag(ΣT) | ||
μ = log.(meanT) .- σ2 ./ 2 | ||
MvLogNormal(μ, ΣT) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
##### Specific distributions ##### | ||
|
||
for fname in [ | ||
# "dirichlet.jl", | ||
# "multinomial.jl", | ||
# "dirichletmultinomial.jl", | ||
# "jointorderstatistics.jl", | ||
# "mvnormal.jl", | ||
# "mvnormalcanon.jl", | ||
# "mvlogitnormal.jl", | ||
"mvlognormal.jl", | ||
# "mvtdist.jl", | ||
# "vonmisesfisher.jl" | ||
] | ||
include(joinpath("multivariate", fname)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using PDMats | ||
using DistributionFits | ||
using Test | ||
|
||
@testset "fit_mean_Σ" begin | ||
Σ = PDiagMat([0.6,0.7]) | ||
μ = [1.2,1.2] | ||
d = MvLogNormal(μ, Σ) | ||
mean(d) | ||
d2 = fit_mean_Σ(MvLogNormal, mean(d), params(d)[2]) | ||
@test d2 ≈ d rtol = 1e6 | ||
# | ||
# Float32 | ||
d2_f32 = fit_mean_Σ(MvLogNormal, Float32.(mean(d)), Float32.(params(d)[2])) | ||
@test d2 ≈ d rtol = 1e6 | ||
@test partype(d2_f32) === Float32 | ||
end; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
using DistributionFits | ||
using Test | ||
using Random: Random | ||
using LoggingExtras | ||
using Optim | ||
|
||
pkgdir = dirname(dirname(pathof(DistributionFits))) | ||
testdir = joinpath(pkgdir, "test") | ||
include(joinpath(testdir,"testutils.jl")) | ||
|
||
function test_univariate_fits(d, D = typeof(d)) | ||
@testset "fit moments" begin | ||
if !occursin("fit(::Type{D}", | ||
string(first(methods(fit, (Type{typeof(d)}, AbstractMoments))))) | ||
m = Moments(mean(d), var(d)) | ||
d_fit = fit(D, m) | ||
@test d ≈ d_fit | ||
@test partype(d_fit) == partype(d) | ||
end | ||
end | ||
@testset "fit two quantiles" begin | ||
qpl = @qp_l(quantile(d, 0.05)) | ||
qpu = @qp_u(quantile(d, 0.95)) | ||
d_fit = fit(D, qpl, qpu) | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
d_fit = fit(D, qpl, qpu) | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
d_fit = fit(D, qpu, qpl) # sort | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
@test partype(d_fit) == partype(d) | ||
end | ||
@testset "fit two quantiles, function version" begin | ||
P = partype(d) | ||
qpl = qp_l(P(quantile(d, 0.05))) | ||
qpu = qp_u(P(quantile(d, 0.95))) | ||
d_fit = fit(D, qpl, qpu) | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
d_fit = fit(D, qpl, qpu) | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
d_fit = fit(D, qpu, qpl) # sort | ||
@test quantile.(d, [qpl.p, qpu.p]) ≈ [qpl.q, qpu.q] | ||
@test partype(d_fit) == partype(d) | ||
end | ||
@testset "typeof mean, mode equals partype" begin | ||
if !(d isa Gamma && first(params(d)) < 1) | ||
@test mean(d) isa partype(d) | ||
@test mode(d) isa partype(d) | ||
end | ||
end | ||
@testset "quantile is of eltype" begin | ||
# quantile still Float64 for Normal of eltype Float32 | ||
if d isa Normal && eltype(d) != Float64 | ||
@test_broken quantile(d, 0.1) isa eltype(d) | ||
else | ||
@test quantile(d, 0.1) isa eltype(d) | ||
end | ||
# quantile is sample-like: stick to eltype - special of normal | ||
# broken, because quantile Normal{Float32} returns Float32 | ||
# but eltype(D{Float32}) is Float64 | ||
if d isa Union{LogNormal, LogitNormal, Exponential, Laplace, Weibull} && | ||
partype(d) != eltype(d) | ||
@test_broken quantile(d, 0.1f0) isa eltype(d) | ||
else | ||
@test quantile(d, 0.1f0) isa eltype(d) | ||
end | ||
end | ||
@testset "fit to quantilepoint and mean" begin | ||
if !occursin("fit_mean_quantile(::Type{D}", | ||
string(first(methods(fit_mean_quantile, | ||
(Type{typeof(d)}, partype(d), QuantilePoint))))) | ||
m = log(mean(d)) | ||
qp = @qp_u(quantile(d, 0.95)) | ||
logger = d isa Exponential ? MinLevelLogger(current_logger(), Logging.Error) : | ||
current_logger() | ||
with_logger(logger) do | ||
d_fit = fit_mean_quantile(D, mean(d), qp) | ||
@test d_fit ≈ d | ||
@test partype(d_fit) == partype(d) | ||
d_fit = fit(D, mean(d), qp, Val(:mean)) | ||
@test d_fit ≈ d | ||
@test partype(d_fit) == partype(d) | ||
# with lower quantile | ||
qp = @qp_l(quantile(d, 0.05)) | ||
d_fit = fit_mean_quantile(D, mean(d), qp) | ||
@test d_fit ≈ d | ||
@test partype(d_fit) == partype(d) | ||
end | ||
# very close to mean can give very different results: | ||
# qp = @qp(mean(d)-1e-4,0.95) | ||
# d_fit = fit_mean_quantile(D, mean(d), qp) | ||
# @test mean(d_fit) ≈ mean(d) && quantile(d_fit, qp.p) ≈ qp.q | ||
end | ||
end | ||
@testset "fit to quantilepoint and mode" begin | ||
if !(d isa Gamma && first(params(d)) < 1) && | ||
!(d isa Weibull) | ||
qp = qp_u(quantile(d, 0.95)) | ||
d_fit = fit_mode_quantile(D, mode(d), qp) | ||
@test d_fit≈d atol=0.1 | ||
d_fit = fit(D, mode(d), qp, Val(:mode)) | ||
@test d_fit≈d atol=0.1 | ||
@test partype(d_fit) == partype(d) | ||
# with lower quantile | ||
qp = qp_ll(quantile(d, 0.025)) | ||
d_fit = fit(D, mode(d), qp, Val(:mode)) | ||
@test mode(d_fit) ≈ mode(d) | ||
@test quantile(d_fit, qp.p)≈qp.q atol=0.01 | ||
@test partype(d_fit) == partype(d) | ||
end | ||
end | ||
@testset "fit to quantilepoint and median" begin | ||
qp = @qp_u(quantile(d, 0.95)) | ||
logger = d isa Exponential ? MinLevelLogger(current_logger(), Logging.Error) : | ||
current_logger() | ||
with_logger(logger) do | ||
d_fit = fit(D, median(d), qp, Val(:median)) | ||
@test d_fit ≈ d | ||
@test partype(d_fit) == partype(d) | ||
end | ||
end | ||
end | ||
|
||
|
||
const tests = [ | ||
"mvlognormal", | ||
] | ||
#tests = ["mvlognormal"] | ||
|
||
for t in tests | ||
@testset "Test $t" begin | ||
Random.seed!(345679) | ||
include(joinpath(testdir,"multivariate","$t.jl")) | ||
end | ||
end |
Oops, something went wrong.