diff --git a/src/LimitedLDLFactorizations.jl b/src/LimitedLDLFactorizations.jl index ca49270..a67aa26 100644 --- a/src/LimitedLDLFactorizations.jl +++ b/src/LimitedLDLFactorizations.jl @@ -769,7 +769,7 @@ function abspermute!( end end -function lldl_lsolve!(n, x, Lp, Li, Lx) +function lldl_lsolve!(n, x::AbstractVector, Lp, Li, Lx) @inbounds for j = 1:n xj = x[j] @inbounds for p = Lp[j]:(Lp[j + 1] - 1) @@ -779,14 +779,14 @@ function lldl_lsolve!(n, x, Lp, Li, Lx) return x end -function lldl_dsolve!(n, x, D) +function lldl_dsolve!(n, x::AbstractVector, D) @inbounds for j = 1:n x[j] /= D[j] end return x end -function lldl_ltsolve!(n, x, Lp, Li, Lx) +function lldl_ltsolve!(n, x::AbstractVector, Lp, Li, Lx) @inbounds for j = n:-1:1 xj = x[j] @inbounds for p = Lp[j]:(Lp[j + 1] - 1) @@ -797,7 +797,7 @@ function lldl_ltsolve!(n, x, Lp, Li, Lx) return x end -function lldl_solve!(n, b, Lp, Li, Lx, D, P) +function lldl_solve!(n, b::AbstractVector, Lp, Li, Lx, D, P) @views y = b[P] lldl_lsolve!(n, y, Lp, Li, Lx) lldl_dsolve!(n, y, D) @@ -805,23 +805,67 @@ function lldl_solve!(n, b, Lp, Li, Lx, D, P) return b end -import Base.(\) -function (\)(LLDL::LimitedLDLFactorization, b::AbstractVector) - y = copy(b) - factorized(LLDL) || throw(LLDLException(error_string)) - lldl_solve!(LLDL.n, y, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P) +# solve functions for multiple rhs +function lldl_lsolve!(n, X::AbstractMatrix, Lp, Li, Lx) + @inbounds for j = 1:n + @inbounds for p = Lp[j]:(Lp[j + 1] - 1) + for k ∈ axes(X, 2) + X[Li[p], k] -= Lx[p] * X[j, k] + end + end + end + return X end +function lldl_dsolve!(n, X::AbstractMatrix, D) + @inbounds for j = 1:n + for k ∈ axes(X, 2) + X[j, k] /= D[j] + end + end + return X +end + +function lldl_ltsolve!(n, X::AbstractMatrix, Lp, Li, Lx) + @inbounds for j = n:-1:1 + @inbounds for p = Lp[j]:(Lp[j + 1] - 1) + for k ∈ axes(X, 2) + X[j, k] -= conj(Lx[p]) * X[Li[p], k] + end + end + end + return X +end + +function lldl_solve!(n, B::AbstractMatrix, Lp, Li, Lx, D, P) + @views Y = B[P, :] + lldl_lsolve!(n, Y, Lp, Li, Lx) + lldl_dsolve!(n, Y, D) + lldl_ltsolve!(n, Y, Lp, Li, Lx) + return B +end + +import Base.(\) +(\)(LLDL::LimitedLDLFactorization, b::AbstractVector) = ldiv!(LLDL, copy(b)) +(\)(LLDL::LimitedLDLFactorization, B::AbstractMatrix) = ldiv!(LLDL, copy(B)) + import LinearAlgebra.ldiv! function ldiv!(LLDL::LimitedLDLFactorization, b::AbstractVector) factorized(LLDL) || throw(LLDLException(error_string)) lldl_solve!(LLDL.n, b, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P) end +function ldiv!(LLDL::LimitedLDLFactorization, B::AbstractMatrix) + factorized(LLDL) || throw(LLDLException(error_string)) + lldl_solve!(LLDL.n, B, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P) +end function ldiv!(y::AbstractVector, LLDL::LimitedLDLFactorization, b::AbstractVector) - factorized(LLDL) || throw(LLDLException(error_string)) y .= b - lldl_solve!(LLDL.n, y, LLDL.colptr, LLDL.Lrowind, LLDL.Lnzvals, LLDL.D, LLDL.P) + ldiv!(LLDL, y) +end +function ldiv!(Y::AbstractMatrix, LLDL::LimitedLDLFactorization, B::AbstractMatrix) + Y .= B + ldiv!(LLDL, Y) end import SparseArrays.nnz diff --git a/test/runtests.jl b/test/runtests.jl index 5ce7099..51f235a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,36 +82,45 @@ end 0 0.01 0 0 0.53 0 0.56 0 0 3.1 ] A = sparse(A) - B = tril(A) + Al = tril(A) for perm ∈ (1:(A.n), amd(A), Metis.permutation(A)[1]) - LLDL = lldl(B, P = perm, memory = 0) + LLDL = lldl(Al, P = perm, memory = 0) nnzl0 = nnz(LLDL) @test nnzl0 == nnz(tril(A)) @test LLDL.α_out == 0 - LLDL = lldl(B, P = perm, memory = 5) + LLDL = lldl(Al, P = perm, memory = 5) nnzl5 = nnz(LLDL) @test nnzl5 ≥ nnzl0 @test LLDL.α_out == 0 - LLDL = lldl(B, P = perm, memory = 10) + LLDL = lldl(Al, P = perm, memory = 10) @test nnz(LLDL) ≥ nnzl5 @test LLDL.α_out == 0 L = LLDL.L + I @test norm(L * diagm(0 => LLDL.D) * L' - A[perm, perm]) ≤ sqrt(eps()) * norm(A) sol = ones(A.n) + Sol = rand(A.n, 4) b = A * sol + B = A * Sol # test matrix rhs x = LLDL \ b + X = LLDL \ B @test x ≈ sol + @test isapprox(X, Sol, atol = sqrt(eps())) y = similar(b) + Y = similar(B) ldiv!(y, LLDL, b) + ldiv!(Y, LLDL, B) @test y ≈ sol + @test isapprox(Y, Sol, atol = sqrt(eps())) ldiv!(LLDL, b) + ldiv!(LLDL, B) @test b ≈ sol + @test isapprox(B, Sol, atol = sqrt(eps())) end end @@ -184,8 +193,14 @@ end ldiv!(y, LLDL, b) @test y ≈ sol + allocs = @allocated ldiv!(y, LLDL, b) + @test allocs == 0 + ldiv!(LLDL, b) @test b ≈ sol + + allocs = @allocated ldiv!(LLDL, b) + @test allocs == 0 end @testset "with shift lower triangle" begin