Skip to content

Commit

Permalink
Preconditioner support (#5)
Browse files Browse the repository at this point in the history
* Preconditioner support
  • Loading branch information
michel2323 authored Apr 25, 2024
1 parent 3ca4968 commit 2d1a718
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/DiffKrylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DiffKrylov

using Krylov
using SparseArrays
using LinearAlgebra
include("ForwardDiff/forwarddiff.jl")
include("EnzymeRules/enzymerules.jl")
end
152 changes: 136 additions & 16 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,87 @@ using .EnzymeRules
export augmented_primal, reverse, forward

for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:cg, :bicgstab, :gmres)
for solver in (:bicgstab, :gmres)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT},
_b::Annotation{VT};
verbose = 0,
M = I,
N = I,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) forward rule")
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
else
return Duplicated((x, stats), (dx, dstats))
end
end
end
end
for solver in (:cg,)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
verbose = 0,
M = I,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; options...)
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; options...)
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
Expand All @@ -53,25 +106,89 @@ end


for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:cg, :bicgstab, :gmres)
for solver in (:bicgstab, :gmres)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
M=I,
N=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M, N)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
end

function reverse(
config,
::Const{typeof(Krylov.$solver)},
dret::Type{RT},
cache,
_A::Annotation{MT},
_b::Annotation{<:Vector};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M,N) = cache
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
end
end
end
for solver in (:cg,)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT}
_b::Annotation{VT};
M=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) augmented forward")
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b)
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx)))
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Expand All @@ -83,13 +200,16 @@ for AMT in (:Matrix, :SparseMatrixCSC)
dret::Type{RT},
cache,
_A::Annotation{MT},
_b::Annotation{<:Vector},
_b::Annotation{<:Vector};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M) = cache
psolver = $solver
pamt = $AMT
# println("($psolver, $pamt) reverse")
(A,x,bx) = cache
_b.dval .= $solver(transpose(A), bx[])[1]
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
end
Expand Down

0 comments on commit 2d1a718

Please sign in to comment.