Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimpleGMRES implementation #366

Merged
merged 14 commits into from
Sep 20, 2023
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style = "sciml"
format_markdown = true
format_markdown = true
annotate_untyped_fields_with_any = false
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
Manifest.toml

*.swp
.vscode
wip
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.6.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
Expand All @@ -28,32 +29,38 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"

[compat]
ArrayInterface = "7.4.11"
BlockDiagonals = "0.1"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
FastLapackInterface = "1, 2"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
IterativeSolvers = "0.9.2"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
KrylovKit = "0.5, 0.6"
PrecompileTools = "1"
Expand All @@ -69,20 +76,22 @@ UnPack = "1"
julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
9 changes: 9 additions & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ choice of Krylov method should be the one most constrained to the type of operat
has, for example if positive definite then `Krylov_CG()`, but if no good properties then
use `Krylov_GMRES()`.

!!! tip

If your materialized operator is a uniform block diagonal matrix, then you can use
`SimpleGMRES(; blocksize = <known block size>)` to further improve performance.
This often shows up in Neural Networks where the Jacobian wrt the Inputs (almost always)
is a Uniform Block Diagonal matrix of Block Size = size of the input divided by the
batch size.

## Full List of Methods

### RecursiveFactorization.jl
Expand Down Expand Up @@ -106,6 +114,7 @@ LinearSolve.jl contains some linear solvers built in for specailized cases.
```@docs
SimpleLUFactorization
DiagonalFactorization
SimpleGMRES
```

### FastLapackInterface.jl
Expand Down
24 changes: 24 additions & 0 deletions ext/LinearSolveBlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module LinearSolveBlockDiagonalsExt

using LinearSolve, BlockDiagonals

function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...;
kwargs...)
@assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
# We need to perform this check even when `zeroinit == true`, since the type of the
# cache is dependent on whether we are able to use the specialized dispatch.
bsizes = blocksizes(A)
usize = first(first(bsizes))
uniform_blocks = true
for bsize in bsizes
if bsize[1] != usize || bsize[2] != usize
uniform_blocks = false
break

Check warning on line 16 in ext/LinearSolveBlockDiagonalsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBlockDiagonalsExt.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end
end
# Can't help but perform dynamic dispatch here
return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, args...;
blocksize = usize, kwargs...)
end

end
24 changes: 24 additions & 0 deletions ext/LinearSolveKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module LinearSolveKernelAbstractionsExt

using LinearSolve, KernelAbstractions

LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true

Check warning on line 5 in ext/LinearSolveKernelAbstractionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveKernelAbstractionsExt.jl#L5

Added line #L5 was not covered by tests

using GPUArraysCore

function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis)
backend = get_backend(Hbis)
kernel! = __fast_sym_givens_kernel!(backend)
kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize)
return c, s, R

Check warning on line 13 in ext/LinearSolveKernelAbstractionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveKernelAbstractionsExt.jl#L9-L13

Added lines #L9 - L13 were not covered by tests
end

@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis))
idx = @index(Global)
@inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx])
@inbounds c[idx] = _c
@inbounds s[idx] = _s
@inbounds R[idx] = _ρ

Check warning on line 21 in ext/LinearSolveKernelAbstractionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveKernelAbstractionsExt.jl#L16-L21

Added lines #L16 - L21 were not covered by tests
end

end
13 changes: 11 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ PrecompileTools.@recompile_invalidations begin
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
import Preferences
import ConcreteStructs: @concrete

# wrap
import Krylov

using SciMLBase
end

Expand All @@ -62,6 +63,11 @@ _isidentity_struct(λ::Number) = isone(λ)
_isidentity_struct(A::UniformScaling) = isone(A.λ)
_isidentity_struct(::SciMLOperators.IdentityOperator) = true

# Dispatch Friendly way to check if an extension is loaded
__is_extension_loaded(::Val) = false

function _fast_sym_givens! end

# Code

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
Expand Down Expand Up @@ -92,6 +98,7 @@ end
include("common.jl")
include("factorization.jl")
include("simplelu.jl")
include("simplegmres.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("solve_function.jl")
Expand Down Expand Up @@ -176,6 +183,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES

export SimpleGMRES

export HYPREAlgorithm
export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
Expand Down
2 changes: 1 addition & 1 deletion src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
Krylov.solve!(args...; M = M,
kwargs...)
elseif cache.cacheval isa Krylov.GmresSolver
Krylov.solve!(args...; M = M, N = N,
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
kwargs...)
elseif cache.cacheval isa Krylov.BicgstabSolver
Krylov.solve!(args...; M = M, N = N,
Expand Down
Loading
Loading