diff --git a/src/DiffKrylov.jl b/src/DiffKrylov.jl index bef9a09..9ff750a 100644 --- a/src/DiffKrylov.jl +++ b/src/DiffKrylov.jl @@ -2,6 +2,7 @@ module DiffKrylov using Krylov using SparseArrays +using LinearAlgebra include("ForwardDiff/forwarddiff.jl") include("EnzymeRules/enzymerules.jl") end diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl index d7bd627..3a589c4 100644 --- a/src/EnzymeRules/enzymerules.jl +++ b/src/EnzymeRules/enzymerules.jl @@ -5,34 +5,87 @@ using .EnzymeRules export augmented_primal, reverse, forward for AMT in (:Matrix, :SparseMatrixCSC) - for solver in (:cg, :bicgstab, :gmres) + for solver in (:bicgstab, :gmres) @eval begin function forward( func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, _A::Annotation{MT}, - _b::Annotation{VT}, + _b::Annotation{VT}; + verbose = 0, + M = I, + N = I, options... ) where {RT, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT - # println("($psolver, $pamt) forward rule") + if verbose > 0 + @info "($psolver, $pamt) forward rule" + end + A = _A.val + b = _b.val + dx = [] + x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) + if isa(_A, Duplicated) && isa(_b, Duplicated) + dA = _A.dval + db = _b.dval + db -= dA*x + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + elseif isa(_A, Duplicated) && isa(_b, Const) + dA = _A.dval + db = -dA*x + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + elseif isa(_A, Const) && isa(_b, Duplicated) + db = _b.dval + dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + elseif isa(_A, Const) && isa(_b, Const) + nothing + else + error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))") + end + + if RT <: Const + return (x, stats) + elseif RT <: DuplicatedNoNeed + return (dx, stats) + else + return Duplicated((x, stats), (dx, dstats)) + end + end + end + end + for solver in (:cg,) + @eval begin + function forward( + func::Const{typeof(Krylov.$solver)}, + ret::Type{RT}, + _A::Annotation{MT}, + _b::Annotation{VT}; + verbose = 0, + M = I, + options... + ) where {RT, MT <: $AMT, VT <: Vector} + psolver = $solver + pamt = $AMT + if verbose > 0 + @info "($psolver, $pamt) forward rule" + end A = _A.val b = _b.val dx = [] - x, stats = Krylov.$solver(A,b; options...) + x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval db -= dA*x - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval db = -dA*x - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - dx, dstats = Krylov.$solver(A,db; options...) + dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Const) nothing else @@ -53,25 +106,89 @@ end for AMT in (:Matrix, :SparseMatrixCSC) - for solver in (:cg, :bicgstab, :gmres) + for solver in (:bicgstab, :gmres) + @eval begin + function augmented_primal( + config, + func::Const{typeof(Krylov.$solver)}, + ret::Type{RT}, + _A::Annotation{MT}, + _b::Annotation{VT}; + M=I, + N=I, + verbose=0, + options... + ) where {RT, MT <: $AMT, VT <: Vector} + psolver = $solver + pamt = $AMT + if verbose > 0 + @info "($psolver, $pamt) augmented forward" + end + A = _A.val + b = _b.val + x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) + bx = zeros(length(x)) + bstats = deepcopy(stats) + if needs_primal(config) + return AugmentedReturn( + (x, stats), + (bx, bstats), + (A,x, Ref(bx), verbose, M, N) + ) + else + return AugmentedReturn(nothing, (bx, bstats), (A,x)) + end + end + + function reverse( + config, + ::Const{typeof(Krylov.$solver)}, + dret::Type{RT}, + cache, + _A::Annotation{MT}, + _b::Annotation{<:Vector}; + options... + ) where {RT, MT <: $AMT} + (A,x,bx,verbose,M,N) = cache + psolver = $solver + pamt = $AMT + if verbose > 0 + @info "($psolver, $pamt) reverse" + end + _b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, N=N, verbose=verbose, options...)[1] + _A.dval .= -x .* _b.dval' + return (nothing, nothing) + end + end + end + for solver in (:cg,) @eval begin function augmented_primal( config, func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, _A::Annotation{MT}, - _b::Annotation{VT} + _b::Annotation{VT}; + M=I, + verbose=0, + options... ) where {RT, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT - # println("($psolver, $pamt) augmented forward") + if verbose > 0 + @info "($psolver, $pamt) augmented forward" + end A = _A.val b = _b.val - x, stats = Krylov.$solver(A,b) + x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...) bx = zeros(length(x)) bstats = deepcopy(stats) if needs_primal(config) - return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx))) + return AugmentedReturn( + (x, stats), + (bx, bstats), + (A,x, Ref(bx), verbose, M) + ) else return AugmentedReturn(nothing, (bx, bstats), (A,x)) end @@ -83,13 +200,16 @@ for AMT in (:Matrix, :SparseMatrixCSC) dret::Type{RT}, cache, _A::Annotation{MT}, - _b::Annotation{<:Vector}, + _b::Annotation{<:Vector}; + options... ) where {RT, MT <: $AMT} + (A,x,bx,verbose,M) = cache psolver = $solver pamt = $AMT - # println("($psolver, $pamt) reverse") - (A,x,bx) = cache - _b.dval .= $solver(transpose(A), bx[])[1] + if verbose > 0 + @info "($psolver, $pamt) reverse" + end + _b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, verbose=verbose, options...)[1] _A.dval .= -x .* _b.dval' return (nothing, nothing) end