Skip to content

Commit

Permalink
Merge pull request #366 from avik-pal/ap/simplegmres
Browse files Browse the repository at this point in the history
Add SimpleGMRES implementation
  • Loading branch information
ChrisRackauckas authored Sep 20, 2023
2 parents 60ae26a + 38e3401 commit ac2b5aa
Show file tree
Hide file tree
Showing 12 changed files with 738 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style = "sciml"
format_markdown = true
format_markdown = true
annotate_untyped_fields_with_any = false
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
Manifest.toml

*.swp
.vscode
wip
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.6.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
Expand All @@ -28,32 +29,38 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"

[compat]
ArrayInterface = "7.4.11"
BlockDiagonals = "0.1"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
FastLapackInterface = "1, 2"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
IterativeSolvers = "0.9.2"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
KrylovKit = "0.5, 0.6"
PrecompileTools = "1"
Expand All @@ -69,20 +76,22 @@ UnPack = "1"
julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
9 changes: 9 additions & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ choice of Krylov method should be the one most constrained to the type of operat
has, for example if positive definite then `Krylov_CG()`, but if no good properties then
use `Krylov_GMRES()`.

!!! tip

If your materialized operator is a uniform block diagonal matrix, then you can use
`SimpleGMRES(; blocksize = <known block size>)` to further improve performance.
This often shows up in Neural Networks where the Jacobian wrt the Inputs (almost always)
is a Uniform Block Diagonal matrix of Block Size = size of the input divided by the
batch size.

## Full List of Methods

### RecursiveFactorization.jl
Expand Down Expand Up @@ -106,6 +114,7 @@ LinearSolve.jl contains some linear solvers built in for specailized cases.
```@docs
SimpleLUFactorization
DiagonalFactorization
SimpleGMRES
```

### FastLapackInterface.jl
Expand Down
24 changes: 24 additions & 0 deletions ext/LinearSolveBlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module LinearSolveBlockDiagonalsExt

using LinearSolve, BlockDiagonals

function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...;
kwargs...)
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
# We need to perform this check even when `zeroinit == true`, since the type of the
# cache is dependent on whether we are able to use the specialized dispatch.
bsizes = blocksizes(A)
usize = first(first(bsizes))
uniform_blocks = true
for bsize in bsizes
if bsize[1] != usize || bsize[2] != usize
uniform_blocks = false
break
end
end
# Can't help but perform dynamic dispatch here
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, args...;
blocksize = usize, kwargs...)
end

end
24 changes: 24 additions & 0 deletions ext/LinearSolveKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module LinearSolveKernelAbstractionsExt

using LinearSolve, KernelAbstractions

LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true

using GPUArraysCore

function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis)
backend = get_backend(Hbis)
kernel! = __fast_sym_givens_kernel!(backend)
kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize)
return c, s, R
end

@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis))
idx = @index(Global)
@inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx])
@inbounds c[idx] = _c
@inbounds s[idx] = _s
@inbounds R[idx] =
end

end
13 changes: 11 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ PrecompileTools.@recompile_invalidations begin
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
import Preferences
import ConcreteStructs: @concrete

# wrap
import Krylov

using SciMLBase
end

Expand All @@ -62,6 +63,11 @@ _isidentity_struct(λ::Number) = isone(λ)
_isidentity_struct(A::UniformScaling) = isone(A.λ)
_isidentity_struct(::SciMLOperators.IdentityOperator) = true

# Dispatch Friendly way to check if an extension is loaded
__is_extension_loaded(::Val) = false

function _fast_sym_givens! end

# Code

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
Expand Down Expand Up @@ -92,6 +98,7 @@ end
include("common.jl")
include("factorization.jl")
include("simplelu.jl")
include("simplegmres.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("solve_function.jl")
Expand Down Expand Up @@ -176,6 +183,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES

export SimpleGMRES

export HYPREAlgorithm
export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
Expand Down
2 changes: 1 addition & 1 deletion src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
Krylov.solve!(args...; M = M,
kwargs...)
elseif cache.cacheval isa Krylov.GmresSolver
Krylov.solve!(args...; M = M, N = N,
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
kwargs...)
elseif cache.cacheval isa Krylov.BicgstabSolver
Krylov.solve!(args...; M = M, N = N,
Expand Down
Loading

0 comments on commit ac2b5aa

Please sign in to comment.