Skip to content

Commit

Permalink
implment fit_mean_Σ for Normal and MvNormal (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgctw authored Jan 15, 2024
1 parent eb6a57d commit 638e6c3
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 14 deletions.
11 changes: 11 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ fit(::Type{D}, ::QuantilePoint, ::QuantilePoint) where {D<:Distribution}
```@docs
fit(::Type{D}, ::Any, ::QuantilePoint, ::Val{stats} = Val(:mean)) where {D<:Distribution, stats}
```

## Fit to mean and uncertainty parameter
For bayesian inversion it is often required to specify a distribution given
the expected value (the predction of the population value) and a description of
uncertainty of an observation.

```@docs
fit_mean_Σ
```


## Currently supported distributions
Univariate continuous
- Normal
Expand Down
4 changes: 0 additions & 4 deletions docs/src/mvlognormal.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ CurrentModule = DistributionFits
Can be fitted to a given mean, provided the Covariance of the underlying
normal distribution.

```@docs
fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real}
```

```jldoctest; output = false, setup = :(using DistributionFits)
Σ = hcat([0.6,0.02],[0.02,0.7])
μ = [1.2,1.3]
Expand Down
12 changes: 12 additions & 0 deletions src/fitstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,15 @@ end
# function fit_mean_quantile(d::Type{D}, mean::Real, qp::QuantilePoint) where D<:Distribution
# error("fit_mean_quantile not yet implemented for Distribution of type: $D")
# end

"""
fit_mean_Σ(::Type{<:Distribution}, mean, Σ)
Fit a Distribution to mean and uncertainty quantificator Σ.
The meaning of `Σ` depends on the type of distribution:
- `MvLogNormal`, `MvNormal`: the Covariancematrix of the associated normal distribution
- `LogNormal`, `Normal`: the scale parameter, i.e. the standard deviation at log-scale, `σ`
"""
function fit_mean_Σ end

9 changes: 0 additions & 9 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
"""
fit_mean_Σ(::Type{<:Distribution}, mean, Σ)
Fit a Distribution to mean and uncertainty quantificator Σ.
The meaning of `Σ` depends on the type of distribution:
- `MvLogNormal`: the Covariancematrix of the associated normal distribution
- `LogNormal`: the scale parameter, i.e. the standard deviation at log-scale, `σ`
"""
function fit_mean_Σ(::Type{MvLogNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real}
_T = promote_type(T1, T2)
fit_mean_Σ(MvLogNormal{_T}, mean, Σ)
Expand Down
15 changes: 15 additions & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function fit_mean_Σ(::Type{MvNormal}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T1 <:Real,T2 <:Real}
_T = promote_type(T1, T2)
fit_mean_Σ(MvNormal{_T}, mean, Σ)
end
function fit_mean_Σ(::Type{MvNormal{T}}, mean::AbstractVector{T1}, Σ::AbstractMatrix{T2}) where {T, T1 <:Real,T2 <:Real}
meanT = T1 == T ? mean : begin
meanT = similar(mean, T)
meanT .= mean
end
ΣT = T2 == T ? Σ : begin
ΣT = similar(Σ, T)
ΣT .= Σ
end
MvNormal(meanT, ΣT)
end
2 changes: 1 addition & 1 deletion src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ for fname in [
# "multinomial.jl",
# "dirichletmultinomial.jl",
# "jointorderstatistics.jl",
# "mvnormal.jl",
"mvnormal.jl",
# "mvnormalcanon.jl",
# "mvlogitnormal.jl",
"mvlognormal.jl",
Expand Down
9 changes: 9 additions & 0 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ end
function fit_mode_quantile(D::Type{Normal{T}}, mode::Real, qp::QuantilePoint) where {T}
fit(D, QuantilePoint(mode, 0.5), qp)
end

function fit_mean_Σ(::Type{Normal}, mean::T1, σ::T2) where {T1 <: Real, T2 <: Real}
_T = promote_type(T1,T2)
fit_mean_Σ(Normal{_T}, mean, σ)
end
function fit_mean_Σ(D::Type{Normal{T}}, mean::Real, σ::Real) where {T}
Normal{T}(mean, σ)
end

11 changes: 11 additions & 0 deletions test/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using DistributionFits
using Test

@testset "fit_mean_Σ" begin
m = [3.0f0, 4.0f0]
Σ = hcat([2.0f0,0.2],[0.2, 2.0f0])
d = fit_mean_Σ(MvNormal, m, Σ)
@test mean(d) == m
@test cov(d) == Σ
end;

1 change: 1 addition & 0 deletions test/multivariate/test_multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ end

const tests = [
"mvlognormal",
"mvnormal",
]
#tests = ["mvlognormal"]

Expand Down
8 changes: 8 additions & 0 deletions test/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ end;
d = Normal(3.0f0, 2.0f0)
test_univariate_fits(d)
end;

@testset "fit_mean_Σ" begin
m = 3.0f0
σ = 2.0f0
d = fit_mean_Σ(Normal, m, σ)
@test mean(d) == m
@test scale(d) == σ
end;

0 comments on commit 638e6c3

Please sign in to comment.