Skip to content

Commit

Permalink
implement and document fit_mean_Σ
Browse files Browse the repository at this point in the history
additionally filter specific non-relevant warning messages with JET.test_package
  • Loading branch information
bgctw authored Jan 12, 2024
1 parent a9c6509 commit eb6a57d
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 15 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.3.7-DEV"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -28,6 +29,7 @@ Requires = "1.2"
StaticArrays = "1.2"
Statistics = "1"
StatsAPI = "1.6"
LinearAlgebra = "1.6"
StatsFuns = "0.9.15, 1"
julia = "1.6"

Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ makedocs(;
"LogitNormal" => "logitnormal.md",
"Weibull" => "weibull.md",
"Gamma" => "gamma.md",
"MvLogNormal" => "mvlognormal.md",
],
"Dependencies" => "set_optimize.md",
"API" => "api.md",
Expand Down
19 changes: 17 additions & 2 deletions docs/src/lognormal.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,29 @@ d = LogNormal(log(2), log(1.2))
true
```

Alternatively the distribution can be specified by its mean and ``\sigma^*`` using type [`AbstractΣstar`](@ref)
Alternatively the distribution can be specified by its mean and either
- Multiplicative standard deviation,``\sigma^*``, using type [`AbstractΣstar`](@ref)
- Standard deviation at log-scale, ``\sigma``, or
- relative error, ``cv``.

```jldoctest; output = false, setup = :(using DistributionFits,Optim)
d = fit(LogNormal, 2, Σstar(1.2))
(mean(d), σstar(d)) == (2, 1.2)
# output
true
```
```jldoctest; output = false, setup = :(using DistributionFits,Optim)
d = fit_mean_Σ(LogNormal, 2, 1.2)
(mean(d), d.σ) == (2, 1.2)
# output
true
```
```jldoctest; output = false, setup = :(using DistributionFits,Optim)
d = fit_mean_relerror(LogNormal, 2, 0.2)
(mean(d), std(d)/mean(d)) .≈ (2, 0.2)
# output
(true, true)
```

## Detailed API

Expand All @@ -37,7 +52,7 @@ true
```

```@docs
fit(::Type{LogNormal}, ::T, ::AbstractΣstar) where T<:Real
fit(d::Type{LogNormal}, mean, σstar::AbstractΣstar)
```

```@docs
Expand Down
22 changes: 22 additions & 0 deletions docs/src/mvlognormal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
```@meta
CurrentModule = DistributionFits
```

# Multivariate LogNormal distribution

Can be fitted to a given mean, provided the Covariance of the underlying
normal distribution.

```@docs
fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real}
```

```jldoctest; output = false, setup = :(using DistributionFits)
Σ = hcat([0.6,0.02],[0.02,0.7])
μ = [1.2,1.3]
d = MvLogNormal(μ, Σ)
d2 = fit_mean_Σ(MvLogNormal, mean(d), Σ)
isapprox(d2, d, rtol = 1e6)
# output
true
```
5 changes: 4 additions & 1 deletion src/DistributionFits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Reexport

using FillArrays, StaticArrays
using StatsFuns: logit, logistic, normcdf
using LinearAlgebra
#using Infiltrator

if !isdefined(Base, :get_extension)
Expand All @@ -26,7 +27,8 @@ export
@qs_cf90, @qs_cf95,
qp, qp_ll, qp_l, qp_m, qp_u, qp_uu,
qs_cf90, qs_cf95,
fit_mean_relerror
fit_mean_relerror,
fit_mean_Σ

# document but do not export - need to qualify by 'DistributionFits.'
# export
Expand All @@ -53,5 +55,6 @@ end
# fitting distributions to stats
include("fitstats.jl")
include("univariates.jl")
include("multivariates.jl")

end
26 changes: 26 additions & 0 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
fit_mean_Σ(::Type{<:Distribution}, mean, Σ)
Fit a Distribution to mean and uncertainty quantificator Σ.
The meaning of `Σ` depends on the type of distribution:
- `MvLogNormal`: the Covariancematrix of the associated normal distribution
- `LogNormal`: the scale parameter, i.e. the standard deviation at log-scale, `σ`
"""
function fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real}
_T = promote_type(T1, T2)
fit_mean_Σ(MvLogNormal{_T}, mean, Σ)
end
function fit_mean_Σ(::Type{MvLogNormal{T}}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T, T1 <:Real,T2 <:Real}
meanT = T1 == T ? mean : begin
meanT = similar(mean, T)
meanT .= mean
end
ΣT = T2 == T ? Σ : begin
ΣT = similar(Σ, T)
ΣT .= Σ
end
σ2 = diag(ΣT)
μ = log.(meanT) .- σ2 ./ 2
MvLogNormal(μ, ΣT)
end
16 changes: 16 additions & 0 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
##### Specific distributions #####

for fname in [
# "dirichlet.jl",
# "multinomial.jl",
# "dirichletmultinomial.jl",
# "jointorderstatistics.jl",
# "mvnormal.jl",
# "mvnormalcanon.jl",
# "mvlogitnormal.jl",
"mvlognormal.jl",
# "mvtdist.jl",
# "vonmisesfisher.jl"
]
include(joinpath("multivariate", fname))
end
27 changes: 18 additions & 9 deletions src/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,20 @@ true
σstar(d::LogNormal) = exp(params(d)[2])

"""
fit(D, mean, σstar)
fit(D, mean, σstar::AbstractΣstar)
fit_mean_Σ(D, mean, σ::Real)
Fit a statistical distribution of type `D` to mean and multiplicative
standard deviation.
standard deviation, `σstar`, or scale parameter at log-scale: `σ`.
# Arguments
- `D`: The type of distribution to fit
- `mean`: The moments of the distribution
- `σstar::AbstractΣstar`: The multiplicative standard deviation
- `σ`: The standard-deviation parameter at log-scale
See also [`σstar`](@ref), [`AbstractΣstar`](@ref).
The first version uses type [`AbstractΣstar`](@ref) to distinguish from
other methods of function fit.
# Examples
```jldoctest fm1; output = false, setup = :(using DistributionFits)
Expand All @@ -121,17 +124,23 @@ d = fit(LogNormal, 2, Σstar(1.1));
true
```
"""
function fit(::Type{LogNormal}, mean::T, σstar::AbstractΣstar) where {T <: Real}
_T = promote_type(T, eltype(σstar))
fit(LogNormal{_T}, mean, σstar)
function fit(d::Type{LogNormal}, mean, σstar::AbstractΣstar)
fit_mean_Σ(d, mean, log(σstar()))
end

function fit(::Type{LogNormal{T}}, mean::Real, σstar::AbstractΣstar) where {T}
σ = log(σstar())
function fit(d::Type{LogNormal{T}}, mean::Real, σstar::AbstractΣstar) where {T}
fit_mean_Σ(d, mean, log(σstar()))
end
function fit_mean_Σ(::Type{LogNormal}, mean::T1, σ::T2) where {T1 <: Real,T2 <: Real}
_T = promote_type(T1, T2)
fit_mean_Σ(LogNormal{_T}, mean, σ)
end
function fit_mean_Σ(::Type{LogNormal{T}}, mean::Real, σ::Real) where {T}
#σ = log(σstar())
μ = log(mean) - σ * σ / 2
LogNormal(T(μ), T(σ))
end


"""
fit_mean_relerror(D, mean, relerror)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
17 changes: 17 additions & 0 deletions test/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using PDMats
using DistributionFits
using Test

@testset "fit_mean_Σ" begin
Σ = PDiagMat([0.6,0.7])
μ = [1.2,1.2]
d = MvLogNormal(μ, Σ)
mean(d)
d2 = fit_mean_Σ(MvLogNormal, mean(d), params(d)[2])
@test d2 d rtol = 1e6
#
# Float32
d2_f32 = fit_mean_Σ(MvLogNormal, Float32.(mean(d)), Float32.(params(d)[2]))
@test d2 d rtol = 1e6
@test partype(d2_f32) === Float32
end;
134 changes: 134 additions & 0 deletions test/multivariate/test_multivariate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using DistributionFits
using Test
using Random: Random
using LoggingExtras
using Optim

pkgdir = dirname(dirname(pathof(DistributionFits)))
testdir = joinpath(pkgdir, "test")
include(joinpath(testdir,"testutils.jl"))

function test_univariate_fits(d, D = typeof(d))
@testset "fit moments" begin
if !occursin("fit(::Type{D}",
string(first(methods(fit, (Type{typeof(d)}, AbstractMoments)))))
m = Moments(mean(d), var(d))
d_fit = fit(D, m)
@test d d_fit
@test partype(d_fit) == partype(d)
end
end
@testset "fit two quantiles" begin
qpl = @qp_l(quantile(d, 0.05))
qpu = @qp_u(quantile(d, 0.95))
d_fit = fit(D, qpl, qpu)
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
d_fit = fit(D, qpl, qpu)
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
d_fit = fit(D, qpu, qpl) # sort
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
@test partype(d_fit) == partype(d)
end
@testset "fit two quantiles, function version" begin
P = partype(d)
qpl = qp_l(P(quantile(d, 0.05)))
qpu = qp_u(P(quantile(d, 0.95)))
d_fit = fit(D, qpl, qpu)
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
d_fit = fit(D, qpl, qpu)
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
d_fit = fit(D, qpu, qpl) # sort
@test quantile.(d, [qpl.p, qpu.p]) [qpl.q, qpu.q]
@test partype(d_fit) == partype(d)
end
@testset "typeof mean, mode equals partype" begin
if !(d isa Gamma && first(params(d)) < 1)
@test mean(d) isa partype(d)
@test mode(d) isa partype(d)
end
end
@testset "quantile is of eltype" begin
# quantile still Float64 for Normal of eltype Float32
if d isa Normal && eltype(d) != Float64
@test_broken quantile(d, 0.1) isa eltype(d)
else
@test quantile(d, 0.1) isa eltype(d)
end
# quantile is sample-like: stick to eltype - special of normal
# broken, because quantile Normal{Float32} returns Float32
# but eltype(D{Float32}) is Float64
if d isa Union{LogNormal, LogitNormal, Exponential, Laplace, Weibull} &&
partype(d) != eltype(d)
@test_broken quantile(d, 0.1f0) isa eltype(d)
else
@test quantile(d, 0.1f0) isa eltype(d)
end
end
@testset "fit to quantilepoint and mean" begin
if !occursin("fit_mean_quantile(::Type{D}",
string(first(methods(fit_mean_quantile,
(Type{typeof(d)}, partype(d), QuantilePoint)))))
m = log(mean(d))
qp = @qp_u(quantile(d, 0.95))
logger = d isa Exponential ? MinLevelLogger(current_logger(), Logging.Error) :
current_logger()
with_logger(logger) do
d_fit = fit_mean_quantile(D, mean(d), qp)
@test d_fit d
@test partype(d_fit) == partype(d)
d_fit = fit(D, mean(d), qp, Val(:mean))
@test d_fit d
@test partype(d_fit) == partype(d)
# with lower quantile
qp = @qp_l(quantile(d, 0.05))
d_fit = fit_mean_quantile(D, mean(d), qp)
@test d_fit d
@test partype(d_fit) == partype(d)
end
# very close to mean can give very different results:
# qp = @qp(mean(d)-1e-4,0.95)
# d_fit = fit_mean_quantile(D, mean(d), qp)
# @test mean(d_fit) ≈ mean(d) && quantile(d_fit, qp.p) ≈ qp.q
end
end
@testset "fit to quantilepoint and mode" begin
if !(d isa Gamma && first(params(d)) < 1) &&
!(d isa Weibull)
qp = qp_u(quantile(d, 0.95))
d_fit = fit_mode_quantile(D, mode(d), qp)
@test d_fitd atol=0.1
d_fit = fit(D, mode(d), qp, Val(:mode))
@test d_fitd atol=0.1
@test partype(d_fit) == partype(d)
# with lower quantile
qp = qp_ll(quantile(d, 0.025))
d_fit = fit(D, mode(d), qp, Val(:mode))
@test mode(d_fit) mode(d)
@test quantile(d_fit, qp.p)qp.q atol=0.01
@test partype(d_fit) == partype(d)
end
end
@testset "fit to quantilepoint and median" begin
qp = @qp_u(quantile(d, 0.95))
logger = d isa Exponential ? MinLevelLogger(current_logger(), Logging.Error) :
current_logger()
with_logger(logger) do
d_fit = fit(D, median(d), qp, Val(:median))
@test d_fit d
@test partype(d_fit) == partype(d)
end
end
end


const tests = [
"mvlognormal",
]
#tests = ["mvlognormal"]

for t in tests
@testset "Test $t" begin
Random.seed!(345679)
include(joinpath(testdir,"multivariate","$t.jl"))
end
end
Loading

0 comments on commit eb6a57d

Please sign in to comment.