Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MvNormal constructor unnecessarily recomputes cholesky every time with ForwardDiff #1781

Closed
marius311 opened this issue Sep 19, 2023 · 1 comment · Fixed by JuliaStats/PDMats.jl#179

Comments

@marius311
Copy link

marius311 commented Sep 19, 2023

In a typical Turing model even if you tried to pre-convert an MvNormal covariance to a PDMat, it'll still get cholseky-ed every gradient call if using ForwardDiff:

using Turing, Distributions, ForwardDiff, LinearAlgebra
Turing.setadbackend(:forwarddiff)

# pirate this function just to see the call every time
@eval LinearAlgebra function cholesky(A::AbstractMatrix, ::NoPivot=NoPivot(); check::Bool = true)
    println("here") 
    cholesky!(cholcopy(A); check)
end

@model function foo(Σ, d)
    μ ~ filldist(Uniform(), 2)
    d ~ MvNormal(μ, Σ)
end

Σ = Distributions.PDMat([1 0; 0 2])
d = rand(2)
model = foo(Σ, d)

ForwardDiff.gradient-> logjoint(model, (;μ)), rand(2)) # "here" printed every time

Imo the problem is the MvNormal constructor is a little too greedy promoting things when the eltypes don't match, and in the process recomputing cholesky (in this case the mean is Duals but the covariance is Floats).

A workaround is for the user to use the MvNormal{T,Cov,Mean}(...) constructor but this seems like an easy and potentially big performance footgun for users, even ones who were smart enough to try to manually do PDMat, and would be nice to fix (in some package, my sense is here, but maybe elsewhere).

@devmotion
Copy link
Member

Of course, allocations would be minimized if MvNormal just takes whatever types the user puts into the struct. But I assume then it becomes much more likely to run into type instability issues and you have to be much more careful when implementing methods operating with MvNormal.

Regardless though, the unnecessary cholesky decompositions are not caused by inefficiences or bugs in Distributions but rather inefficient and missing convert definitions in PDMats. JuliaStats/PDMats.jl#179 fixes your example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants