Skip to content

Commit

Permalink
Add SimpleGMRES implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 22, 2023
1 parent f531dd0 commit adde736
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
Manifest.toml

*.swp
.vscode
wip
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,30 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolveNNlibExt = "NNlib"
LinearSolvePardisoExt = "Pardiso"

[compat]
ArrayInterface = "7.4.11"
BlockDiagonals = "0.1"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
FastLapackInterface = "1, 2"
Expand All @@ -56,6 +61,7 @@ IterativeSolvers = "0.9.2"
KLU = "0.3.0, 0.4"
Krylov = "0.9"
KrylovKit = "0.5, 0.6"
NNlib = "0.9"
PrecompileTools = "1"
Preferences = "1"
RecursiveFactorization = "0.2.8"
Expand Down
24 changes: 24 additions & 0 deletions ext/LinearSolveBlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module LinearSolveBlockDiagonalsExt

using LinearSolve, BlockDiagonals

function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, u, Pl, Pr,

Check warning on line 5 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L5

Added line #L5 was not covered by tests
maxiters::Int, abstol, reltol, verbose, assumptions; zeroinit = true)
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."

Check warning on line 7 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L7

Added line #L7 was not covered by tests
# We need to perform this check even when `zeroinit == true`, since the type of the
# cache is dependent on whether we are able to use the specialized dispatch.
bsizes = blocksizes(A)
usize = first(first(bsizes))
uniform_blocks = true
for bsize in bsizes
if bsize[1] != usize || bsize[2] != usize
uniform_blocks = false
break

Check warning on line 16 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L10-L16

Added lines #L10 - L16 were not covered by tests
end
end

Check warning on line 18 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L18

Added line #L18 was not covered by tests
# Can't help but perform dynamic dispatch here
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters,

Check warning on line 20 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L20

Added line #L20 was not covered by tests
abstol, reltol, verbose, assumptions; zeroinit)
end

end
5 changes: 5 additions & 0 deletions ext/LinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module LinearSolveNNlibExt

using LinearSolve, NNlib

end
5 changes: 4 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using Requires
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
Expand Down Expand Up @@ -85,6 +85,7 @@ end
include("common.jl")
include("factorization.jl")
include("simplelu.jl")
include("simplegmres.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("solve_function.jl")
Expand Down Expand Up @@ -171,6 +172,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES

export SimpleGMRES

export HYPREAlgorithm
export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
Expand Down
158 changes: 158 additions & 0 deletions src/simplegmres.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
SimpleGMRES(; restart::Int = 20, blocksize::Int = 0)
A simple GMRES implementation for square non-Hermitian linear systems.
This implementation handles Block Diagonal Matrices with Uniformly Sized Square Blocks with
specialized dispatches.
## Arguments
* `restart::Int = 20`: the number of iterations before restarting. Must be a strictly
positive integer.
* `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a
uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing
this option will lead to incorrect results.
* If this is set `≤ 0` and during runtime we get a Block Diagonal Matrix, then we will
check if the specialized dispatch can be used.
!!! warning
Most users should be using the `KrylovJL_GMRES` solver instead of this implementation.
"""
struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod
restart::Int
blocksize::Int

function SimpleGMRES(; restart::Int = 20, blocksize::Int = 0)
@assert restart1 "restart must be greater than or equal to 1"
return new{blocksize > 0}(restart, blocksize)

Check warning on line 29 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L27-L29

Added lines #L27 - L29 were not covered by tests
end
end

struct SimpleGMRESCache{UBD, T, QType, HType, xType, rType, βe₁Type, AType, bType, βType}
M::Int
N::Int
maxiters::Int
blocksize::Int
ϵ::T
Q::QType
H::HType
x::xType
r::rType
βe₁::βe₁Type
A::AType
b::bType
β::βType
abstol::T

function SimpleGMRESCache{UBD}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, β,

Check warning on line 49 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L49

Added line #L49 was not covered by tests
abstol) where {UBD}
return new{UBD, typeof(ϵ), typeof(Q), typeof(H), typeof(x), typeof(r), typeof(βe₁),

Check warning on line 51 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L51

Added line #L51 was not covered by tests
typeof(A), typeof(b), typeof(β)}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁,
A, b, β, abstol)
end
end

_no_preconditioner(::Nothing) = true
_no_preconditioner(::IdentityOperator) = true
_no_preconditioner(::UniformScaling) = true
_no_preconditioner(_) = false

Check warning on line 60 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L57-L60

Added lines #L57 - L60 were not covered by tests

function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...)
return _init_cacheval(Val(false), alg, args...; kwargs...)

Check warning on line 63 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
end

# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks
# and use the specialized dispatch
function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int,

Check warning on line 68 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L68

Added line #L68 was not covered by tests
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true)
if zeroinit
return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)),

Check warning on line 71 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0),
A, b, zero(eltype(u)), abstol)
end

@assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead."
N = LinearAlgebra.checksquare(A)
T = eltype(u)
M = min(maxiters, alg.restart)
ϵ = eps(T)

Check warning on line 80 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L76-L80

Added lines #L76 - L80 were not covered by tests

# Initialize the Cache
## Use `b` since `A` might be an operator
Q = similar(b, length(b), M + 1)
H = similar(b, M + 1, M)
fill!(H, zero(T))

Check warning on line 86 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L84-L86

Added lines #L84 - L86 were not covered by tests

mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u
axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b
β = norm(@view(Q[:, 1]), 2)
Q[:, 1] ./= β

Check warning on line 91 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L88-L91

Added lines #L88 - L91 were not covered by tests

x = u
r = similar(b)
βe₁ = similar(b, M + 1)
fill!(βe₁, 0)
βe₁[1:1] .= β # Avoid the scalar indexing error

Check warning on line 97 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L93-L97

Added lines #L93 - L97 were not covered by tests

return SimpleGMRESCache{false}(M, N, maxiters, alg.blocksize, ϵ, Q, H, x, r, βe₁, A, b,

Check warning on line 99 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L99

Added line #L99 was not covered by tests
β, abstol)
end

default_alias_A(::SimpleGMRES, ::Any, ::Any) = false
default_alias_b(::SimpleGMRES, ::Any, ::Any) = false

Check warning on line 104 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L103-L104

Added lines #L103 - L104 were not covered by tests

function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...)
if cache.isfresh
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,

Check warning on line 108 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L106-L108

Added lines #L106 - L108 were not covered by tests
cache.maxiters, cache.abstol, cache.reltol, cache.verbose,
cache.assumptions; zeroinit = false)
cache.cacheval = solver
cache.isfresh = false

Check warning on line 112 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L111-L112

Added lines #L111 - L112 were not covered by tests
end
return SciMLBase.solve!(cache.cacheval)

Check warning on line 114 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L114

Added line #L114 was not covered by tests
end

function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T}
@unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache
norm2 = Base.Fix2(norm, 2)
res_norm = β

Check warning on line 120 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L117-L120

Added lines #L117 - L120 were not covered by tests

# FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES
# version
for _ in 1:(maxiters ÷ M + 1)
for j in 1:M
Qⱼ₊₁ = @view(Q[:, j + 1])
mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j)
for i in 1:j
H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁)
axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁)
end
H[j + 1, j] = norm2(Qⱼ₊₁)
H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j])

Check warning on line 133 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L124-L133

Added lines #L124 - L133 were not covered by tests

# FIXME: Figure out a way to avoid the allocation
# Using views doesn't work very well with LinearSolve
y = @view(H[1:(j + 1), 1:j]) \ @view(βe₁[1:(j + 1)])

Check warning on line 137 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L137

Added line #L137 was not covered by tests

# Update the solution
mul!(x, @view(Q[:, 1:j]), y)
mul!(r, A, x, T(-1), T(0))
axpy!(T(1), b, r)
res_norm = norm2(r)

Check warning on line 143 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L140-L143

Added lines #L140 - L143 were not covered by tests

if res_norm < abstol
return SciMLBase.build_linear_solution(nothing, x, r, nothing;

Check warning on line 146 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
retcode = ReturnCode.Success)
end
end

Check warning on line 149 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L149

Added line #L149 was not covered by tests

# Restart
Q[:, 1] = r ./ res_norm
fill!(H, zero(T))
end

Check warning on line 154 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L152-L154

Added lines #L152 - L154 were not covered by tests

return SciMLBase.build_linear_solution(nothing, x, r, nothing;

Check warning on line 156 in src/simplegmres.jl

View check run for this annotation

Codecov / codecov/patch

src/simplegmres.jl#L156

Added line #L156 was not covered by tests
retcode = ReturnCode.MaxIters)
end

0 comments on commit adde736

Please sign in to comment.