Skip to content

Commit

Permalink
Merge pull request #400 from avik-pal/ap/needs_square_A
Browse files Browse the repository at this point in the history
Add `needs_square_A` trait
  • Loading branch information
ChrisRackauckas authored Oct 25, 2023
2 parents a880003 + 12c7ed9 commit 4ee3e1e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.11.1"
version = "2.12.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
29 changes: 29 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util
is_underdetermined(x) = false
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)

_isidentity_struct(A) = false
_isidentity_struct::Number) = isone(λ)
Expand Down Expand Up @@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down Expand Up @@ -143,6 +147,31 @@ end
include("factorization_sparse.jl")
end

# Solver Specific Traits
## Needs Square Matrix
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
needs_square_A(alg::SciMLLinearSolveAlgorithm) = true
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
:NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false
end
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization)
@eval needs_square_A(::$(alg)) = true
end

const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

Expand Down
33 changes: 28 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18}
T13, T14, T15, T16, T17, T18, T19}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
end

# Legacy fallback
Expand Down Expand Up @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.MKLLUFactorization
Expand Down Expand Up @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
else
Expand Down Expand Up @@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
NormalCholeskyFactorization()
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
@static if VERSION v"1.7beta"
QRFactorization(ColumnNorm())
else
QRFactorization(Val(true))
end
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization

@static if VERSION >= v"1.7"
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
else
defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted
end

"""
if alg.alg === DefaultAlgorithmChoice.LUFactorization
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
Expand Down
10 changes: 10 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
QRFactorization(pivot, 16, inplace)
end

@static if VERSION v"1.7beta"
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
else
function QRFactorization(pivot::Val, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
Expand Down
16 changes: 16 additions & 0 deletions test/nonsquare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u res
@test !LinearSolve.needs_square_A(QRFactorization())
@test solve(prob, QRFactorization()) res
@test !LinearSolve.needs_square_A(FastQRFactorization())
@test solve(prob, FastQRFactorization()) res
@test !LinearSolve.needs_square_A(KrylovJL_LSMR())
@test solve(prob, KrylovJL_LSMR()) res

A = sprand(m, n, 0.5)
Expand All @@ -23,6 +27,7 @@ A = sprand(n, m, 0.5)
b = rand(n)
prob = LinearProblem(A, b)
res = Matrix(A) \ b
@test !LinearSolve.needs_square_A(KrylovJL_CRAIGMR())
@test solve(prob, KrylovJL_CRAIGMR()) res

A = sprandn(1000, 100, 0.1)
Expand All @@ -35,7 +40,9 @@ A = randn(1000, 100)
b = randn(1000)
@test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b))
solve(LinearProblem(A, b)).u;
@test !LinearSolve.needs_square_A(NormalCholeskyFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
@test !LinearSolve.needs_square_A(NormalBunchKaufmanFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
Expand All @@ -49,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
condition = OperatorCondition.WellConditioned))).u;

# Underdetermined
m, n = 2, 3

A = rand(m, n)
b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u res

0 comments on commit 4ee3e1e

Please sign in to comment.