-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems with Zygote and the PDMat
constructor
#159
Comments
Doesn't help but I've seen these quite often.
A major difference is that your type is not a subtype of |
The usual approach for fixing these errors is defining an rrule or a projector with CR, as Will discussed in the linked PR. |
This is a sure sign that I don't actually understand how Zygote works. Where in the call stack would it matter whether or not EDIT: In other words, the fact that Zygote can differentiate |
Making it a subtype of struct MyOtherPDMat{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
dim::Int
mat::S
chol::Cholesky{T,S}
MyOtherPDMat{T,S}(d::Int,m::AbstractMatrix{T},c::Cholesky{T,S}) where {T,S} = new{T,S}(d,m,c)
end
function MyOtherPDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
d = size(mat, 1)
size(chol, 1) == d ||
throw(DimensionMismatch("Dimensions of mat and chol are inconsistent."))
MyOtherPDMat{T,S}(d, convert(S, mat), chol)
end
MyOtherPDMat(mat::AbstractMatrix) = MyOtherPDMat(mat, cholesky(mat))
Base.:\(a::MyOtherPDMat, x::AbstractVecOrMat) = cholesky(a) \ x
LinearAlgebra.cholesky(a::MyOtherPDMat) = a.chol
h(x) = tr(MyOtherPDMat(kernel(0.1)) \ kernel(x))
Zygote.gradient(x->h(only(x)), [.2]) # ERROR
|
Which is the function that needs an PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat)) ? |
If it doesn't happen for your own type that doesn't subtype |
Thanks, this sounds great! Any tips on how to find out efficiently which rrule is problematic? |
I wasn't able to figure out which rrule to opt out of. However, I came up with an rrule that allows me to differentiate through the I started with However, Zygote does not seem to recognize my rrule and still complains about a missing adjoint: using LinearAlgebra
using PDMats
using Zygote
x = [1. 0.2; 0.2 1.]
y = [1. 0.1; 0.1 1.]
Zygote.gradient(logdet ∘ PDMat, x) |> only # works
Zygote.gradient(det ∘ PDMat, x) |> only # ERROR
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}
# Since `logdet` has an overload for `PDMat` and `det` doesn't,
# and `logdet` above works while `det` doesn't, try to add an overload
# for `det` and see what happens:
LinearAlgebra.det(A::PDMat) = det(A.chol)
Zygote.gradient(det ∘ PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64}
# While the overload may be useful for efficiency, it does not solve the AD issue
# Next, we will try to define an rrule
using ChainRules, ChainRulesCore
# rrule draft for constructor of PDMat
# This is probably not completely correct, but should be a good start
function ChainRulesCore.rrule(::Type{PDMat}, mat)
chol, chol_pullback = rrule(cholesky, mat)
y = PDMat(mat, chol)
function PDMat_pullbackCR(m̄at::AbstractMatrix)
@info "Using CR for PDMat, AbstractMatrix tangent"
return NoTangent(), m̄at
end
function PDMat_pullbackCR(m̄at::Tangent)
@info "Using CR for PDMat, Tangent type"
return NoTangent(), chol_pullback(m̄at.chol)
end
return y, PDMat_pullbackCR
end
# Perform the individual forward and backward steps manually:
a, a_pullback = rrule(PDMat, x)
b, b_pullback = rrule(det, a)
b̄ = 1.
_, ā = b_pullback(b̄)
_, x̄ = a_pullback(ā)
# Compare to the result without the `PDMat` wrapper
unthunk(x̄) ≈ Zygote.gradient(det, x) |> only # true
Zygote.gradient(det ∘ PDMat, x) |> only
# ERROR: Need an adjoint for constructor PDMat{Float64, Matrix{Float64}}. Gradient is of type Matrix{Float64} |
You often need to do a |
Just tried this code, and |
Interestingly julia> ForwardDiff.gradient(det ∘ PDMat, x)
2×2 Matrix{Float64}:
1.0 -0.4
0.0 1.0
julia> ForwardDiff.gradient(det, PDMat(x))
2×2 Matrix{Float64}:
1.0 -0.2
-0.2 1.0
julia> ForwardDiff.gradient(det, x)
2×2 Matrix{Float64}:
1.0 -0.2
-0.2 1.0 |
Probably one should restrict Regardless of AD, I think it would be useful to add definitions of Probably I think the right approach would be to add a projection mechanism for PDMat and rrules for the constructor similar to the one for Hermitian and Symmetric matrices (https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/projection.jl#L425-L451 and https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/symmetric.jl#L5-L92). |
I think this would be a great test to do for all the functions overloaded for |
I opened #161. |
The ForwardDiff issue is not related to PDMats: julia> using PDMats, ForwardDiff, LinearAlgebra
julia> x = [1. 0.2; 0.2 1.];
julia> ForwardDiff.gradient(det ∘ PDMat, x)
2×2 Matrix{Float64}:
1.0 -0.4
0.0 1.0
julia> ForwardDiff.gradient(det ∘ cholesky, x)
2×2 Matrix{Float64}:
1.0 -0.4
0.0 1.0 From the perspective of |
So |
Especially because using a julia> ForwardDiff.gradient(logdet ∘ PDMat, Symmetric(x))
ERROR: ArgumentError: Cannot set a non-diagonal index in a symmetric matrix whereas it works for Zygote: julia> Zygote.gradient(logdet ∘ PDMat, Symmetric(x)) |> only
2×2 Symmetric{Float64, Matrix{Float64}}:
1.04167 -0.208333
-0.208333 1.04167 |
Actually, I don't think there's anything wrong with the derivatives of ForwardDiff and |
This is a general issue with |
Yes, as I said I do believe that the differential is correct because it is only defined for a symmetric tangent. I think we can focus on getting the gradients working in Zygote. |
The only things to make it work with Zygote are (copied from above):
I just checked it locally, with these changes also |
An error is thrown when differentiating a trace of a matrix division with a
PDMat
:This seems like a strange error. I tried to reproduce with my own type, but couldn't:
BTW, all of these functions can be differentiated with ForwardDiff.
The text was updated successfully, but these errors were encountered: