Skip to content

Commit

Permalink
add indirection for update step, add projection for LocationScale (#65
Browse files Browse the repository at this point in the history
)

* add indirection for update step, add projection for `LocationScale`
* add projection for `Bijectors` with `MvLocationScale`
  • Loading branch information
Red-Portal authored Jun 13, 2024
1 parent 95a83c3 commit 5ced9c2
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 22 deletions.
24 changes: 24 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt
if isdefined(Base, :get_extension)
using AdvancedVI
using Bijectors
using LinearAlgebra
using Optimisers
using Random
else
using ..AdvancedVI
using ..Bijectors
using ..LinearAlgebra
using ..Optimisers
using ..Random
end

function AdvancedVI.update_variational_params!(
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
restructure,
grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
Expand Down
27 changes: 27 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
"""
function value_and_gradient! end

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.
# Arguments
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `grad`: Gradient to be used by the update rule of `opt_st`.
# Returns
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)

# estimators
"""
AbstractVariationalObjective
Expand Down
62 changes: 41 additions & 21 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@ represented as follows:
```
"""
struct MvLocationScale{
S, D <: ContinuousDistribution, L
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
location ::L
scale ::S
dist ::D
scale_eps::E
end

function MvLocationScale(
location ::AbstractVector{T},
scale ::AbstractMatrix{T},
dist ::ContinuousDistribution;
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
MvLocationScale(location, scale, dist, scale_eps)
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -36,14 +46,14 @@ function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
MvLocationScale(location, scale, re.q.dist)
MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
end

function Optimisers.destructure(
q::MvLocationScale{<:Diagonal, D, L}
) where {D, L}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
end
# end
Expand All @@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale))
n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
Expand Down Expand Up @@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

"""
Expand All @@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T)),
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end
4 changes: 3 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ function optimize(
stat = merge(stat, stat′)

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
opt_st, params = update_variational_params!(
typeof(q_init), opt_st, params, restructure, grad
)

if !isnothing(callback)
stat′ = callback(
Expand Down
32 changes: 32 additions & 0 deletions test/interface/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,35 @@
@test q == re(λ)
end
end

@testset "scale positive definite projection" begin
@testset "$(string(covtype)) $(realtype) $(bijector)" for
covtype = [:meanfield, :fullrank],
realtype = [Float32, Float64],
bijector = [nothing, :identity]

d = 5
μ = zeros(realtype, d)
ϵ = sqrt(realtype(0.5))
q = if covtype == :fullrank
L = LowerTriangular(Matrix{realtype}(I,d,d))
FullRankGaussian(μ, L; scale_eps=ϵ)
elseif covtype == :meanfield
L = Diagonal(ones(realtype, d))
MeanFieldGaussian(μ, L; scale_eps=ϵ)
end
q_trans = if isnothing(bijector)
q
else
Bijectors.TransformedDistribution(q, identity)
end
g = deepcopy(q)

λ, re = Optimisers.destructure(q)
grad, _ = Optimisers.destructure(g)
opt_st = Optimisers.setup(Descent(one(realtype)), λ)
_, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad)
q′ = re(λ′)
@test all(diag(var(q′)) .≥ ϵ^2)
end
end

1 comment on commit 5ced9c2

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 5ced9c2 Previous: 75eb334 Ratio
normal + bijector/meanfield/ForwardDiff 535462845 ns 498137471 ns 1.07
normal + bijector/meanfield/ReverseDiff 192925746.5 ns 141700614.5 ns 1.36

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.