Skip to content

Commit adde736

Browse files
committed
Add SimpleGMRES implementation
1 parent f531dd0 commit adde736

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
Manifest.toml
66

77
*.swp
8+
.vscode
9+
wip

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,30 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2828
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2929

3030
[weakdeps]
31+
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
3132
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3233
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3334
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3435
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
3536
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
3637
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
38+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3739
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3840

3941
[extensions]
42+
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4043
LinearSolveCUDAExt = "CUDA"
4144
LinearSolveHYPREExt = "HYPRE"
4245
LinearSolveIterativeSolversExt = "IterativeSolvers"
4346
LinearSolveKrylovKitExt = "KrylovKit"
4447
LinearSolveMKLExt = "MKL_jll"
4548
LinearSolveMetalExt = "Metal"
49+
LinearSolveNNlibExt = "NNlib"
4650
LinearSolvePardisoExt = "Pardiso"
4751

4852
[compat]
4953
ArrayInterface = "7.4.11"
54+
BlockDiagonals = "0.1"
5055
DocStringExtensions = "0.8, 0.9"
5156
EnumX = "1"
5257
FastLapackInterface = "1, 2"
@@ -56,6 +61,7 @@ IterativeSolvers = "0.9.2"
5661
KLU = "0.3.0, 0.4"
5762
Krylov = "0.9"
5863
KrylovKit = "0.5, 0.6"
64+
NNlib = "0.9"
5965
PrecompileTools = "1"
6066
Preferences = "1"
6167
RecursiveFactorization = "0.2.8"

ext/LinearSolveBlockDiagonalsExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module LinearSolveBlockDiagonalsExt
2+
3+
using LinearSolve, BlockDiagonals
4+
5+
function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, u, Pl, Pr,
6+
maxiters::Int, abstol, reltol, verbose, assumptions; zeroinit = true)
7+
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
8+
# We need to perform this check even when `zeroinit == true`, since the type of the
9+
# cache is dependent on whether we are able to use the specialized dispatch.
10+
bsizes = blocksizes(A)
11+
usize = first(first(bsizes))
12+
uniform_blocks = true
13+
for bsize in bsizes
14+
if bsize[1] != usize || bsize[2] != usize
15+
uniform_blocks = false
16+
break
17+
end
18+
end
19+
# Can't help but perform dynamic dispatch here
20+
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters,
21+
abstol, reltol, verbose, assumptions; zeroinit)
22+
end
23+
24+
end

ext/LinearSolveNNlibExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module LinearSolveNNlibExt
2+
3+
using LinearSolve, NNlib
4+
5+
end

src/LinearSolve.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using Requires
2424
import InteractiveUtils
2525

2626
using LinearAlgebra: BlasInt, LU
27-
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
27+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
2828
@blasfunc, chkargsok
2929

3030
import GPUArraysCore
@@ -85,6 +85,7 @@ end
8585
include("common.jl")
8686
include("factorization.jl")
8787
include("simplelu.jl")
88+
include("simplegmres.jl")
8889
include("iterative_wrappers.jl")
8990
include("preconditioners.jl")
9091
include("solve_function.jl")
@@ -171,6 +172,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
171172
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
172173
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES
173174

175+
export SimpleGMRES
176+
174177
export HYPREAlgorithm
175178
export CudaOffloadFactorization
176179
export MKLPardisoFactorize, MKLPardisoIterate

src/simplegmres.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""
2+
SimpleGMRES(; restart::Int = 20, blocksize::Int = 0)
3+
4+
A simple GMRES implementation for square non-Hermitian linear systems.
5+
6+
This implementation handles Block Diagonal Matrices with Uniformly Sized Square Blocks with
7+
specialized dispatches.
8+
9+
## Arguments
10+
11+
* `restart::Int = 20`: the number of iterations before restarting. Must be a strictly
12+
positive integer.
13+
* `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a
14+
uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing
15+
this option will lead to incorrect results.
16+
* If this is set `≤ 0` and during runtime we get a Block Diagonal Matrix, then we will
17+
check if the specialized dispatch can be used.
18+
19+
!!! warning
20+
21+
Most users should be using the `KrylovJL_GMRES` solver instead of this implementation.
22+
"""
23+
struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod
24+
restart::Int
25+
blocksize::Int
26+
27+
function SimpleGMRES(; restart::Int = 20, blocksize::Int = 0)
28+
@assert restart1 "restart must be greater than or equal to 1"
29+
return new{blocksize > 0}(restart, blocksize)
30+
end
31+
end
32+
33+
struct SimpleGMRESCache{UBD, T, QType, HType, xType, rType, βe₁Type, AType, bType, βType}
34+
M::Int
35+
N::Int
36+
maxiters::Int
37+
blocksize::Int
38+
ϵ::T
39+
Q::QType
40+
H::HType
41+
x::xType
42+
r::rType
43+
βe₁::βe₁Type
44+
A::AType
45+
b::bType
46+
β::βType
47+
abstol::T
48+
49+
function SimpleGMRESCache{UBD}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, β,
50+
abstol) where {UBD}
51+
return new{UBD, typeof(ϵ), typeof(Q), typeof(H), typeof(x), typeof(r), typeof(βe₁),
52+
typeof(A), typeof(b), typeof(β)}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁,
53+
A, b, β, abstol)
54+
end
55+
end
56+
57+
_no_preconditioner(::Nothing) = true
58+
_no_preconditioner(::IdentityOperator) = true
59+
_no_preconditioner(::UniformScaling) = true
60+
_no_preconditioner(_) = false
61+
62+
function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...)
63+
return _init_cacheval(Val(false), alg, args...; kwargs...)
64+
end
65+
66+
# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks
67+
# and use the specialized dispatch
68+
function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int,
69+
abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true)
70+
if zeroinit
71+
return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)),
72+
similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0),
73+
A, b, zero(eltype(u)), abstol)
74+
end
75+
76+
@assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead."
77+
N = LinearAlgebra.checksquare(A)
78+
T = eltype(u)
79+
M = min(maxiters, alg.restart)
80+
ϵ = eps(T)
81+
82+
# Initialize the Cache
83+
## Use `b` since `A` might be an operator
84+
Q = similar(b, length(b), M + 1)
85+
H = similar(b, M + 1, M)
86+
fill!(H, zero(T))
87+
88+
mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u
89+
axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b
90+
β = norm(@view(Q[:, 1]), 2)
91+
Q[:, 1] ./= β
92+
93+
x = u
94+
r = similar(b)
95+
βe₁ = similar(b, M + 1)
96+
fill!(βe₁, 0)
97+
βe₁[1:1] .= β # Avoid the scalar indexing error
98+
99+
return SimpleGMRESCache{false}(M, N, maxiters, alg.blocksize, ϵ, Q, H, x, r, βe₁, A, b,
100+
β, abstol)
101+
end
102+
103+
default_alias_A(::SimpleGMRES, ::Any, ::Any) = false
104+
default_alias_b(::SimpleGMRES, ::Any, ::Any) = false
105+
106+
function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...)
107+
if cache.isfresh
108+
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,
109+
cache.maxiters, cache.abstol, cache.reltol, cache.verbose,
110+
cache.assumptions; zeroinit = false)
111+
cache.cacheval = solver
112+
cache.isfresh = false
113+
end
114+
return SciMLBase.solve!(cache.cacheval)
115+
end
116+
117+
function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T}
118+
@unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache
119+
norm2 = Base.Fix2(norm, 2)
120+
res_norm = β
121+
122+
# FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES
123+
# version
124+
for _ in 1:(maxiters ÷ M + 1)
125+
for j in 1:M
126+
Qⱼ₊₁ = @view(Q[:, j + 1])
127+
mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j)
128+
for i in 1:j
129+
H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁)
130+
axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁)
131+
end
132+
H[j + 1, j] = norm2(Qⱼ₊₁)
133+
H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j])
134+
135+
# FIXME: Figure out a way to avoid the allocation
136+
# Using views doesn't work very well with LinearSolve
137+
y = @view(H[1:(j + 1), 1:j]) \ @view(βe₁[1:(j + 1)])
138+
139+
# Update the solution
140+
mul!(x, @view(Q[:, 1:j]), y)
141+
mul!(r, A, x, T(-1), T(0))
142+
axpy!(T(1), b, r)
143+
res_norm = norm2(r)
144+
145+
if res_norm < abstol
146+
return SciMLBase.build_linear_solution(nothing, x, r, nothing;
147+
retcode = ReturnCode.Success)
148+
end
149+
end
150+
151+
# Restart
152+
Q[:, 1] = r ./ res_norm
153+
fill!(H, zero(T))
154+
end
155+
156+
return SciMLBase.build_linear_solution(nothing, x, r, nothing;
157+
retcode = ReturnCode.MaxIters)
158+
end

0 commit comments

Comments
 (0)