Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RecursiveFactorization.jl optional #569

Merged
27 commits merged into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2c0d8ab
Make RecursiveFactorization.jl optional
ChrisRackauckas Feb 1, 2025
75d2813
get it working
ChrisRackauckas Feb 1, 2025
b62f277
Update ext/LinearSolveRecursiveFactorization.jl
ChrisRackauckas Feb 1, 2025
d81e8a2
Update src/default.jl
ChrisRackauckas Feb 1, 2025
aa865ba
Update src/factorization.jl
ChrisRackauckas Feb 1, 2025
b8cd21b
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
1c2eae4
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
6606dd4
Update src/extension_algs.jl
ChrisRackauckas Feb 1, 2025
45d94fd
Update src/default.jl
ChrisRackauckas Feb 1, 2025
cd3a29e
add RecursiveFactorization in tests
ChrisRackauckas Feb 1, 2025
80e6614
Update and rename LinearSolveRecursiveFactorization.jl to LinearSolve…
ChrisRackauckas Feb 5, 2025
c280c46
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
50e8cbe
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
4e358f6
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
0278ecd
Update LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
e18d864
namespace PreallocatedLU
ChrisRackauckas Feb 5, 2025
a689fd6
one more
ChrisRackauckas Feb 5, 2025
8faa4e6
one more
ChrisRackauckas Feb 5, 2025
7d1f54a
namespace
ChrisRackauckas Feb 5, 2025
c93a6b2
namespace
ChrisRackauckas Feb 5, 2025
41a786b
fix default
ChrisRackauckas Feb 5, 2025
7569440
fix inference on recfact load
ChrisRackauckas Feb 5, 2025
8704062
don't double
ChrisRackauckas Feb 5, 2025
3ad68c7
Update src/extension_algs.jl
ChrisRackauckas Feb 5, 2025
4d3a346
Update src/extension_algs.jl
ChrisRackauckas Feb 5, 2025
486d924
Update src/default.jl
ChrisRackauckas Feb 5, 2025
1fad945
Update ext/LinearSolveRecursiveFactorizationExt.jl
ChrisRackauckas Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Expand All @@ -45,6 +44,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -60,6 +60,7 @@ LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"

[compat]
AllocCheck = "0.2"
Expand Down Expand Up @@ -140,11 +141,12 @@ MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization"]
31 changes: 31 additions & 0 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module LinearSolveRecursiveFactorizationExt

using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization

LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)
fact, ipiv = LinearSolve.@get_cacheval(cache, :RFLUFactorization)
if cache.isfresh
if length(ipiv) != min(size(A)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :RFLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

end
4 changes: 1 addition & 3 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ end

import PrecompileTools
using ArrayInterface
using RecursiveFactorization
using Base: cache_dependencies, Bool
using LinearAlgebra
using SparseArrays
Expand Down Expand Up @@ -127,6 +126,7 @@ end
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}

include("common.jl")
include("extension_algs.jl")
include("factorization.jl")
include("appleaccelerate.jl")
include("mkl.jl")
Expand All @@ -137,7 +137,6 @@ include("preconditioners.jl")
include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("adjoint.jl")
include("deprecated.jl")

Expand Down Expand Up @@ -212,7 +211,6 @@ PrecompileTools.@compile_workload begin
prob = LinearProblem(A, b)
sol = solve(prob)
sol = solve(prob, LUFactorization())
sol = solve(prob, RFLUFactorization())
sol = solve(prob, KrylovJL_GMRES())
end

Expand Down
9 changes: 6 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
end
end

userecursivefactorization(A) = false

# Allows A === nothing as a stand-in for dense matrix
function defaultalg(A, b, assump::OperatorAssumptions{Bool})
alg = if assump.issq
Expand All @@ -178,14 +180,15 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
(__conditioning(assump) === OperatorCondition.IllConditioned ||
__conditioning(assump) === OperatorCondition.WellConditioned)
if length(b) <= 10
DefaultAlgorithmChoice.RFLUFactorization
DefaultAlgorithmChoice.GenericLUFactorization
elseif appleaccelerate_isavailable() && b isa Array &&
eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64}
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
(usemkl && length(b) <= 200)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
eltype(A) <: Union{Float32, Float64}) &&
userecursivefactorization(A)
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
Expand Down Expand Up @@ -265,7 +268,7 @@ function algchoice_to_alg(alg::Symbol)
elseif alg === :GenericLUFactorization
GenericLUFactorization()
elseif alg === :RFLUFactorization
RFLUFactorization()
RFLUFactorization(throwerror = false)
elseif alg === :BunchKaufmanFactorization
BunchKaufmanFactorization()
elseif alg === :CHOLMODFactorization
Expand Down
25 changes: 25 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,31 @@ struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization
end
end

## RFLUFactorization

"""
`RFLUFactorization()`

A fast pure Julia LU-factorization implementation
using RecursiveFactorization.jl. This is by far the fastest LU-factorization
implementation, usually outperforming OpenBLAS and MKL for smaller matrices
(<500x500), but currently optimized only for Base `Array` with `Float32` or `Float64`.
Additional optimization for complex matrices is in the works.
"""
struct RFLUFactorization{P, T} <: AbstractDenseFactorization
function RFLUFactorization(::Val{P}, ::Val{T}; throwerror = true) where {P, T}
if !userecursivefactorization(nothing)
throwerror &&
error("RFLUFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
end
return new{P, T}()
end
end

function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror = true)
RFLUFactorization(pivot, thread; throwerror)
end

"""
```julia
MKLPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,
Expand Down
99 changes: 30 additions & 69 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ function init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters::In
do_factorization(alg, convert(AbstractMatrix, A), b, u)
end

## RFLU Factorization

function LinearSolve.init_cacheval(alg::RFLUFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv
end

function LinearSolve.init_cacheval(
alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
PREALLOCATED_LU, ipiv
end

function LinearSolve.init_cacheval(alg::RFLUFactorization,
A::Union{AbstractSparseArray, AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

function LinearSolve.init_cacheval(alg::RFLUFactorization,
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

## LU Factorizations

"""
Expand Down Expand Up @@ -989,75 +1019,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs.
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

## RFLUFactorization

"""
`RFLUFactorization()`

A fast pure Julia LU-factorization implementation
using RecursiveFactorization.jl. This is by far the fastest LU-factorization
implementation, usually outperforming OpenBLAS and MKL for smaller matrices
(<500x500), but currently optimized only for Base `Array` with `Float32` or `Float64`.
Additional optimization for complex matrices is in the works.
"""
struct RFLUFactorization{P, T} <: AbstractDenseFactorization
RFLUFactorization(::Val{P}, ::Val{T}) where {P, T} = new{P, T}()
end

function RFLUFactorization(; pivot = Val(true), thread = Val(true))
RFLUFactorization(pivot, thread)
end

function init_cacheval(alg::RFLUFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv
end

function init_cacheval(alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
PREALLOCATED_LU, ipiv
end

function init_cacheval(alg::RFLUFactorization,
A::Union{AbstractSparseArray, AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

function init_cacheval(alg::RFLUFactorization,
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing, nothing
end

function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)
fact, ipiv = @get_cacheval(cache, :RFLUFactorization)
if cache.isfresh
if length(ipiv) != min(size(A)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end
y = ldiv!(cache.u, @get_cacheval(cache, :RFLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

## NormalCholeskyFactorization

"""
Expand Down
2 changes: 1 addition & 1 deletion test/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Zygote, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using FiniteDiff, RecursiveFactorization
using LazyArrays: BroadcastArray

n = 4
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators
using SciMLOperators, RecursiveFactorization
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
using Test
import Random
Expand Down
4 changes: 2 additions & 2 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearSolve, LinearAlgebra, SparseArrays, Test, JET
using LinearSolve, RecursiveFactorization, LinearAlgebra, SparseArrays, Test, JET
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
LinearSolve.DefaultAlgorithmChoice.RFLUFactorization
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
prob = LinearProblem(rand(3, 3), rand(3))
solve(prob)

Expand Down
2 changes: 1 addition & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using FiniteDiff, RecursiveFactorization

n = 4
A = rand(n, n);
Expand Down
2 changes: 1 addition & 1 deletion test/retcodes.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve
using LinearSolve, RecursiveFactorization

alglist = (
LUFactorization,
Expand Down
Loading