Skip to content

Commit

Permalink
verify gradient for VIfull, VIdiag and VIrank1
Browse files Browse the repository at this point in the history
  • Loading branch information
ngiann committed Apr 25, 2023
1 parent 136da74 commit 9188e4a
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions src/util/verifygradient.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
function verifygradient(μ₀, Σ₀, elbo, minauxiliary_grad, unpack, Ztrain)
# Checks gradient for VI with full covariance matrix
function verifygradient(μ, Σ::Matrix, elbo, minauxiliary_grad, unpack, Z)

C = Matrix(cholesky(Σ₀).L)
C = vec(Matrix(cholesky(Σ).L))

angrad = minauxiliary_grad([μ; C])

local angrad = minauxiliary_grad([μ₀;vec(C)])
adgrad = ForwardDiff.gradient(p -> -elbo(unpack(p)..., Z), [μ; C])

reportdiscrepancy(angrad, adgrad)

end


# Checks gradient for VI with diagonal covariance matrix
function verifygradient(μ, Σdiag::Vector, elbo, minauxiliary_grad, unpack, Z)

C = sqrt.(Σdiag)

angrad = minauxiliary_grad([μ; C])

adgrad = ForwardDiff.gradient(p -> -elbo(unpack(p)..., Z), [μ; C])

reportdiscrepancy(angrad, adgrad)

end


# Checks gradient for VI with rank 1 parametrised covariance matrix
function verifygradient(μ, u::Vector, v::Vector, elbo, minauxiliary_grad, unpack, Z)

angrad = minauxiliary_grad([μ; u; v])

adgrad = ForwardDiff.gradient(p -> -elbo(unpack(p)..., Ztrain), [μ₀; vec(C)])
adgrad = ForwardDiff.gradient(p -> -elbo(unpack(p)..., Z), [μ; u; v])

discrepancy = maximum(abs.(vec(adgrad) - vec(angrad)))
reportdiscrepancy(angrad, adgrad)

msg = @sprintf("Maximum absolute difference between AD and analytical gradient is %f\n", discrepancy)
end


function reportdiscrepancy(angrad, adgrad)

discrepancy = maximum(abs.(vec(adgrad) - vec(angrad)))

msg = @sprintf("Maximum absolute difference between AD and analytical gradient is %.8f\n", discrepancy)

clr = discrepancy > 1e-5 ? :red : :cyan

print(Crayon(foreground = clr, bold=true), msg)

print(Crayon(foreground = :white, bold=false), "")
print(Crayon(foreground = :white, bold=false), "", Crayon(reset = true))

nothing

Expand Down

0 comments on commit 9188e4a

Please sign in to comment.