Skip to content

Commit

Permalink
Support preconditioners in BiLQ and QMR (#862)
Browse files Browse the repository at this point in the history
* Support preconditioners in BiLQ and QMR

* Fix preconditioning with ldiv=true in BiLQ and QMR

* Update bilq.jl and qmr.jl
  • Loading branch information
amontoison authored May 20, 2024
1 parent 8ecc293 commit 31fe591
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 48 deletions.
5 changes: 3 additions & 2 deletions docs/src/preconditioners.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ Krylov.jl supports both approaches thanks to the argument `ldiv` of the Krylov s
## How to use preconditioners in Krylov.jl?

!!! info
- A preconditioner only need support the operation `mul!(y, P⁻¹, x)` when `ldiv=false` or `ldiv!(y, P, x)` when `ldiv=true` to be used in Krylov.jl.
- A preconditioner only needs to support the operation `mul!(y, P⁻¹, x)` when `ldiv=false` or `ldiv!(y, P, x)` when `ldiv=true` to be used in Krylov.jl.
- Additional support for `adjoint` with preconditioners is required in the methods [`BILQ`](@ref bilq) and [`QMR`](@ref qmr).
- The default value of a preconditioner in Krylov.jl is the identity operator `I`.

### Square non-Hermitian linear systems

Methods concerned: [`CGS`](@ref cgs), [`BiCGSTAB`](@ref bicgstab), [`DQGMRES`](@ref dqgmres), [`GMRES`](@ref gmres), [`BLOCK-GMRES`](@ref block_gmres), [`FGMRES`](@ref fgmres), [`DIOM`](@ref diom) and [`FOM`](@ref fom).
Methods concerned: [`CGS`](@ref cgs), [`BILQ`](@ref bilq), [`QMR`](@ref qmr), [`BiCGSTAB`](@ref bicgstab), [`DQGMRES`](@ref dqgmres), [`GMRES`](@ref gmres), [`BLOCK-GMRES`](@ref block_gmres), [`FGMRES`](@ref fgmres), [`DIOM`](@ref diom) and [`FOM`](@ref fom).

A Krylov method dedicated to non-Hermitian linear systems allows the three variants of preconditioning.

Expand Down
73 changes: 54 additions & 19 deletions src/bilq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ export bilq, bilq!
"""
(x, stats) = bilq(A, b::AbstractVector{FC};
c::AbstractVector{FC}=b, transfer_to_bicg::Bool=true,
atol::T=√eps(T), rtol::T=√eps(T), itmax::Int=0,
timemax::Float64=Inf, verbose::Int=0, history::Bool=false,
M=I, N=I, ldiv::Bool=false, atol::T=√eps(T),
rtol::T=√eps(T), itmax::Int=0, timemax::Float64=Inf,
verbose::Int=0, history::Bool=false,
callback=solver->false, iostream::IO=kstdout)
`T` is an `AbstractFloat` such as `Float32`, `Float64` or `BigFloat`.
Expand All @@ -30,6 +31,7 @@ Solve the square linear system Ax = b of size n using BiLQ.
BiLQ is based on the Lanczos biorthogonalization process and requires two initial vectors `b` and `c`.
The relation `bᴴc ≠ 0` must be satisfied and by default `c = b`.
When `A` is Hermitian and `b = c`, BiLQ is equivalent to SYMMLQ.
BiLQ requires support for `adjoint(M)` and `adjoint(N)` if preconditioners are provided.
#### Input arguments
Expand All @@ -44,6 +46,9 @@ When `A` is Hermitian and `b = c`, BiLQ is equivalent to SYMMLQ.
* `c`: the second initial vector of length `n` required by the Lanczos biorthogonalization process;
* `transfer_to_bicg`: transfer from the BiLQ point to the BiCG point, when it exists. The transfer is based on the residual norm;
* `M`: linear operator that models a nonsingular matrix of size `n` used for left preconditioning;
* `N`: linear operator that models a nonsingular matrix of size `n` used for right preconditioning;
* `ldiv`: define whether the preconditioners use `ldiv!` or `mul!`;
* `atol`: absolute stopping tolerance based on the residual norm;
* `rtol`: relative stopping tolerance based on the residual norm;
* `itmax`: the maximum number of iterations. If `itmax=0`, the default number of iterations is set to `2n`;
Expand Down Expand Up @@ -82,6 +87,9 @@ def_optargs_bilq = (:(x0::AbstractVector),)

def_kwargs_bilq = (:(; c::AbstractVector{FC} = b ),
:(; transfer_to_bicg::Bool = true),
:(; M = I ),
:(; N = I ),
:(; ldiv::Bool = false ),
:(; atol::T = eps(T) ),
:(; rtol::T = eps(T) ),
:(; itmax::Int = 0 ),
Expand All @@ -95,7 +103,7 @@ def_kwargs_bilq = mapreduce(extract_parameters, vcat, def_kwargs_bilq)

args_bilq = (:A, :b)
optargs_bilq = (:x0,)
kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
kwargs_bilq = (:c, :transfer_to_bicg, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function bilq($(def_args_bilq...), $(def_optargs_bilq...); $(def_kwargs_bilq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
Expand Down Expand Up @@ -131,26 +139,42 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
length(b) == m || error("Inconsistent problem size")
(verbose > 0) && @printf(iostream, "BILQ: system of size %d\n", n)

# Check M = Iₙ and N = Iₙ
MisI = (M === I)
NisI = (N === I)

# Check type consistency
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
ktypeof(c) <: S || error("ktypeof(c) is not a subtype of $S")

# Compute the adjoint of A
# Compute the adjoint of A, M and N
Aᴴ = A'
Mᴴ = M'
Nᴴ = N'

# Set up workspace.
allocate_if(!MisI, solver, :t, S, n)
allocate_if(!NisI, solver, :s, S, n)
uₖ₋₁, uₖ, q, vₖ₋₁, vₖ = solver.uₖ₋₁, solver.uₖ, solver.q, solver.vₖ₋₁, solver.vₖ
p, Δx, x, d̅, stats = solver.p, solver.Δx, solver.x, solver.d̅, solver.stats
warm_start = solver.warm_start
rNorms = stats.residuals
reset!(stats)
r₀ = warm_start ? q : b
Mᴴuₖ = MisI ? uₖ : solver.t
t = MisI ? q : solver.t
Nvₖ = NisI ? vₖ : solver.s
s = NisI ? p : solver.s

if warm_start
mul!(r₀, A, Δx)
@kaxpby!(n, one(FC), b, -one(FC), r₀)
end
if !MisI
mulorldiv!(solver.t, M, r₀, ldiv)
r₀ = solver.t
end

# Initial solution x₀ and residual norm ‖r₀‖.
x .= zero(FC)
Expand All @@ -170,10 +194,6 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
iter = 0
itmax == 0 && (itmax = 2*n)

ε = atol + rtol * bNorm
(verbose > 0) && @printf(iostream, "%5s %7s %5s\n", "k", "‖rₖ‖", "timer")
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, bNorm, ktimer(start_time))

# Initialize the Lanczos biorthogonalization process.
cᴴb = @kdot(n, c, r₀) # ⟨c,r₀⟩
if cᴴb == 0
Expand All @@ -186,6 +206,10 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
return solver
end

ε = atol + rtol * bNorm
(verbose > 0) && @printf(iostream, "%5s %8s %7s %5s\n", "k", "αₖ", "‖rₖ‖", "timer")
kdisplay(iter, verbose) && @printf(iostream, "%5d %8.1e %7.1e %.2fs\n", iter, cᴴb, bNorm, ktimer(start_time))

βₖ = (abs(cᴴb)) # β₁γ₁ = cᴴ(b - Ax₀)
γₖ = cᴴb / βₖ # β₁γ₁ = cᴴ(b - Ax₀)
vₖ₋₁ .= zero(FC) # v₀ = 0
Expand Down Expand Up @@ -214,23 +238,30 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
iter = iter + 1

# Continue the Lanczos biorthogonalization process.
# AVₖ = VₖTₖ + βₖ₊₁vₖ₊₁(eₖ)ᵀ = Vₖ₊₁Tₖ₊₁.ₖ
# AᴴUₖ = Uₖ(Tₖ)ᴴ + γ̄ₖ₊₁uₖ₊₁(eₖ)ᵀ = Uₖ₊₁(Tₖ.ₖ₊₁)ᴴ
# MANVₖ = VₖTₖ + βₖ₊₁vₖ₊₁(eₖ)ᵀ = Vₖ₊₁Tₖ₊₁.ₖ
# NᴴAᴴMᴴUₖ = Uₖ(Tₖ)ᴴ + γ̄ₖ₊₁uₖ₊₁(eₖ)ᵀ = Uₖ₊₁(Tₖ.ₖ₊₁)ᴴ

mul!(q, A , vₖ) # Forms vₖ₊₁ : q ← Avₖ
mul!(p, Aᴴ, uₖ) # Forms uₖ₊₁ : p ← Aᴴuₖ
# Forms vₖ₊₁ : q ← MANvₖ
NisI || mulorldiv!(Nvₖ, N, vₖ, ldiv)
mul!(t, A, Nvₖ)
MisI || mulorldiv!(q, M, t, ldiv)

# Forms uₖ₊₁ : p ← NᴴAᴴMᴴuₖ
MisI || mulorldiv!(Mᴴuₖ, Mᴴ, uₖ, ldiv)
mul!(s, Aᴴ, Mᴴuₖ)
NisI || mulorldiv!(p, Nᴴ, s, ldiv)

@kaxpy!(n, -γₖ, vₖ₋₁, q) # q ← q - γₖ * vₖ₋₁
@kaxpy!(n, -βₖ, uₖ₋₁, p) # p ← p - β̄ₖ * uₖ₋₁

αₖ = @kdot(n, uₖ, q) # αₖ = ⟨uₖ,q⟩
αₖ = @kdot(n, uₖ, q) # αₖ = ⟨uₖ,q⟩

@kaxpy!(n, - αₖ , vₖ, q) # q ← q - αₖ * vₖ
@kaxpy!(n, -conj(αₖ), uₖ, p) # p ← p - ᾱₖ * uₖ
@kaxpy!(n, - αₖ , vₖ, q) # q ← q - αₖ * vₖ
@kaxpy!(n, -conj(αₖ), uₖ, p) # p ← p - ᾱₖ * uₖ

pᴴq = @kdot(n, p, q) # pᴴq = ⟨p,q⟩
βₖ₊₁ = (abs(pᴴq)) # βₖ₊₁ = √(|pᴴq|)
γₖ₊₁ = pᴴq / βₖ₊₁ # γₖ₊₁ = pᴴq / βₖ₊₁
pᴴq = @kdot(n, p, q) # pᴴq = ⟨p,q⟩
βₖ₊₁ = (abs(pᴴq)) # βₖ₊₁ = √(|pᴴq|)
γₖ₊₁ = pᴴq / βₖ₊₁ # γₖ₊₁ = pᴴq / βₖ₊₁

# Update the LQ factorization of Tₖ = L̅ₖQₖ.
# [ α₁ γ₂ 0 • • • 0 ] [ δ₁ 0 • • • • 0 ]
Expand Down Expand Up @@ -353,7 +384,7 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
breakdown = !solved_lq && !solved_cg && (pᴴq == 0)
timer = time_ns() - start_time
overtimed = timer > timemax_ns
kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, rNorm_lq, ktimer(start_time))
kdisplay(iter, verbose) && @printf(iostream, "%5d %8.1e %7.1e %.2fs\n", iter, αₖ, rNorm_lq, ktimer(start_time))
end
(verbose > 0) && @printf(iostream, "\n")

Expand All @@ -372,6 +403,10 @@ kwargs_bilq = (:c, :transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose,
overtimed && (status = "time limit exceeded")

# Update x
if !NisI
copyto!(solver.s, x)
mulorldiv!(x, N, solver.s, ldiv)
end
warm_start && @kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false

Expand Down
12 changes: 10 additions & 2 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,8 @@ mutable struct BilqSolver{T,FC,S} <: KrylovSolver{T,FC,S}
Δx :: S
x :: S
:: S
t :: S
s :: S
warm_start :: Bool
stats :: SimpleStats{T}
end
Expand All @@ -1018,8 +1020,10 @@ function BilqSolver(m, n, S)
Δx = S(undef, 0)
x = S(undef, n)
= S(undef, n)
t = S(undef, 0)
s = S(undef, 0)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = BilqSolver{T,FC,S}(m, n, uₖ₋₁, uₖ, q, vₖ₋₁, vₖ, p, Δx, x, d̅, false, stats)
solver = BilqSolver{T,FC,S}(m, n, uₖ₋₁, uₖ, q, vₖ₋₁, vₖ, p, Δx, x, d̅, t, s, false, stats)
return solver
end

Expand Down Expand Up @@ -1052,6 +1056,8 @@ mutable struct QmrSolver{T,FC,S} <: KrylovSolver{T,FC,S}
x :: S
wₖ₋₂ :: S
wₖ₋₁ :: S
t :: S
s :: S
warm_start :: Bool
stats :: SimpleStats{T}
end
Expand All @@ -1069,8 +1075,10 @@ function QmrSolver(m, n, S)
x = S(undef, n)
wₖ₋₂ = S(undef, n)
wₖ₋₁ = S(undef, n)
t = S(undef, 0)
s = S(undef, 0)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = QmrSolver{T,FC,S}(m, n, uₖ₋₁, uₖ, q, vₖ₋₁, vₖ, p, Δx, x, wₖ₋₂, wₖ₋₁, false, stats)
solver = QmrSolver{T,FC,S}(m, n, uₖ₋₁, uₖ, q, vₖ₋₁, vₖ, p, Δx, x, wₖ₋₂, wₖ₋₁, t, s, false, stats)
return solver
end

Expand Down
Loading

0 comments on commit 31fe591

Please sign in to comment.