Skip to content

Commit

Permalink
Support inplace methods directly
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed May 1, 2024
1 parent 23efac7 commit 6c86e88
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 110 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@ tangent and adjoint methods (see this
## Current Technical Limitations

* Only supports `gmres`, `cg`, and `bicgstab` methods
* No support for inplace methods `gmres!`, `cg!`, and `bicgstab!`
* No support for options when using Enzyme
* No support for sparse matrices using Enzyme
* No support for linear operators

## Current Open Questions
* How to handle preconditioners?
* How to set the options for the tangent/adjoint solve based on the options for the forward solve? For example `bicgtab` may return `NaN` for the tangents or adjoints.

## Installation
Expand Down
147 changes: 71 additions & 76 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,99 +5,93 @@ using .EnzymeRules
export augmented_primal, reverse, forward

for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:bicgstab, :gmres)
for solver in (:bicgstab!, :gmres!)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{VT};
M = I,
N = I,
verbose = 0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.val, A,b; M=M, N=N, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
db -= dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
return solver.val
else
return Duplicated((x, stats), (dx, dstats))
return solver
end
end
end
end
for solver in (:cg,)
for solver in (:cg!,)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{VT};
verbose = 0,
M = I,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
Krylov.$solver(solver.val,A,b; M=M, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
db -= dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
db = -dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
return solver.val
else
return Duplicated((x, stats), (dx, dstats))
return solver
end
end
end
Expand All @@ -106,116 +100,117 @@ end


for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:bicgstab, :gmres)
for solver in (:bicgstab!, :gmres!)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
ret::Type{<:Annotation},
solver::Annotation{ST},
A::Annotation{MT},
b::Annotation{VT};
M=I,
N=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M, N)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Krylov.$solver(
solver.val, A.val,b.val;
M=M, verbose=verbose, options...
)

cache = (solver.val.x, A.val, verbose,M,N)
return AugmentedReturn(nothing, nothing, cache)
end

function reverse(
config,
::Const{typeof(Krylov.$solver)},
dret::Type{RT},
cache,
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{<:Vector};
_b::Annotation{VT};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M,N) = cache
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT}
(x, A, verbose,M,N) = cache
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) reverse"
end
adjM = adjoint(N)
adjN = adjoint(M)
_b.dval .= Krylov.$solver(adjoint(A), bx[]; M=adjM, N=adjN, verbose=verbose, options...)[1]
Krylov.$solver(
solver.dval,
adjoint(A), copy(solver.dval.x); M=adjM, N=adjN,
verbose=verbose, options...
)
copyto!(_b.dval, solver.dval.x)
if isa(_A, Duplicated)
_A.dval .= -x .* _b.dval'
end
return (nothing, nothing)
return (nothing, nothing, nothing)
end
end
end
for solver in (:cg,)
for solver in (:cg!,)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
ret::Type{<:Annotation},
solver::Annotation{ST},
A::Annotation{MT},
b::Annotation{VT};
M=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Krylov.$solver(
solver.val, A.val,b.val;
M=M, verbose=verbose, options...
)
cache = (solver.val.x, A.val,verbose,M)
return AugmentedReturn(nothing, nothing, cache)
end

function reverse(
config,
::Const{typeof(Krylov.$solver)},
dret::Type{RT},
cache,
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{<:Vector};
_b::Annotation{VT};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M) = cache
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT}
(x, A, verbose,M) = cache
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
Krylov.$solver(
solver.dval,
A, copy(solver.dval.x); M=M,
verbose=verbose, options...
)
copyto!(_b.dval, solver.dval.x)
if isa(_A, Duplicated)
_A.dval .= -x .* _b.dval'
end
return (nothing, nothing, nothing)
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ include("utils.jl")
atol = 1e-12
rtol = 0.0
@testset "DiffKrylov" begin
@testset "ForwardDiff" begin
include("forwarddiff.jl")
end
# @testset "ForwardDiff" begin
# include("forwarddiff.jl")
# end
@testset "Enzyme" begin
include("enzymediff.jl")
end
Expand Down
Loading

0 comments on commit 6c86e88

Please sign in to comment.