Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

det, logdet, and logabsdet rrules for SparseMatrixCSC #730

Merged
merged 8 commits into from
Aug 16, 2023
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0"
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular
using SparseInverseSubset

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
Expand Down
88 changes: 88 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)

return (I, V), findnz_pullback
end

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
j = p[k]
while j != k
result += 1 # increment element count
todo[j] = false
j = p[j]
end
result += 1 # increment cycle count
end
return ifelse(isodd(result), -1, 1)
end

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
end
return abs_det, s * P
end
end


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)
end
return Ω, logabsdet_pullback
end

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)
end
return Ω, logdet_pullback
end

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end
10 changes: 10 additions & 0 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ end
V̄ = rand!(similar(V))
test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄))
end

@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end