From 2b9d8fd62abb50282f6bd39cd6cf72261874cbc2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 22 Aug 2023 18:14:26 -0400 Subject: [PATCH 01/13] Add SimpleGMRES implementation --- .gitignore | 2 + Project.toml | 6 ++ ext/LinearSolveBlockDiagonalsExt.jl | 24 +++++ ext/LinearSolveNNlibExt.jl | 5 + src/LinearSolve.jl | 5 +- src/simplegmres.jl | 158 ++++++++++++++++++++++++++++ 6 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 ext/LinearSolveBlockDiagonalsExt.jl create mode 100644 ext/LinearSolveNNlibExt.jl create mode 100644 src/simplegmres.jl diff --git a/.gitignore b/.gitignore index e454bf595..1b6ed4dea 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ Manifest.toml *.swp +.vscode +wip \ No newline at end of file diff --git a/Project.toml b/Project.toml index 4081ee050..e66e5f5ef 100644 --- a/Project.toml +++ b/Project.toml @@ -28,25 +28,30 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] +LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMKLExt = "MKL_jll" LinearSolveMetalExt = "Metal" +LinearSolveNNlibExt = "NNlib" LinearSolvePardisoExt = "Pardiso" [compat] ArrayInterface = "7.4.11" +BlockDiagonals = "0.1" DocStringExtensions = "0.8, 0.9" EnumX = "1" FastLapackInterface = "1, 2" @@ -56,6 +61,7 @@ IterativeSolvers = "0.9.2" KLU = "0.3.0, 0.4" Krylov = "0.9" KrylovKit = "0.5, 0.6" +NNlib = "0.9" PrecompileTools = "1" Preferences = "1" RecursiveFactorization = "0.2.8" diff --git a/ext/LinearSolveBlockDiagonalsExt.jl b/ext/LinearSolveBlockDiagonalsExt.jl new file mode 100644 index 000000000..6c1f61a68 --- /dev/null +++ b/ext/LinearSolveBlockDiagonalsExt.jl @@ -0,0 +1,24 @@ +module LinearSolveBlockDiagonalsExt + +using LinearSolve, BlockDiagonals + +function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose, assumptions; zeroinit = true) + @assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2." + # We need to perform this check even when `zeroinit == true`, since the type of the + # cache is dependent on whether we are able to use the specialized dispatch. + bsizes = blocksizes(A) + usize = first(first(bsizes)) + uniform_blocks = true + for bsize in bsizes + if bsize[1] != usize || bsize[2] != usize + uniform_blocks = false + break + end + end + # Can't help but perform dynamic dispatch here + return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters, + abstol, reltol, verbose, assumptions; zeroinit) +end + +end diff --git a/ext/LinearSolveNNlibExt.jl b/ext/LinearSolveNNlibExt.jl new file mode 100644 index 000000000..56e6e9a93 --- /dev/null +++ b/ext/LinearSolveNNlibExt.jl @@ -0,0 +1,5 @@ +module LinearSolveNNlibExt + +using LinearSolve, NNlib + +end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 95115b32b..d63a6e00e 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -24,7 +24,7 @@ using Requires import InteractiveUtils using LinearAlgebra: BlasInt, LU -using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, +using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, @blasfunc, chkargsok import GPUArraysCore @@ -85,6 +85,7 @@ end include("common.jl") include("factorization.jl") include("simplelu.jl") +include("simplegmres.jl") include("iterative_wrappers.jl") include("preconditioners.jl") include("solve_function.jl") @@ -171,6 +172,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES, IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES, KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES +export SimpleGMRES + export HYPREAlgorithm export CudaOffloadFactorization export MKLPardisoFactorize, MKLPardisoIterate diff --git a/src/simplegmres.jl b/src/simplegmres.jl new file mode 100644 index 000000000..5a468c30c --- /dev/null +++ b/src/simplegmres.jl @@ -0,0 +1,158 @@ +""" + SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) + +A simple GMRES implementation for square non-Hermitian linear systems. + +This implementation handles Block Diagonal Matrices with Uniformly Sized Square Blocks with +specialized dispatches. + +## Arguments + +* `restart::Int = 20`: the number of iterations before restarting. Must be a strictly + positive integer. +* `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a + uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing + this option will lead to incorrect results. + * If this is set `≤ 0` and during runtime we get a Block Diagonal Matrix, then we will + check if the specialized dispatch can be used. + +!!! warning + + Most users should be using the `KrylovJL_GMRES` solver instead of this implementation. +""" +struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod + restart::Int + blocksize::Int + + function SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) + @assert restart≥1 "restart must be greater than or equal to 1" + return new{blocksize > 0}(restart, blocksize) + end +end + +struct SimpleGMRESCache{UBD, T, QType, HType, xType, rType, βe₁Type, AType, bType, βType} + M::Int + N::Int + maxiters::Int + blocksize::Int + ϵ::T + Q::QType + H::HType + x::xType + r::rType + βe₁::βe₁Type + A::AType + b::bType + β::βType + abstol::T + + function SimpleGMRESCache{UBD}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, β, + abstol) where {UBD} + return new{UBD, typeof(ϵ), typeof(Q), typeof(H), typeof(x), typeof(r), typeof(βe₁), + typeof(A), typeof(b), typeof(β)}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, + A, b, β, abstol) + end +end + +_no_preconditioner(::Nothing) = true +_no_preconditioner(::IdentityOperator) = true +_no_preconditioner(::UniformScaling) = true +_no_preconditioner(_) = false + +function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...) + return _init_cacheval(Val(false), alg, args...; kwargs...) +end + +# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks +# and use the specialized dispatch +function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, + abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true) + if zeroinit + return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), + similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0), + A, b, zero(eltype(u)), abstol) + end + + @assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." + N = LinearAlgebra.checksquare(A) + T = eltype(u) + M = min(maxiters, alg.restart) + ϵ = eps(T) + + # Initialize the Cache + ## Use `b` since `A` might be an operator + Q = similar(b, length(b), M + 1) + H = similar(b, M + 1, M) + fill!(H, zero(T)) + + mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u + axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b + β = norm(@view(Q[:, 1]), 2) + Q[:, 1] ./= β + + x = u + r = similar(b) + βe₁ = similar(b, M + 1) + fill!(βe₁, 0) + βe₁[1:1] .= β # Avoid the scalar indexing error + + return SimpleGMRESCache{false}(M, N, maxiters, alg.blocksize, ϵ, Q, H, x, r, βe₁, A, b, + β, abstol) +end + +default_alias_A(::SimpleGMRES, ::Any, ::Any) = false +default_alias_b(::SimpleGMRES, ::Any, ::Any) = false + +function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) + if cache.isfresh + solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, + cache.maxiters, cache.abstol, cache.reltol, cache.verbose, + cache.assumptions; zeroinit = false) + cache.cacheval = solver + cache.isfresh = false + end + return SciMLBase.solve!(cache.cacheval) +end + +function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} + @unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache + norm2 = Base.Fix2(norm, 2) + res_norm = β + + # FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES + # version + for _ in 1:(maxiters ÷ M + 1) + for j in 1:M + Qⱼ₊₁ = @view(Q[:, j + 1]) + mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j) + for i in 1:j + H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁) + axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁) + end + H[j + 1, j] = norm2(Qⱼ₊₁) + H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j]) + + # FIXME: Figure out a way to avoid the allocation + # Using views doesn't work very well with LinearSolve + y = @view(H[1:(j + 1), 1:j]) \ @view(βe₁[1:(j + 1)]) + + # Update the solution + mul!(x, @view(Q[:, 1:j]), y) + mul!(r, A, x, T(-1), T(0)) + axpy!(T(1), b, r) + res_norm = norm2(r) + + if res_norm < abstol + return SciMLBase.build_linear_solution(nothing, x, r, nothing; + retcode = ReturnCode.Success) + end + end + + # Restart + Q[:, 1] = r ./ res_norm + fill!(H, zero(T)) + end + + return SciMLBase.build_linear_solution(nothing, x, r, nothing; + retcode = ReturnCode.MaxIters) +end From 786f9b1e8e92c4e977e63fe8c19806c830011e6a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 22 Aug 2023 18:19:20 -0400 Subject: [PATCH 02/13] Fix Project.toml --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index e66e5f5ef..75ff2a4e1 100644 --- a/Project.toml +++ b/Project.toml @@ -75,6 +75,7 @@ UnPack = "1" julia = "1.6" [extras] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -85,6 +86,7 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" From 8547c19a95fadfb79f9628e464e58bf84da2e065 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 14:28:15 -0400 Subject: [PATCH 03/13] Working version for GMRES with blocksize --- ext/LinearSolveBlockDiagonalsExt.jl | 2 +- ext/LinearSolveNNlibExt.jl | 50 ++++++++++++++++++++- src/simplegmres.jl | 69 +++++++++++++++++++++++------ 3 files changed, 105 insertions(+), 16 deletions(-) diff --git a/ext/LinearSolveBlockDiagonalsExt.jl b/ext/LinearSolveBlockDiagonalsExt.jl index 6c1f61a68..654224afc 100644 --- a/ext/LinearSolveBlockDiagonalsExt.jl +++ b/ext/LinearSolveBlockDiagonalsExt.jl @@ -18,7 +18,7 @@ function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, end # Can't help but perform dynamic dispatch here return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters, - abstol, reltol, verbose, assumptions; zeroinit) + abstol, reltol, verbose, assumptions; zeroinit, blocksize = usize) end end diff --git a/ext/LinearSolveNNlibExt.jl b/ext/LinearSolveNNlibExt.jl index 56e6e9a93..4598abff2 100644 --- a/ext/LinearSolveNNlibExt.jl +++ b/ext/LinearSolveNNlibExt.jl @@ -1,5 +1,53 @@ module LinearSolveNNlibExt -using LinearSolve, NNlib +using LinearAlgebra, LinearSolve, NNlib +import LinearSolve: SimpleGMRESCache, SimpleGMRES, OperatorAssumptions, _no_preconditioner, + _init_cacheval, _norm2, LinearCache +import UnPack: @unpack + +function SciMLBase.solve!(cache::SimpleGMRESCache{true, T}, lincache::LinearCache) where {T} + @unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol, blocksize = cache + res_norm = β + + # FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES + # version + for _ in 1:((maxiters ÷ M) + 1) + for j in 1:M + Qⱼ₊₁ = @view(Q[:, j + 1, :]) + mul!(vec(Qⱼ₊₁), A, vec(@view(Q[:, j, :]))) # Q(:,j+1) <- A Q(:, j) + for i in 1:j + H[i, j, :] .= vec(sum(@view(Q[:, i, :]) .* Qⱼ₊₁; dims = 1)) + Qⱼ₊₁ .-= H[i:i, j, :] .* @view(Q[:, i, :]) + end + H[j + 1, j, :] .= vec(_norm2(Qⱼ₊₁, 1)) + Qⱼ₊₁ ./= H[j + 1, j:j, :] + + # FIXME: Figure out a way to avoid the allocation + # Using views doesn't work very well with LinearSolve + y = similar(b, j, 1, size(H, 3)) + for bidx in 1:size(y, 3) + y[:, :, bidx] .= @view(H[1:(j + 1), 1:j, bidx]) \ @view(βe₁[1:(j + 1), bidx]) + end + + # Update the solution + batched_mul!(reshape(x, blocksize, 1, :), @view(Q[:, 1:j, :]), y) + mul!(r, A, x, T(-1), T(0)) + r .+= b + res_norm = _norm2(reshape(r, blocksize, :), 1) + + if maximum(res_norm) < abstol + return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; + retcode = ReturnCode.Success) + end + end + + # Restart + Q[:, 1, :] = reshape(r, blocksize, :) ./ res_norm + fill!(H, zero(T)) + end + + return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; + retcode = ReturnCode.MaxIters) +end end diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 5a468c30c..0799f4240 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -59,14 +59,15 @@ _no_preconditioner(::IdentityOperator) = true _no_preconditioner(::UniformScaling) = true _no_preconditioner(_) = false -function init_cacheval(alg::SimpleGMRES{false}, args...; kwargs...) - return _init_cacheval(Val(false), alg, args...; kwargs...) +_norm2(x) = norm(x, 2) +_norm2(x, dims) = .√(sum(abs2, x; dims)) + +function init_cacheval(alg::SimpleGMRES{UDB}, args...; kwargs...) where {UDB} + return _init_cacheval(Val(UDB), alg, args...; kwargs...) end -# TODO: We can check if `A` is a block diagonal matrix with uniformly sized square blocks -# and use the specialized dispatch function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, - abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true) + abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true, kwargs...) if zeroinit return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0), @@ -75,6 +76,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter @assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." N = LinearAlgebra.checksquare(A) + @assert N == length(b) "The size of `A` and `b` must match." T = eltype(u) M = min(maxiters, alg.restart) ϵ = eps(T) @@ -87,7 +89,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b - β = norm(@view(Q[:, 1]), 2) + β = _norm2(@view(Q[:, 1])) Q[:, 1] ./= β x = u @@ -100,6 +102,45 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter β, abstol) end +function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, + abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true, + blocksize = alg.blocksize) + if zeroinit + return SimpleGMRESCache{true}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), + similar(b, 0, 0, 0), similar(b, 0, 0, 0), u, similar(b, 0), similar(b, 0, 0), + A, b, similar(b, 0, 0), abstol) + end + + @assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." + N = LinearAlgebra.checksquare(A) + @assert mod(N, blocksize)==0 "The blocksize must divide the size of the matrix." + @assert N==length(b) "The size of `A` and `b` must match." + T = eltype(u) + M = min(maxiters, alg.restart) + ϵ = eps(T) + bsize = N ÷ blocksize + + # Initialize the Cache + ## Use `b` since `A` might be an operator + Q = similar(b, blocksize, M + 1, bsize) + H = similar(b, M + 1, M, bsize) + fill!(H, zero(T)) + + mul!(vec(@view(Q[:, 1, :])), A, u, T(-1), T(0)) # r0 <- A u + axpy!(T(1), b, vec(@view(Q[:, 1, :]))) # r0 <- r0 - b + β = _norm2(@view(Q[:, 1, :]), 1) + Q[:, 1, :] ./= β + + x = u + r = similar(b) + βe₁ = similar(b, M + 1, bsize) + fill!(βe₁, 0) + βe₁[1, :] .= vec(β) # Avoid the scalar indexing error + + return SimpleGMRESCache{true}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, + β, abstol) +end + default_alias_A(::SimpleGMRES, ::Any, ::Any) = false default_alias_b(::SimpleGMRES, ::Any, ::Any) = false @@ -111,17 +152,17 @@ function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) cache.cacheval = solver cache.isfresh = false end - return SciMLBase.solve!(cache.cacheval) + return SciMLBase.solve!(cache.cacheval, cache) end -function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} +function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}, + lincache::LinearCache) where {T} @unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache - norm2 = Base.Fix2(norm, 2) res_norm = β # FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES # version - for _ in 1:(maxiters ÷ M + 1) + for _ in 1:((maxiters ÷ M) + 1) for j in 1:M Qⱼ₊₁ = @view(Q[:, j + 1]) mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j) @@ -129,7 +170,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁) axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁) end - H[j + 1, j] = norm2(Qⱼ₊₁) + H[j + 1, j] = _norm2(Qⱼ₊₁) H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j]) # FIXME: Figure out a way to avoid the allocation @@ -140,10 +181,10 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} mul!(x, @view(Q[:, 1:j]), y) mul!(r, A, x, T(-1), T(0)) axpy!(T(1), b, r) - res_norm = norm2(r) + res_norm = _norm2(r) if res_norm < abstol - return SciMLBase.build_linear_solution(nothing, x, r, nothing; + return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; retcode = ReturnCode.Success) end end @@ -153,6 +194,6 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}) where {T} fill!(H, zero(T)) end - return SciMLBase.build_linear_solution(nothing, x, r, nothing; + return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; retcode = ReturnCode.MaxIters) end From f23f83c7ea66f16ffacdfac967b551f8cffeaca3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 23 Aug 2023 14:30:15 -0400 Subject: [PATCH 04/13] Add performance tip --- src/simplegmres.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 0799f4240..e17709d85 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -19,6 +19,15 @@ specialized dispatches. !!! warning Most users should be using the `KrylovJL_GMRES` solver instead of this implementation. + +!!! tip + + We can automatically detect if the matrix is a Block Diagonal Matrix with Uniformly + Sized Square Blocks. If this is the case, then we can use a specialized dispatch. + However, on most modern systems performing a single matrix-vector multiplication is + faster than performing multiple smaller matrix-vector multiplications (as in the case + of Block Diagonal Matrix). We recommend making the matrix dense (if size permits) and + specifying the `blocksize` argument. """ struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod restart::Int From 79fd696bd57816abb9c3c305aad16d837e5e0dae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 24 Aug 2023 18:13:45 -0400 Subject: [PATCH 05/13] Add faster GMRES version --- .JuliaFormatter.toml | 3 +- Project.toml | 6 +- ext/LinearSolveBlockDiagonalsExt.jl | 8 +- ext/LinearSolveNNlibExt.jl | 53 --- src/LinearSolve.jl | 1 + src/iterative_wrappers.jl | 2 +- src/simplegmres.jl | 613 ++++++++++++++++++++++------ 7 files changed, 499 insertions(+), 187 deletions(-) delete mode 100644 ext/LinearSolveNNlibExt.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 9c7935911..0f93ea574 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,3 @@ style = "sciml" -format_markdown = true \ No newline at end of file +format_markdown = true +annotate_untyped_fields_with_any = false \ No newline at end of file diff --git a/Project.toml b/Project.toml index 75ff2a4e1..e4c93e851 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.5.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" @@ -35,7 +36,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] @@ -61,7 +61,6 @@ IterativeSolvers = "0.9.2" KLU = "0.3.0, 0.4" Krylov = "0.9" KrylovKit = "0.5, 0.6" -NNlib = "0.9" PrecompileTools = "1" Preferences = "1" RecursiveFactorization = "0.2.8" @@ -83,10 +82,9 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/ext/LinearSolveBlockDiagonalsExt.jl b/ext/LinearSolveBlockDiagonalsExt.jl index 654224afc..1e9b053eb 100644 --- a/ext/LinearSolveBlockDiagonalsExt.jl +++ b/ext/LinearSolveBlockDiagonalsExt.jl @@ -2,8 +2,8 @@ module LinearSolveBlockDiagonalsExt using LinearSolve, BlockDiagonals -function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, u, Pl, Pr, - maxiters::Int, abstol, reltol, verbose, assumptions; zeroinit = true) +function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...; + kwargs...) @assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2." # We need to perform this check even when `zeroinit == true`, since the type of the # cache is dependent on whether we are able to use the specialized dispatch. @@ -17,8 +17,8 @@ function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, end end # Can't help but perform dynamic dispatch here - return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, u, Pl, Pr, maxiters, - abstol, reltol, verbose, assumptions; zeroinit, blocksize = usize) + return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, args...; + blocksize = usize, kwargs...) end end diff --git a/ext/LinearSolveNNlibExt.jl b/ext/LinearSolveNNlibExt.jl deleted file mode 100644 index 4598abff2..000000000 --- a/ext/LinearSolveNNlibExt.jl +++ /dev/null @@ -1,53 +0,0 @@ -module LinearSolveNNlibExt - -using LinearAlgebra, LinearSolve, NNlib -import LinearSolve: SimpleGMRESCache, SimpleGMRES, OperatorAssumptions, _no_preconditioner, - _init_cacheval, _norm2, LinearCache -import UnPack: @unpack - -function SciMLBase.solve!(cache::SimpleGMRESCache{true, T}, lincache::LinearCache) where {T} - @unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol, blocksize = cache - res_norm = β - - # FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES - # version - for _ in 1:((maxiters ÷ M) + 1) - for j in 1:M - Qⱼ₊₁ = @view(Q[:, j + 1, :]) - mul!(vec(Qⱼ₊₁), A, vec(@view(Q[:, j, :]))) # Q(:,j+1) <- A Q(:, j) - for i in 1:j - H[i, j, :] .= vec(sum(@view(Q[:, i, :]) .* Qⱼ₊₁; dims = 1)) - Qⱼ₊₁ .-= H[i:i, j, :] .* @view(Q[:, i, :]) - end - H[j + 1, j, :] .= vec(_norm2(Qⱼ₊₁, 1)) - Qⱼ₊₁ ./= H[j + 1, j:j, :] - - # FIXME: Figure out a way to avoid the allocation - # Using views doesn't work very well with LinearSolve - y = similar(b, j, 1, size(H, 3)) - for bidx in 1:size(y, 3) - y[:, :, bidx] .= @view(H[1:(j + 1), 1:j, bidx]) \ @view(βe₁[1:(j + 1), bidx]) - end - - # Update the solution - batched_mul!(reshape(x, blocksize, 1, :), @view(Q[:, 1:j, :]), y) - mul!(r, A, x, T(-1), T(0)) - r .+= b - res_norm = _norm2(reshape(r, blocksize, :), 1) - - if maximum(res_norm) < abstol - return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; - retcode = ReturnCode.Success) - end - end - - # Restart - Q[:, 1, :] = reshape(r, blocksize, :) ./ res_norm - fill!(H, zero(T)) - end - - return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; - retcode = ReturnCode.MaxIters) -end - -end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index d63a6e00e..4b7946793 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -29,6 +29,7 @@ using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, import GPUArraysCore import Preferences +import ConcreteStructs: @concrete # wrap import Krylov diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 402c71609..b37571cb5 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -253,7 +253,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) Krylov.solve!(args...; M = M, kwargs...) elseif cache.cacheval isa Krylov.GmresSolver - Krylov.solve!(args...; M = M, N = N, + Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0, kwargs...) elseif cache.cacheval isa Krylov.BicgstabSolver Krylov.solve!(args...; M = M, N = N, diff --git a/src/simplegmres.jl b/src/simplegmres.jl index e17709d85..abf8f94ae 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -8,8 +8,9 @@ specialized dispatches. ## Arguments -* `restart::Int = 20`: the number of iterations before restarting. Must be a strictly - positive integer. +* `restart::Bool`: If `true`, then the solver will restart after `memory` iterations. +* `memory::Int = 20`: The number of iterations before restarting. If restart is false, this + value is used to allocate memory and later expanded if more memory is required. * `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing this option will lead to incorrect results. @@ -30,37 +31,44 @@ specialized dispatches. specifying the `blocksize` argument. """ struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod - restart::Int + restart::Bool + memory::Int blocksize::Int + warm_start::Bool - function SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) - @assert restart≥1 "restart must be greater than or equal to 1" - return new{blocksize > 0}(restart, blocksize) + function SimpleGMRES(; restart::Bool = true, blocksize::Int = 0, + warm_start::Bool = false, memory::Int = 20) + return new{blocksize > 0}(restart, memory, blocksize, warm_start) end end -struct SimpleGMRESCache{UBD, T, QType, HType, xType, rType, βe₁Type, AType, bType, βType} - M::Int - N::Int +@concrete mutable struct SimpleGMRESCache{UBD} + memory::Int + n::Int + restart::Bool maxiters::Int blocksize::Int - ϵ::T - Q::QType - H::HType - x::xType - r::rType - βe₁::βe₁Type - A::AType - b::bType - β::βType - abstol::T - - function SimpleGMRESCache{UBD}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, β, - abstol) where {UBD} - return new{UBD, typeof(ϵ), typeof(Q), typeof(H), typeof(x), typeof(r), typeof(βe₁), - typeof(A), typeof(b), typeof(β)}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, - A, b, β, abstol) - end + ε + PlisI::Bool + PrisI::Bool + Pl + Pr + Δx + q + p + x + A + b + abstol + reltol + w + V + s + c + z + R + β + warm_start::Bool end _no_preconditioner(::Nothing) = true @@ -71,138 +79,495 @@ _no_preconditioner(_) = false _norm2(x) = norm(x, 2) _norm2(x, dims) = .√(sum(abs2, x; dims)) +default_alias_A(::SimpleGMRES, ::Any, ::Any) = false +default_alias_b(::SimpleGMRES, ::Any, ::Any) = false + +function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) + if cache.isfresh + solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, + cache.maxiters, cache.abstol, cache.reltol, cache.verbose, + cache.assumptions; zeroinit = false) + cache.cacheval = solver + cache.isfresh = false + end + return SciMLBase.solve!(cache.cacheval, cache) +end + function init_cacheval(alg::SimpleGMRES{UDB}, args...; kwargs...) where {UDB} return _init_cacheval(Val(UDB), alg, args...; kwargs...) end function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, - abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true, kwargs...) + abstol, reltol, ::Bool, ::OperatorAssumptions; zeroinit = true, kwargs...) + @unpack memory, restart, blocksize, warm_start = alg + if zeroinit - return SimpleGMRESCache{false}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), - similar(b, 0, 0), similar(b, 0, 0), u, similar(b, 0), similar(b, 0), - A, b, zero(eltype(u)), abstol) + return SimpleGMRESCache{false}(memory, 0, restart, maxiters, blocksize, + zero(eltype(u)) * reltol + abstol, false, false, Pl, Pr, similar(u, 0), + similar(u, 0), similar(u, 0), u, A, b, abstol, reltol, similar(u, 0), + Vector{typeof(u)}(undef, 0), Vector{eltype(u)}(undef, 0), + Vector{eltype(u)}(undef, 0), Vector{eltype(u)}(undef, 0), + Vector{eltype(u)}(undef, 0), zero(eltype(u)), warm_start) end - @assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." - N = LinearAlgebra.checksquare(A) - @assert N == length(b) "The size of `A` and `b` must match." T = eltype(u) - M = min(maxiters, alg.restart) - ϵ = eps(T) - - # Initialize the Cache - ## Use `b` since `A` might be an operator - Q = similar(b, length(b), M + 1) - H = similar(b, M + 1, M) - fill!(H, zero(T)) + n = LinearAlgebra.checksquare(A) + @assert n==length(b) "The size of `A` and `b` must match." + memory = min(memory, maxiters) - mul!(@view(Q[:, 1]), A, u, T(-1), T(0)) # r0 <- A u - axpy!(T(1), b, @view(Q[:, 1])) # r0 <- r0 - b - β = _norm2(@view(Q[:, 1])) - Q[:, 1] ./= β + PlisI = _no_preconditioner(Pl) + PrisI = _no_preconditioner(Pr) + Δx = restart ? similar(u, n) : similar(u, 0) + q = PlisI ? similar(u, 0) : similar(u, n) + p = PrisI ? similar(u, 0) : similar(u, n) x = u - r = similar(b) - βe₁ = similar(b, M + 1) - fill!(βe₁, 0) - βe₁[1:1] .= β # Avoid the scalar indexing error - return SimpleGMRESCache{false}(M, N, maxiters, alg.blocksize, ϵ, Q, H, x, r, βe₁, A, b, - β, abstol) + w = similar(u, n) + V = [similar(u) for _ in 1:memory] + s = Vector{eltype(x)}(undef, memory) + c = Vector{eltype(x)}(undef, memory) + + z = Vector{eltype(x)}(undef, memory) + R = Vector{eltype(x)}(undef, (memory * (memory + 1)) ÷ 2) + + q = PlisI ? w : q + r₀ = PlisI ? w : q + + # Initial residual r₀. + if warm_start + mul!(w, A, Δx) + axpby!(one(T), b, -one(T), w) + restart && axpy!(one(T), Δx, x) + else + w .= b + end + PlisI || mul!(r₀, Pl, w) # r₀ = Pl(b - Ax₀) + β = _norm2(r₀) # β = ‖r₀‖₂ + + rNorm = β + ε = abstol + reltol * rNorm + + return SimpleGMRESCache{false}(memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, + Pl, Pr, Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) +end + +function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) + @unpack memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr = cache + @unpack Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start = cache + + T = eltype(x) + q = PlisI ? w : q + r₀ = PlisI ? w : q + xr = restart ? Δx : x + + if β == 0 + return SciMLBase.build_linear_solution(nothing, x, r₀, nothing; + retcode = ReturnCode.Success) + end + + rNorm = β + npass = 0 # Number of pass + + iter = 0 # Cumulative number of iterations + inner_iter = 0 # Number of iterations in a pass + + # Tolerance for breakdown detection. + btol = eps(T)^(3 / 4) + + # Stopping criterion + breakdown = false + inconsistent = false + solved = rNorm ≤ ε + inner_maxiters = maxiters + tired = iter ≥ maxiters + inner_tired = inner_iter ≥ inner_maxiters + status = ReturnCode.Default + + while !(solved || tired || breakdown) + # Initialize workspace. + nr = 0 # Number of coefficients stored in Rₖ. + #= TODO: Check that not zeroing out doesn't lead to incorrect results. + foreach(V) do v + v .= zero(T) # Orthogonal basis of Kₖ(MAN, Mr₀). + end + s .= zero(T) # Givens sines used for the factorization QₖRₖ = Hₖ₊₁.ₖ. + c .= zero(T) # Givens cosines used for the factorization QₖRₖ = Hₖ₊₁.ₖ. + R .= zero(T) # Upper triangular matrix Rₖ. + z .= zero(T) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖyₖ - βe₁‖₂. + =# + + if restart + xr .= zero(T) # xr === Δx when restart is set to true + if npass ≥ 1 + mul!(w, A, x) + axpby!(one(T), b, -one(T), w) + PlisI || ldiv!(r₀, Pl, w) + end + end + + # Initial ζ₁ and V₁ + β = _norm2(r₀) + z[1] = β + V[1] .= r₀ / β + + npass = npass + 1 + inner_iter = 0 + inner_tired = false + + while !(solved || inner_tired || breakdown) + # Update iteration index + inner_iter += 1 + # Update workspace if more storage is required and restart is set to false + if !restart && (inner_iter > memory) + append!(R, zeros(T, inner_iter)) + push!(s, zero(T)) + push!(c, zero(T)) + end + + # Continue the Arnoldi process. + p = PrisI ? V[inner_iter] : p + PrisI || ldiv!(p, Pr, V[inner_iter]) # p ← Nvₖ + mul!(w, A, p) # w ← ANvₖ + PlisI || ldiv!(q, Pl, w) # q ← MANvₖ + for i in 1:inner_iter + R[nr + i] = dot(V[i], q) # hᵢₖ = (vᵢ)ᴴq + axpy!(-R[nr + i], V[i], q) # q ← q - hᵢₖvᵢ + end + + # Compute hₖ₊₁.ₖ + Hbis = _norm2(q) # hₖ₊₁.ₖ = ‖vₖ₊₁‖₂ + + # Update the QR factorization of Hₖ₊₁.ₖ. + # Apply previous Givens reflections Ωᵢ. + # [cᵢ sᵢ] [ r̄ᵢ.ₖ ] = [ rᵢ.ₖ ] + # [s̄ᵢ -cᵢ] [rᵢ₊₁.ₖ] [r̄ᵢ₊₁.ₖ] + for i in 1:(inner_iter - 1) + Rtmp = c[i] * R[nr + i] + s[i] * R[nr + i + 1] + R[nr + i + 1] = conj(s[i]) * R[nr + i] - c[i] * R[nr + i + 1] + R[nr + i] = Rtmp + end + + # Compute and apply current Givens reflection Ωₖ. + # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] + # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] + (c[inner_iter], s[inner_iter], R[nr + inner_iter]) = Krylov.sym_givens(R[nr + inner_iter], + Hbis) + + # Update zₖ = (Qₖ)ᴴβe₁ + ζₖ₊₁ = conj(s[inner_iter]) * z[inner_iter] + z[inner_iter] = c[inner_iter] * z[inner_iter] + + # Update residual norm estimate. + # ‖ Pl(b - Axₖ) ‖₂ = |ζₖ₊₁| + rNorm = abs(ζₖ₊₁) + + # Update the number of coefficients in Rₖ + nr = nr + inner_iter + + # Stopping conditions that do not depend on user input. + # This is to guard against tolerances that are unreasonably small. + resid_decrease_mach = (rNorm + one(T) ≤ one(T)) + + # Update stopping criterion. + resid_decrease_lim = rNorm ≤ ε + breakdown = Hbis ≤ btol + solved = resid_decrease_lim || resid_decrease_mach + inner_tired = restart ? inner_iter ≥ min(memory, inner_maxiters) : + inner_iter ≥ inner_maxiters + + # Compute vₖ₊₁. + if !(solved || inner_tired || breakdown) + if !restart && (inner_iter ≥ memory) + push!(V, similar(first(V))) + push!(z, zero(T)) + end + @. V[inner_iter + 1] = q / Hbis # hₖ₊₁.ₖvₖ₊₁ = q + z[inner_iter + 1] = ζₖ₊₁ + end + end + + # Compute yₖ by solving Rₖyₖ = zₖ with backward substitution. + y = z # yᵢ = zᵢ + for i in inner_iter:-1:1 + pos = nr + i - inner_iter # position of rᵢ.ₖ + for j in inner_iter:-1:(i + 1) + y[i] = y[i] - R[pos] * y[j] # yᵢ ← yᵢ - rᵢⱼyⱼ + pos = pos - j + 1 # position of rᵢ.ⱼ₋₁ + end + # Rₖ can be singular if the system is inconsistent + if abs(R[pos]) ≤ btol + y[i] = zero(T) + inconsistent = true + else + y[i] = y[i] / R[pos] # yᵢ ← yᵢ / rᵢᵢ + end + end + + # Form xₖ = NVₖyₖ + for i in 1:inner_iter + axpy!(y[i], V[i], xr) + end + if !PrisI + p .= xr + ldiv!(xr, Pr, p) + end + restart && axpy!(one(T), xr, x) + + # Update inner_itmax, iter, tired and overtimed variables. + inner_maxiters = inner_maxiters - inner_iter + iter = iter + inner_iter + tired = iter ≥ maxiters + end + + # Termination status + tired && (status = ReturnCode.MaxIters) + solved && (status = ReturnCode.Success) + inconsistent && (status = ReturnCode.Infeasible) + + # Update x + warm_start && !restart && axpy!(one(T), Δx, x) + cache.warm_start = false + + return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; + retcode = status) end function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, - abstol, ::Any, ::Bool, ::OperatorAssumptions; zeroinit = true, + abstol, reltol, ::Bool, ::OperatorAssumptions; zeroinit = true, blocksize = alg.blocksize) + @unpack memory, restart, warm_start = alg + if zeroinit - return SimpleGMRESCache{true}(0, 0, maxiters, alg.blocksize, zero(eltype(u)), - similar(b, 0, 0, 0), similar(b, 0, 0, 0), u, similar(b, 0), similar(b, 0, 0), - A, b, similar(b, 0, 0), abstol) + return SimpleGMRESCache{true}(memory, 0, restart, maxiters, blocksize, + zero(eltype(u)) * reltol + abstol, false, false, Pl, Pr, similar(u, 0), + similar(u, 0), similar(u, 0), u, A, b, abstol, reltol, similar(u, 0), + [u], [u], [u], [u], [u], zero(eltype(u)), warm_start) end - @assert _no_preconditioner(Pl)&&_no_preconditioner(Pr) "Preconditioning not supported! Use KrylovJL_GMRES instead." - N = LinearAlgebra.checksquare(A) - @assert mod(N, blocksize)==0 "The blocksize must divide the size of the matrix." - @assert N==length(b) "The size of `A` and `b` must match." T = eltype(u) - M = min(maxiters, alg.restart) - ϵ = eps(T) - bsize = N ÷ blocksize + n = LinearAlgebra.checksquare(A) + @assert mod(n, blocksize)==0 "The blocksize must divide the size of the matrix." + @assert n==length(b) "The size of `A` and `b` must match." + memory = min(memory, maxiters) + bsize = n ÷ blocksize + + PlisI = _no_preconditioner(Pl) + PrisI = _no_preconditioner(Pr) + + Δx = restart ? similar(u, n) : similar(u, 0) + q = PlisI ? similar(u, 0) : similar(u, n) + p = PrisI ? similar(u, 0) : similar(u, n) + x = u - # Initialize the Cache - ## Use `b` since `A` might be an operator - Q = similar(b, blocksize, M + 1, bsize) - H = similar(b, M + 1, M, bsize) - fill!(H, zero(T)) + w = similar(u, n) + V = [similar(u) for _ in 1:memory] + s = [similar(u, bsize) for _ in 1:memory] + c = [similar(u, bsize) for _ in 1:memory] - mul!(vec(@view(Q[:, 1, :])), A, u, T(-1), T(0)) # r0 <- A u - axpy!(T(1), b, vec(@view(Q[:, 1, :]))) # r0 <- r0 - b - β = _norm2(@view(Q[:, 1, :]), 1) - Q[:, 1, :] ./= β + z = [similar(u, bsize) for _ in 1:memory] + R = [similar(u, bsize) for _ in 1:((memory * (memory + 1)) ÷ 2)] - x = u - r = similar(b) - βe₁ = similar(b, M + 1, bsize) - fill!(βe₁, 0) - βe₁[1, :] .= vec(β) # Avoid the scalar indexing error + q = PlisI ? w : q + r₀ = PlisI ? w : q + + # Initial residual r₀. + if warm_start + mul!(w, A, Δx) + axpby!(one(T), b, -one(T), w) + restart && axpy!(one(T), Δx, x) + else + w .= b + end + PlisI || ldiv!(r₀, Pl, w) # r₀ = Pl(b - Ax₀) + β = _norm2(r₀) # β = ‖r₀‖₂ - return SimpleGMRESCache{true}(M, N, maxiters, blocksize, ϵ, Q, H, x, r, βe₁, A, b, - β, abstol) + rNorm = β + ε = abstol + reltol * rNorm + + return SimpleGMRESCache{true}(memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, + Pl, Pr, Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) end -default_alias_A(::SimpleGMRES, ::Any, ::Any) = false -default_alias_b(::SimpleGMRES, ::Any, ::Any) = false +function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) + @unpack memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr = cache + @unpack Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start = cache + bsize = n ÷ blocksize -function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) - if cache.isfresh - solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, - cache.maxiters, cache.abstol, cache.reltol, cache.verbose, - cache.assumptions; zeroinit = false) - cache.cacheval = solver - cache.isfresh = false + __batch = Base.Fix2(reshape, (blocksize, bsize)) + + T = eltype(x) + q = PlisI ? w : q + r₀ = PlisI ? w : q + xr = restart ? Δx : x + + if β == 0 + return SciMLBase.build_linear_solution(nothing, x, r₀, nothing; + retcode = ReturnCode.Success) end - return SciMLBase.solve!(cache.cacheval, cache) -end -function SciMLBase.solve!(cache::SimpleGMRESCache{false, T}, - lincache::LinearCache) where {T} - @unpack M, N, maxiters, ϵ, Q, H, x, r, βe₁, A, b, β, abstol = cache - res_norm = β - - # FIXME: The performance for this is quite bad when compared to the KrylovJL_GMRES - # version - for _ in 1:((maxiters ÷ M) + 1) - for j in 1:M - Qⱼ₊₁ = @view(Q[:, j + 1]) - mul!(Qⱼ₊₁, A, @view(Q[:, j])) # Q(:,j+1) <- A Q(:, j) - for i in 1:j - H[i, j] = dot(@view(Q[:, i]), Qⱼ₊₁) - axpy!(-H[i, j], @view(Q[:, i]), Qⱼ₊₁) + rNorm = β + npass = 0 # Number of pass + + iter = 0 # Cumulative number of iterations + inner_iter = 0 # Number of iterations in a pass + + # Tolerance for breakdown detection. + btol = eps(T)^(3 / 4) + + # Stopping criterion + breakdown = false + inconsistent = false + solved = rNorm ≤ ε + inner_maxiters = maxiters + tired = iter ≥ maxiters + inner_tired = inner_iter ≥ inner_maxiters + status = ReturnCode.Default + + while !(solved || tired || breakdown) + # Initialize workspace. + # TODO: Check that not zeroing out (V, s, c, R, z) doesn't lead to incorrect results. + nr = 0 # Number of coefficients stored in Rₖ. + + if restart + xr .= zero(T) # xr === Δx when restart is set to true + if npass ≥ 1 + mul!(w, A, x) + axpby!(one(T), b, -one(T), w) + PlisI || ldiv!(r₀, Pl, w) + end + end + + # Initial ζ₁ and V₁ + β = _norm2(__batch(r₀), 1) + z[1] .= vec(β) + V[1] .= vec(__batch(r₀) ./ β) + + npass = npass + 1 + inner_iter = 0 + inner_tired = false + + while !(solved || inner_tired || breakdown) + # Update iteration index + inner_iter += 1 + # Update workspace if more storage is required and restart is set to false + if !restart && (inner_iter > memory) + append!(R, [similar(first(R), bsize) for _ in 1:inner_iter]) + push!(s, similar(first(s), bsize)) + push!(c, similar(first(c), bsize)) + end + + # Continue the Arnoldi process. + p = PrisI ? V[inner_iter] : p + PrisI || ldiv!(p, Pr, V[inner_iter]) # p ← Nvₖ + mul!(w, A, p) # w ← ANvₖ + PlisI || ldiv!(q, Pl, w) # q ← MANvₖ + for i in 1:inner_iter + R[nr + i] .= vec(sum(__batch(V[i]) .* __batch(q); dims = 1)) # hᵢₖ = (vᵢ)ᴴq + q .-= vec(R[nr + i]' .* __batch(V[i])) # q ← q - hᵢₖvᵢ + end + + # Compute hₖ₊₁.ₖ + Hbis = vec(_norm2(__batch(q), 1)) # hₖ₊₁.ₖ = ‖vₖ₊₁‖₂ + + # Update the QR factorization of Hₖ₊₁.ₖ. + # Apply previous Givens reflections Ωᵢ. + # [cᵢ sᵢ] [ r̄ᵢ.ₖ ] = [ rᵢ.ₖ ] + # [s̄ᵢ -cᵢ] [rᵢ₊₁.ₖ] [r̄ᵢ₊₁.ₖ] + for i in 1:(inner_iter - 1) + Rtmp = c[i] .* R[nr + i] .+ s[i] .* R[nr + i + 1] + R[nr + i + 1] .= conj.(s[i]) .* R[nr + i] .- c[i] .* R[nr + i + 1] + R[nr + i] .= Rtmp + end + + # Compute and apply current Givens reflection Ωₖ. + # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] + # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] + # FIXME: Write inplace kernel + __res = Krylov.sym_givens.(R[nr + inner_iter], Hbis) + foreach(1:bsize) do i + c[inner_iter][i] = __res[i][1] + s[inner_iter][i] = __res[i][2] + R[nr + inner_iter][i] = __res[i][3] + end + + # Update zₖ = (Qₖ)ᴴβe₁ + ζₖ₊₁ = conj.(s[inner_iter]) .* z[inner_iter] + z[inner_iter] .= c[inner_iter] .* z[inner_iter] + + # Update residual norm estimate. + # ‖ Pl(b - Axₖ) ‖₂ = |ζₖ₊₁| + rNorm = maximum(abs, ζₖ₊₁) + + # Update the number of coefficients in Rₖ + nr = nr + inner_iter + + # Stopping conditions that do not depend on user input. + # This is to guard against tolerances that are unreasonably small. + resid_decrease_mach = (rNorm + one(T) ≤ one(T)) + + # Update stopping criterion. + resid_decrease_lim = rNorm ≤ ε + breakdown = maximum(Hbis) ≤ btol + solved = resid_decrease_lim || resid_decrease_mach + inner_tired = restart ? inner_iter ≥ min(memory, inner_maxiters) : + inner_iter ≥ inner_maxiters + + # Compute vₖ₊₁. + if !(solved || inner_tired || breakdown) + if !restart && (inner_iter ≥ memory) + push!(V, similar(first(V))) + push!(z, similar(first(z), bsize)) + end + V[inner_iter + 1] .= vec(__batch(q) ./ Hbis') # hₖ₊₁.ₖvₖ₊₁ = q + z[inner_iter + 1] .= ζₖ₊₁ + end + end + + # Compute yₖ by solving Rₖyₖ = zₖ with backward substitution. + y = z # yᵢ = zᵢ + for i in inner_iter:-1:1 + pos = nr + i - inner_iter # position of rᵢ.ₖ + for j in inner_iter:-1:(i + 1) + y[i] .= y[i] .- R[pos] .* y[j] # yᵢ ← yᵢ - rᵢⱼyⱼ + pos = pos - j + 1 # position of rᵢ.ⱼ₋₁ end - H[j + 1, j] = _norm2(Qⱼ₊₁) - H[j + 1, j] > ϵ && (Qⱼ₊₁ ./= H[j + 1, j]) - - # FIXME: Figure out a way to avoid the allocation - # Using views doesn't work very well with LinearSolve - y = @view(H[1:(j + 1), 1:j]) \ @view(βe₁[1:(j + 1)]) - - # Update the solution - mul!(x, @view(Q[:, 1:j]), y) - mul!(r, A, x, T(-1), T(0)) - axpy!(T(1), b, r) - res_norm = _norm2(r) - - if res_norm < abstol - return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; - retcode = ReturnCode.Success) + # Rₖ can be singular if the system is inconsistent + # FIXME: Write with broadcasting + GPUArraysCore.@allowscalar foreach(1:bsize) do B + if abs(R[pos][B]) ≤ btol + y[i][B] = zero(T) + inconsistent = true + else + y[i][B] /= R[pos][B] + end end end - # Restart - Q[:, 1] = r ./ res_norm - fill!(H, zero(T)) + # Form xₖ = NVₖyₖ + for i in 1:inner_iter + xr .+= vec(__batch(V[i]) .* y[i]') + end + if !PrisI + p .= xr + ldiv!(xr, Pr, p) + end + restart && axpy!(one(T), xr, x) + + # Update inner_itmax, iter, tired and overtimed variables. + inner_maxiters = inner_maxiters - inner_iter + iter = iter + inner_iter + tired = iter ≥ maxiters end - return SciMLBase.build_linear_solution(lincache.alg, x, r, lincache; - retcode = ReturnCode.MaxIters) + # Termination status + tired && (status = ReturnCode.MaxIters) + solved && (status = ReturnCode.Success) + inconsistent && (status = ReturnCode.Infeasible) + + # Update x + warm_start && !restart && axpy!(one(T), Δx, x) + + return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; + retcode = status) end From 07fc3fad6004fc950a4d0d081742d6d252c8098e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 24 Aug 2023 18:15:27 -0400 Subject: [PATCH 06/13] Remove NNlib --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index e4c93e851..a013c7e86 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,6 @@ LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMKLExt = "MKL_jll" LinearSolveMetalExt = "Metal" -LinearSolveNNlibExt = "NNlib" LinearSolvePardisoExt = "Pardiso" [compat] From bc748931e28dbb00aabafeee9e2b285941657dac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 12:10:20 -0400 Subject: [PATCH 07/13] Use a local copy of sym_givens --- src/simplegmres.jl | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index abf8f94ae..06637a9dc 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -1,5 +1,6 @@ """ - SimpleGMRES(; restart::Int = 20, blocksize::Int = 0) + SimpleGMRES(; restart::Bool = true, blocksize::Int = 0, warm_start::Bool = false, + memory::Int = 20) A simple GMRES implementation for square non-Hermitian linear systems. @@ -71,6 +72,39 @@ end warm_start::Bool end +""" + (c, s, ρ) = _sym_givens(a, b) + +Numerically stable symmetric Givens reflection. +Given `a` and `b` reals, return `(c, s, ρ)` such that + + [ c s ] [ a ] = [ ρ ] + [ s -c ] [ b ] = [ 0 ]. +""" +function _sym_givens(a::T, b::T) where {T <: AbstractFloat} + # This has taken from Krylov.jl + if b == 0 + c = ifelse(a == 0, one(T), sign(a)) # In Julia, sign(0) = 0. + s = zero(T) + ρ = abs(a) + elseif a == 0 + c = zero(T) + s = sign(b) + ρ = abs(b) + elseif abs(b) > abs(a) + t = a / b + s = sign(b) / sqrt(one(T) + t * t) + c = s * t + ρ = b / s # Computationally better than ρ = a / c since |c| ≤ |s|. + else + t = b / a + c = sign(a) / sqrt(one(T) + t * t) + s = c * t + ρ = a / c # Computationally better than ρ = b / s since |s| ≤ |c| + end + return (c, s, ρ) +end + _no_preconditioner(::Nothing) = true _no_preconditioner(::IdentityOperator) = true _no_preconditioner(::UniformScaling) = true @@ -162,7 +196,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) xr = restart ? Δx : x if β == 0 - return SciMLBase.build_linear_solution(nothing, x, r₀, nothing; + return SciMLBase.build_linear_solution(lincache.alg, x, r₀, lincache; retcode = ReturnCode.Success) end @@ -251,7 +285,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) # Compute and apply current Givens reflection Ωₖ. # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] - (c[inner_iter], s[inner_iter], R[nr + inner_iter]) = Krylov.sym_givens(R[nr + inner_iter], + (c[inner_iter], s[inner_iter], R[nr + inner_iter]) = _sym_givens(R[nr + inner_iter], Hbis) # Update zₖ = (Qₖ)ᴴβe₁ @@ -402,7 +436,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) xr = restart ? Δx : x if β == 0 - return SciMLBase.build_linear_solution(nothing, x, r₀, nothing; + return SciMLBase.build_linear_solution(lincache.alg, x, r₀, lincache; retcode = ReturnCode.Success) end @@ -484,7 +518,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] # FIXME: Write inplace kernel - __res = Krylov.sym_givens.(R[nr + inner_iter], Hbis) + __res = _sym_givens.(R[nr + inner_iter], Hbis) foreach(1:bsize) do i c[inner_iter][i] = __res[i][1] s[inner_iter][i] = __res[i][2] From 19312b10fe80eec2920b63956556b751c8552464 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 13:21:31 -0400 Subject: [PATCH 08/13] Use KA.jl for faster sym_givens --- Project.toml | 4 +++ ext/LinearSolveKernelAbstractionsExt.jl | 24 ++++++++++++++ src/LinearSolve.jl | 5 +++ src/simplegmres.jl | 42 ++++++++++--------------- 4 files changed, 50 insertions(+), 25 deletions(-) create mode 100644 ext/LinearSolveKernelAbstractionsExt.jl diff --git a/Project.toml b/Project.toml index a013c7e86..90395cd70 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" @@ -43,6 +44,7 @@ LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" +LinearSolveKernelAbstractionsExt = "KernelAbstractions" LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMKLExt = "MKL_jll" LinearSolveMetalExt = "Metal" @@ -58,6 +60,7 @@ GPUArraysCore = "0.1" HYPRE = "1.4.0" IterativeSolvers = "0.9.2" KLU = "0.3.0, 0.4" +KernelAbstractions = "0.9" Krylov = "0.9" KrylovKit = "0.5, 0.6" PrecompileTools = "1" @@ -79,6 +82,7 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" diff --git a/ext/LinearSolveKernelAbstractionsExt.jl b/ext/LinearSolveKernelAbstractionsExt.jl new file mode 100644 index 000000000..ba620382f --- /dev/null +++ b/ext/LinearSolveKernelAbstractionsExt.jl @@ -0,0 +1,24 @@ +module LinearSolveKernelAbstractionsExt + +using LinearSolve, KernelAbstractions + +LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true + +using GPUArraysCore + +function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis) + backend = get_backend(Hbis) + kernel! = __fast_sym_givens_kernel!(backend) + kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize) + return c, s, R +end + +@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis)) + idx = @index(Global) + @inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx]) + @inbounds c[idx] = _c + @inbounds s[idx] = _s + @inbounds R[idx] = _ρ +end + +end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 4b7946793..d904d6878 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -56,6 +56,11 @@ _isidentity_struct(λ::Number) = isone(λ) _isidentity_struct(A::UniformScaling) = isone(A.λ) _isidentity_struct(::SciMLOperators.IdentityOperator) = true +# Dispatch Friendly way to check if an extension is loaded +__is_extension_loaded(::Val) = false + +function _fast_sym_givens! end + # Code const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 06637a9dc..a4e0c9db5 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -105,6 +105,19 @@ function _sym_givens(a::T, b::T) where {T <: AbstractFloat} return (c, s, ρ) end +function _sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis) + if __is_extension_loaded(Val(:KernelAbstractions)) + return _fast_sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis) + end + __res = _sym_givens.(R[nr + inner_iter], Hbis) + GPUArraysCore.@allowscalar foreach(1:bsize) do i + c[inner_iter][i] = __res[i][1] + s[inner_iter][i] = __res[i][2] + R[nr + inner_iter][i] = __res[i][3] + end + return c, s, R +end + _no_preconditioner(::Nothing) = true _no_preconditioner(::IdentityOperator) = true _no_preconditioner(::UniformScaling) = true @@ -221,15 +234,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) while !(solved || tired || breakdown) # Initialize workspace. nr = 0 # Number of coefficients stored in Rₖ. - #= TODO: Check that not zeroing out doesn't lead to incorrect results. - foreach(V) do v - v .= zero(T) # Orthogonal basis of Kₖ(MAN, Mr₀). - end - s .= zero(T) # Givens sines used for the factorization QₖRₖ = Hₖ₊₁.ₖ. - c .= zero(T) # Givens cosines used for the factorization QₖRₖ = Hₖ₊₁.ₖ. - R .= zero(T) # Upper triangular matrix Rₖ. - z .= zero(T) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖyₖ - βe₁‖₂. - =# + # TODO: Check that not zeroing out doesn't lead to incorrect results. if restart xr .= zero(T) # xr === Δx when restart is set to true @@ -517,13 +522,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) # Compute and apply current Givens reflection Ωₖ. # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] - # FIXME: Write inplace kernel - __res = _sym_givens.(R[nr + inner_iter], Hbis) - foreach(1:bsize) do i - c[inner_iter][i] = __res[i][1] - s[inner_iter][i] = __res[i][2] - R[nr + inner_iter][i] = __res[i][3] - end + _sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis) # Update zₖ = (Qₖ)ᴴβe₁ ζₖ₊₁ = conj.(s[inner_iter]) .* z[inner_iter] @@ -567,15 +566,8 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) pos = pos - j + 1 # position of rᵢ.ⱼ₋₁ end # Rₖ can be singular if the system is inconsistent - # FIXME: Write with broadcasting - GPUArraysCore.@allowscalar foreach(1:bsize) do B - if abs(R[pos][B]) ≤ btol - y[i][B] = zero(T) - inconsistent = true - else - y[i][B] /= R[pos][B] - end - end + y[i] .= ifelse.(abs.(R[pos]) .≤ btol, zero(T), y[i] ./ R[pos]) # yᵢ ← yᵢ / rᵢᵢ + inconsistent = any(abs.(R[pos]) .≤ btol) end # Form xₖ = NVₖyₖ From d354ab1e9749978bc67a6f88319b759f2d577d03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 15:13:08 -0400 Subject: [PATCH 09/13] Use inplace accumulation --- src/simplegmres.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index a4e0c9db5..c2a81b722 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -502,7 +502,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) mul!(w, A, p) # w ← ANvₖ PlisI || ldiv!(q, Pl, w) # q ← MANvₖ for i in 1:inner_iter - R[nr + i] .= vec(sum(__batch(V[i]) .* __batch(q); dims = 1)) # hᵢₖ = (vᵢ)ᴴq + sum!(R[nr + i]', __batch(V[i]) .* __batch(q)) q .-= vec(R[nr + i]' .* __batch(V[i])) # q ← q - hᵢₖvᵢ end From d738dcc57fc580ca3c66868151ad3f0b41b0ce73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 17:25:53 -0400 Subject: [PATCH 10/13] Propagate Iteration Statistics --- src/simplegmres.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index c2a81b722..1ebd2ea2a 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -369,7 +369,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) cache.warm_start = false return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; - retcode = status) + retcode = status, iters = iter) end function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, @@ -595,5 +595,5 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) warm_start && !restart && axpy!(one(T), Δx, x) return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; - retcode = status) + retcode = status, iters = iter) end From b1e4bf4335db71a0c6544989dfa6eac3f0884df4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 17:44:23 -0400 Subject: [PATCH 11/13] Add tests --- Project.toml | 2 +- src/simplegmres.jl | 2 ++ test/basictests.jl | 31 ++++++++++++++++++++++++++++++- test/gpu/Project.toml | 1 + test/gpu/cuda.jl | 24 +++++++++++++++++++++++- 5 files changed, 57 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 90395cd70..2ff81fb08 100644 --- a/Project.toml +++ b/Project.toml @@ -94,4 +94,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"] diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 1ebd2ea2a..9ab482d06 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -169,6 +169,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter q = PlisI ? similar(u, 0) : similar(u, n) p = PrisI ? similar(u, 0) : similar(u, n) x = u + x .= zero(T) w = similar(u, n) V = [similar(u) for _ in 1:memory] @@ -398,6 +399,7 @@ function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters q = PlisI ? similar(u, 0) : similar(u, n) p = PrisI ? similar(u, 0) : similar(u, n) x = u + x .= zero(T) w = similar(u, n) V = [similar(u) for _ in 1:memory] diff --git a/test/basictests.jl b/test/basictests.jl index a58d4987d..401568d19 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -235,6 +235,10 @@ end end end + @testset "Simple GMRES: restart = $restart" for restart in (true, false) + test_interface(SimpleGMRES(; restart), prob1, prob2) + end + @testset "KrylovJL" begin kwargs = (; gmres_restart = 5) for alg in (("Default", KrylovJL(kwargs...)), @@ -412,7 +416,7 @@ end @testset "DirectLdiv!" begin function get_operator(A, u; add_inverse = true) - + function f(u, p, t) println("using FunctionOperator OOP mul") A * u @@ -470,3 +474,28 @@ lp = LinearProblem(A, b; u0 = view(u0, :)); truesol = solve(lp, LUFactorization()) krylovsol = solve(lp, KrylovJL_GMRES()) @test truesol ≈ krylovsol + +# Block Diagonals +using BlockDiagonals + +@testset "Block Diagonal Specialization" begin + A = BlockDiagonal([rand(2, 2) for _ in 1:3]) + b = rand(size(A, 1)) + + if VERSION > v"1.9-" + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(A, b, x1) + prob2 = LinearProblem(A, b, x2) + test_interface(SimpleGMRES(), prob1, prob2) + end + + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(Array(A), b, x1) + prob2 = LinearProblem(Array(A), b, x2) + + test_interface(SimpleGMRES(; blocksize=2), prob1, prob2) + + @test solve(prob1, SimpleGMRES(; blocksize=2)).u ≈ solve(prob2, SimpleGMRES()).u +end diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 8ea63055c..7fc6e3847 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,4 +1,5 @@ [deps] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" diff --git a/test/gpu/cuda.jl b/test/gpu/cuda.jl index 9ae035501..042576300 100644 --- a/test/gpu/cuda.jl +++ b/test/gpu/cuda.jl @@ -42,10 +42,32 @@ function test_interface(alg, prob1, prob2) return end -test_interface(CudaOffloadFactorization(), prob1, prob2) +@testset "CudaOffloadFactorization" begin + test_interface(CudaOffloadFactorization(), prob1, prob2) +end + +@testset "Simple GMRES: restart = $restart" for restart in (true, false) + test_interface(SimpleGMRES(; restart), prob1, prob2) +end A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0; y = solve(prob1) @test A1 * y ≈ b1 + +using BlockDiagonals + +@testset "Block Diagonal Specialization" begin + A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu + b = rand(size(A, 1)) |> cu + + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(A, b, x1) + prob2 = LinearProblem(A, b, x2) + + test_interface(SimpleGMRES(; blocksize=2), prob1, prob2) + + @test solve(prob1, SimpleGMRES(; blocksize=2)).u ≈ solve(prob2, SimpleGMRES()).u +end From f6155f6e68cb030c9b52c6c493d16edb58eb30f7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Aug 2023 17:49:25 -0400 Subject: [PATCH 12/13] Add documentation entry --- docs/src/solvers/solvers.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index f7e52c2a7..fe8a8b6e4 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -72,6 +72,14 @@ choice of Krylov method should be the one most constrained to the type of operat has, for example if positive definite then `Krylov_CG()`, but if no good properties then use `Krylov_GMRES()`. +!!! tip + + If your materialized operator is a uniform block diagonal matrix, then you can use + `SimpleGMRES(; blocksize = )` to further improve performance. + This often shows up in Neural Networks where the Jacobian wrt the Inputs (almost always) + is a Uniform Block Diagonal matrix of Block Size = size of the input divided by the + batch size. + ## Full List of Methods ### RecursiveFactorization.jl @@ -106,6 +114,7 @@ LinearSolve.jl contains some linear solvers built in for specailized cases. ```@docs SimpleLUFactorization DiagonalFactorization +SimpleGMRES ``` ### FastLapackInterface.jl From 38191f64579cb0b16d2135f9d8932908079ce0b2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Sep 2023 16:12:36 -0400 Subject: [PATCH 13/13] Remove comment --- src/simplegmres.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 9ab482d06..924cdaeea 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -235,7 +235,6 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) while !(solved || tired || breakdown) # Initialize workspace. nr = 0 # Number of coefficients stored in Rₖ. - # TODO: Check that not zeroing out doesn't lead to incorrect results. if restart xr .= zero(T) # xr === Δx when restart is set to true