Skip to content

Commit

Permalink
add convenience constructors for LocationScaleLowRank
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Aug 9, 2024
1 parent b24737f commit 1d56953
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
8 changes: 6 additions & 2 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,14 @@ include("objectives/elbo/repgradelbo.jl")
export
MvLocationScale,
MeanFieldGaussian,
FullRankGaussian,
MvLocationScaleLowRank
FullRankGaussian

include("families/location_scale.jl")

export
MvLocationScaleLowRank,
LowRankGaussian

include("families/location_scale_low_rank.jl")


Expand Down
12 changes: 6 additions & 6 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ function Distributions.cov(q::MvLocationScale)
end

"""
FullRankGaussian(location, scale; check_args = true)
FullRankGaussian(μ, L; check_args = true)
Construct a Gaussian variational approximation with a dense covariance matrix.
# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
Expand All @@ -142,13 +142,13 @@ function FullRankGaussian(
end

"""
MeanFieldGaussian(location, scale; check_args = true)
MeanFieldGaussian(μ, L; check_args = true)
Construct a Gaussian variational approximation with a diagonal covariance matrix.
# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
Expand Down
25 changes: 25 additions & 0 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
MvLocationLowRankScale(location, scale_diag, scale_factors, dist) <: ContinuousMultivariateDistribution
Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix.
The rank is given by `size(scale_factors, 2)`.
It generally represents any distribution for which the sampling path can be
represented as follows:
Expand Down Expand Up @@ -135,3 +136,27 @@ function update_variational_params!(

opt_st, params
end

"""
LowRankGaussian(location, scale_diag, scale_factors; check_args = true)
Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix.
# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `D::Vector{T}`: Diagonal of the scale.
- `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank.
# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function LowRankGaussian(
μ::AbstractVector{T},
D::Vector{T},
U::Matrix{T};
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
@assert minimum(D) sqrt(scale_eps) "Initial scale is too small (smallest diagonal scale value is $(minimum(D)). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScaleLowRank(μ, D, U, q_base, scale_eps)
end
4 changes: 1 addition & 3 deletions test/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
Σ = Diagonal(D.^2) + U*U'

q = if basedist == :gaussian
MvLocationScaleLowRank(
μ, D, U, Normal{realtype}(zero(realtype), one(realtype))
)
LowRankGaussian(μ, D, U)
end
q_true = if basedist == :gaussian
MvNormal(μ, Σ)
Expand Down

0 comments on commit 1d56953

Please sign in to comment.