Skip to content

Commit

Permalink
add mean field
Browse files Browse the repository at this point in the history
  • Loading branch information
ngiann committed Apr 27, 2023
1 parent 9188e4a commit e586c5b
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 151 deletions.
6 changes: 4 additions & 2 deletions src/GaussianVariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ module GaussianVariationalInference

include("interface.jl")
include("VIfull.jl")
include("coreVIrank1.jl")
include("VIdiag.jl")
include("VIrank1.jl")
include("entropy.jl")

# include("VIdiag.jl")
Expand All @@ -27,6 +28,7 @@ module GaussianVariationalInference
# Utilities

# include("util/report.jl")
include("util/pickoptimiser.jl")
include("util/generatelatentZ.jl")
include("util/defaultgradient.jl")
include("util/verifygradient.jl")
Expand All @@ -42,7 +44,7 @@ module GaussianVariationalInference



export VI, VIrank1 #, VIdiag, VIfixedcov, MVI, laplace
export VI, VIdiag, VIrank1 #, VIdiag, VIfixedcov, MVI, laplace

export exampleproblem1

Expand Down
161 changes: 161 additions & 0 deletions src/VIdiag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
function coreVIdiag(logp::Function, μ₀::AbstractArray{T, 1}, Σ₀diag::AbstractArray{T, 1}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every) where T

D = length(μ₀)

#----------------------------------------------------
# generate latent variables
#----------------------------------------------------

Ztrain = generatelatentZ(S = S, D = D, seed = seed)


#----------------------------------------------------
# Auxiliar function for handling parameters
#----------------------------------------------------

function unpack(param)

@assert(length(param) == D+D)

local μ = param[1:D]

local Cdiag = reshape(param[D+1:D+D], D)

return μ, Cdiag

end


#----------------------------------------------------
# Objective and gradient functions for Optim.optimize
#----------------------------------------------------

function minauxiliary(param)

local μ, Cdiag = unpack(param)

local= elbo(μ, Cdiag, Ztrain)

update!(trackELBO; newelbo = ℓ, μ = μ, C = Cdiag)

return -1.0 *# Optim.optimise is minimising

end


function minauxiliary_grad(param)

local μ, Cdiag = unpack(param)

return -1.0 * elbo_grad(μ, Cdiag, Ztrain) # Optim.optimise is minimising

end


#----------------------------------------------------
# Functions for covariance and covariance root
#----------------------------------------------------


function getcov(Cdiag)

Diagonal(Cdiag.^2)

end


function getcovroot(Cdiag)

return Cdiag

end


#----------------------------------------------------
# Approximate evidence lower bound and its gradient
#----------------------------------------------------

function elbo(μ, Cdiag, Z)

local aux = z -> logp.+ Cdiag.*z)

Transducers.foldxt(+, Map(aux), Z) / length(Z) + GaussianVariationalInference.entropy(Cdiag)

end


function partial_elbo_grad(μ, Cdiag, z)

local g = gradlogp.+ Cdiag.*z)

[g; vec(g.*z)]

end


function elbo_grad(μ, Cdiag, Z)

local aux = z -> partial_elbo_grad(μ, Cdiag, z)

local gradμCdiag = Transducers.foldxt(+, Map(aux), Z) / length(Z)

# entropy contribution to covariance

gradμCdiag[D+1:end] .+= vec(1.0 ./ Cdiag)

return gradμCdiag

end


# Package Optim requires that function for gradient has following signature

gradhelper(storage, param) = copyto!(storage, minauxiliary_grad(param))


#----------------------------------------------------
# Numerically verify gradient
#----------------------------------------------------

numerical_verification ? verifygradient(μ₀, Σ₀diag, elbo, minauxiliary_grad, unpack, Ztrain) : nothing


#----------------------------------------------------
# Define callback function called at each iteration
#----------------------------------------------------

# We want to keep track of the best variational
# parameters encountered during the optimisation of
# the elbo. Unfortunately, the otherwise superb
# package Optim.jl does not provide a consistent way
# accross different optimisers to do this.


trackELBO = RecordELBOProgress(; μ = zeros(D), C = zeros(D),
Stest = Stest,
show_every = show_every,
test_every = test_every,
elbo = elbo, seed = seed)



#----------------------------------------------------
# Call optimiser to minimise *negative* elbo
#----------------------------------------------------

options = Optim.Options(extended_trace = false, store_trace = false, show_every = 1, show_trace = false, iterations = iterations, g_tol = 1e-6, callback = trackELBO)

result = Optim.optimize(minauxiliary, gradhelper, [μ₀; vec(sqrt.(Σ₀diag))], optimiser, options)

μopt, Copt = unpack(result.minimizer)


#----------------------------------------------------
# Return results
#----------------------------------------------------

Σopt = getcov(Copt)

return MvNormal(μopt, Σopt), elbo(μopt, Copt, Ztrain), Copt

end
57 changes: 20 additions & 37 deletions src/coreVIrank1.jl → src/VIrank1.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractArray{T, 2}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, transform = transform, seedtest = seedtest) where T
function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C::AbstractArray{T, 2}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, transform = transform, seedtest = seedtest) where T

D = length(μ₀)

rg = MersenneTwister(seed)

D = length(μ₀); @assert(D == size(C₀, 1) == size(C₀, 2))

#----------------------------------------------------
# generate latent variables
Expand All @@ -13,7 +16,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
# Define jacobian of transformation via AD
#----------------------------------------------------

jac_transform = transform == identity ? Matrix(I, D, D) : x -> ForwardDiff.jacobian(transform, x)
# jac_transform = transform == identity ? Matrix(I, D, D) : x -> ForwardDiff.jacobian(transform, x)


#----------------------------------------------------
Expand Down Expand Up @@ -41,11 +44,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA

local μ, u, v = unpack(param)

local C = getcovroot(C₀, u, v)
local = elbo, u, v, Ztrain)

local= elbo(μ, C, Ztrain)

update!(trackELBO; newelbo = ℓ, μ = μ, C = C)
update!(trackELBO; newelbo = ℓ, μ = μ, C = getcovroot(u, v))

return -1.0 *# Optim.optimise is minimising

Expand All @@ -65,9 +66,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
# Functions for covariance and covariance root
#----------------------------------------------------

function getcov(C₀, u, v)
function getcov(u, v)

local aux = getcovroot(C₀, u, v)
local aux = getcovroot(u, v)

local Σ = aux*aux'

Expand All @@ -76,9 +77,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
end


function getcovroot(C₀, u, v)
function getcovroot(u, v)

C + u*v'
C + u*v'

end

Expand All @@ -87,9 +88,10 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
# Approximate evidence lower bound and its gradient
#----------------------------------------------------

function elbo(μ, C, Z)
function elbo(μ, u, v, Z)

local C = getcovroot(u, v)


local= GaussianVariationalInference.entropy(C)

# if transform !== identity
Expand Down Expand Up @@ -126,7 +128,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA

function elbo_grad(μ, u, v, Z)

local C = getcovroot(C₀, u, v)
local C = getcovroot(u, v)

local aux = z -> partial_elbo_grad(μ, C, u, v, z)

Expand All @@ -152,26 +154,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
# Numerically verify gradient
#----------------------------------------------------

# COMMENT BACK IN AFTER VERIFICATION
#numerical_verification ? verifygradient(μ₀, Σ₀, elbo, minauxiliary_grad, unpack, Ztrain) : nothing

# DELETE AFTER VERIFICATION
# let

# local u,v = randn(D), randn(D)
numerical_verification ? verifygradient(μ₀, 1e-2*randn(rg, D), 1e-2*randn(rg, D), elbo, minauxiliary_grad, unpack, Ztrain) : nothing

# local angrad = minauxiliary_grad([μ₀;vec(u);vec(v)])

# adgrad = ForwardDiff.gradient(minauxiliary, [μ₀; vec(u);vec(v)])

# discrepancy = maximum(abs.(vec(adgrad) - vec(angrad)))

# display([angrad adgrad])

# @printf("Maximum absolute difference between AD and analytical gradient is %f\n", discrepancy)

# end

#----------------------------------------------------
# Define callback function called at each iteration
#----------------------------------------------------
Expand All @@ -196,7 +181,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA

options = Optim.Options(extended_trace = false, store_trace = false, show_trace = false, iterations = iterations, g_tol = 1e-6, callback = trackELBO)

result = Optim.optimize(minauxiliary, gradhelper, [μ₀; 1e-2*randn(2D)], optimiser, options)
result = Optim.optimize(minauxiliary, gradhelper, [μ₀; 1e-2*randn(rg, 2D)], optimiser, options)

μopt, uopt, vopt = unpack(result.minimizer)

Expand All @@ -205,10 +190,8 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA
# Return results
#----------------------------------------------------

Copt = getcovroot(C₀, uopt, vopt)

# Σopt = getcov(C₀, uopt, vopt)
Copt = getcovroot(uopt, vopt)

return μopt, Copt, elbo(μopt, Copt, Ztrain)
return MvNormal(μopt, getcov(uopt, vopt)), elbo(μopt, uopt, vopt, Ztrain), Copt

end
Loading

0 comments on commit e586c5b

Please sign in to comment.