Skip to content

Commit

Permalink
fix #80 (multiple rhs)
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffroyleconte committed Oct 18, 2023
1 parent a171349 commit 6112cee
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 15 deletions.
66 changes: 55 additions & 11 deletions src/LimitedLDLFactorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ function abspermute!(
end
end

function lldl_lsolve!(n, x, Lp, Li, Lx)
function lldl_lsolve!(n, x::AbstractVector, Lp, Li, Lx)
@inbounds for j = 1:n
xj = x[j]
@inbounds for p = Lp[j]:(Lp[j + 1] - 1)
Expand All @@ -779,14 +779,14 @@ function lldl_lsolve!(n, x, Lp, Li, Lx)
return x
end

function lldl_dsolve!(n, x, D)
function lldl_dsolve!(n, x::AbstractVector, D)
@inbounds for j = 1:n
x[j] /= D[j]
end
return x
end

function lldl_ltsolve!(n, x, Lp, Li, Lx)
function lldl_ltsolve!(n, x::AbstractVector, Lp, Li, Lx)
@inbounds for j = n:-1:1
xj = x[j]
@inbounds for p = Lp[j]:(Lp[j + 1] - 1)
Expand All @@ -797,31 +797,75 @@ function lldl_ltsolve!(n, x, Lp, Li, Lx)
return x
end

function lldl_solve!(n, b, Lp, Li, Lx, D, P)
function lldl_solve!(n, b::AbstractVector, Lp, Li, Lx, D, P)
@views y = b[P]
lldl_lsolve!(n, y, Lp, Li, Lx)
lldl_dsolve!(n, y, D)
lldl_ltsolve!(n, y, Lp, Li, Lx)
return b
end

import Base.(\)
function (\)(LLDL::LimitedLDLFactorization, b::AbstractVector)
y = copy(b)
factorized(LLDL) || throw(LLDLException(error_string))
lldl_solve!(LLDL.n, y, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P)
# solve functions for multiple rhs
function lldl_lsolve!(n, X::AbstractMatrix, Lp, Li, Lx)
@inbounds for j = 1:n
@inbounds for p = Lp[j]:(Lp[j + 1] - 1)
for k axes(X, 2)
X[Li[p], k] -= Lx[p] * X[j, k]
end
end
end
return X
end

function lldl_dsolve!(n, X::AbstractMatrix, D)
@inbounds for j = 1:n
for k axes(X, 2)
X[j, k] /= D[j]
end
end
return X
end

function lldl_ltsolve!(n, X::AbstractMatrix, Lp, Li, Lx)
@inbounds for j = n:-1:1
@inbounds for p = Lp[j]:(Lp[j + 1] - 1)
for k axes(X, 2)
X[j, k] -= conj(Lx[p]) * X[Li[p], k]
end
end
end
return X
end

function lldl_solve!(n, B::AbstractMatrix, Lp, Li, Lx, D, P)
@views Y = B[P, :]
lldl_lsolve!(n, Y, Lp, Li, Lx)
lldl_dsolve!(n, Y, D)
lldl_ltsolve!(n, Y, Lp, Li, Lx)
return B
end

import Base.(\)
(\)(LLDL::LimitedLDLFactorization, b::AbstractVector) = ldiv!(LLDL, copy(b))
(\)(LLDL::LimitedLDLFactorization, B::AbstractVecOrMat) = ldiv!(LLDL, copy(B))

import LinearAlgebra.ldiv!
function ldiv!(LLDL::LimitedLDLFactorization, b::AbstractVector)
factorized(LLDL) || throw(LLDLException(error_string))
lldl_solve!(LLDL.n, b, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P)
end
function ldiv!(LLDL::LimitedLDLFactorization, B::AbstractMatrix)
factorized(LLDL) || throw(LLDLException(error_string))
lldl_solve!(LLDL.n, B, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P)
end

function ldiv!(y::AbstractVector, LLDL::LimitedLDLFactorization, b::AbstractVector)
factorized(LLDL) || throw(LLDLException(error_string))
y .= b
lldl_solve!(LLDL.n, y, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P)
ldiv!(LLDL, y)
end
function ldiv!(Y::AbstractMatrix, LLDL::LimitedLDLFactorization, B::AbstractMatrix)
Y .= B
ldiv!(LLDL, Y)
end

import SparseArrays.nnz
Expand Down
23 changes: 19 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,36 +82,45 @@ end
0 0.01 0 0 0.53 0 0.56 0 0 3.1
]
A = sparse(A)
B = tril(A)
Al = tril(A)

for perm (1:(A.n), amd(A), Metis.permutation(A)[1])
LLDL = lldl(B, P = perm, memory = 0)
LLDL = lldl(Al, P = perm, memory = 0)
nnzl0 = nnz(LLDL)
@test nnzl0 == nnz(tril(A))
@test LLDL.α_out == 0

LLDL = lldl(B, P = perm, memory = 5)
LLDL = lldl(Al, P = perm, memory = 5)
nnzl5 = nnz(LLDL)
@test nnzl5 nnzl0
@test LLDL.α_out == 0

LLDL = lldl(B, P = perm, memory = 10)
LLDL = lldl(Al, P = perm, memory = 10)
@test nnz(LLDL) nnzl5
@test LLDL.α_out == 0
L = LLDL.L + I
@test norm(L * diagm(0 => LLDL.D) * L' - A[perm, perm]) sqrt(eps()) * norm(A)

sol = ones(A.n)
Sol = rand(A.n, 4)
b = A * sol
B = A * Sol # test matrix rhs
x = LLDL \ b
X = LLDL \ B
@test x sol
@test isapprox(X, Sol, atol = sqrt(eps()))

y = similar(b)
Y = similar(B)
ldiv!(y, LLDL, b)
ldiv!(Y, LLDL, B)
@test y sol
@test isapprox(Y, Sol, atol = sqrt(eps()))

ldiv!(LLDL, b)
ldiv!(LLDL, B)
@test b sol
@test isapprox(B, Sol, atol = sqrt(eps()))
end
end

Expand Down Expand Up @@ -184,8 +193,14 @@ end
ldiv!(y, LLDL, b)
@test y sol

allocs = @allocated ldiv!(y, LLDL, b)
@test allocs == 0

ldiv!(LLDL, b)
@test b sol

allocs = @allocated ldiv!(LLDL, b)
@test allocs == 0
end

@testset "with shift lower triangle" begin
Expand Down

0 comments on commit 6112cee

Please sign in to comment.