diff --git a/README.md b/README.md index a556fb6..d8685b2 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,9 @@ tangent and adjoint methods (see this ## Current Technical Limitations * Only supports `gmres`, `cg`, and `bicgstab` methods -* No support for inplace methods `gmres!`, `cg!`, and `bicgstab!` -* No support for options when using Enzyme -* No support for sparse matrices using Enzyme * No support for linear operators ## Current Open Questions -* How to handle preconditioners? * How to set the options for the tangent/adjoint solve based on the options for the forward solve? For example `bicgtab` may return `NaN` for the tangents or adjoints. ## Installation diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl index 834f778..e76e4d5 100644 --- a/src/EnzymeRules/enzymerules.jl +++ b/src/EnzymeRules/enzymerules.jl @@ -5,18 +5,19 @@ using .EnzymeRules export augmented_primal, reverse, forward for AMT in (:Matrix, :SparseMatrixCSC) - for solver in (:bicgstab, :gmres) + for solver in (:bicgstab!, :gmres!) @eval begin function forward( func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, + solver::Annotation{ST}, _A::Annotation{MT}, _b::Annotation{VT}; M = I, N = I, verbose = 0, options... - ) where {RT, MT <: $AMT, VT <: Vector} + ) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT if verbose > 0 @@ -24,47 +25,44 @@ for AMT in (:Matrix, :SparseMatrixCSC) end A = _A.val b = _b.val - dx = [] - x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) + Krylov.$solver(solver.val, A,b; M=M, N=N, verbose=verbose, options...) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval - db -= dA*x - dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + db -= dA*solver.val.x + Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval db = -dA*x - dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...) + Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Const) nothing else error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))") end - if RT <: Const - return (x, stats) - elseif RT <: DuplicatedNoNeed - return (dx, stats) + return solver.val else - return Duplicated((x, stats), (dx, dstats)) + return solver end end end end - for solver in (:cg,) + for solver in (:cg!,) @eval begin function forward( func::Const{typeof(Krylov.$solver)}, ret::Type{RT}, + solver::Annotation{ST}, _A::Annotation{MT}, _b::Annotation{VT}; verbose = 0, M = I, options... - ) where {RT, MT <: $AMT, VT <: Vector} + ) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT if verbose > 0 @@ -72,32 +70,28 @@ for AMT in (:Matrix, :SparseMatrixCSC) end A = _A.val b = _b.val - dx = [] - x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...) + Krylov.$solver(solver.val,A,b; M=M, verbose=verbose, options...) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval - db -= dA*x - dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) + db -= dA*solver.val.x + Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval - db = -dA*x - dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) + db = -dA*solver.val.x + Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...) + Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) elseif isa(_A, Const) && isa(_b, Const) nothing else error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))") end - if RT <: Const - return (x, stats) - elseif RT <: DuplicatedNoNeed - return (dx, stats) + return solver.val else - return Duplicated((x, stats), (dx, dstats)) + return solver end end end @@ -106,38 +100,32 @@ end for AMT in (:Matrix, :SparseMatrixCSC) - for solver in (:bicgstab, :gmres) + for solver in (:bicgstab!, :gmres!) @eval begin function augmented_primal( config, func::Const{typeof(Krylov.$solver)}, - ret::Type{RT}, - _A::Annotation{MT}, - _b::Annotation{VT}; + ret::Type{<:Annotation}, + solver::Annotation{ST}, + A::Annotation{MT}, + b::Annotation{VT}; M=I, N=I, verbose=0, options... - ) where {RT, MT <: $AMT, VT <: Vector} + ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT if verbose > 0 @info "($psolver, $pamt) augmented forward" end - A = _A.val - b = _b.val - x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...) - bx = zeros(length(x)) - bstats = deepcopy(stats) - if needs_primal(config) - return AugmentedReturn( - (x, stats), - (bx, bstats), - (A,x, Ref(bx), verbose, M, N) - ) - else - return AugmentedReturn(nothing, (bx, bstats), (A,x)) - end + Krylov.$solver( + solver.val, A.val,b.val; + M=M, verbose=verbose, options... + ) + + cache = (solver.val.x, A.val, verbose,M,N) + return AugmentedReturn(nothing, nothing, cache) end function reverse( @@ -145,11 +133,12 @@ for AMT in (:Matrix, :SparseMatrixCSC) ::Const{typeof(Krylov.$solver)}, dret::Type{RT}, cache, + solver::Annotation{ST}, _A::Annotation{MT}, - _b::Annotation{<:Vector}; + _b::Annotation{VT}; options... - ) where {RT, MT <: $AMT} - (A,x,bx,verbose,M,N) = cache + ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT} + (x, A, verbose,M,N) = cache psolver = $solver pamt = $AMT if verbose > 0 @@ -157,45 +146,43 @@ for AMT in (:Matrix, :SparseMatrixCSC) end adjM = adjoint(N) adjN = adjoint(M) - _b.dval .= Krylov.$solver(adjoint(A), bx[]; M=adjM, N=adjN, verbose=verbose, options...)[1] + Krylov.$solver( + solver.dval, + adjoint(A), copy(solver.dval.x); M=adjM, N=adjN, + verbose=verbose, options... + ) + copyto!(_b.dval, solver.dval.x) if isa(_A, Duplicated) _A.dval .= -x .* _b.dval' end - return (nothing, nothing) + return (nothing, nothing, nothing) end end end - for solver in (:cg,) + for solver in (:cg!,) @eval begin function augmented_primal( config, func::Const{typeof(Krylov.$solver)}, - ret::Type{RT}, - _A::Annotation{MT}, - _b::Annotation{VT}; + ret::Type{<:Annotation}, + solver::Annotation{ST}, + A::Annotation{MT}, + b::Annotation{VT}; M=I, verbose=0, options... - ) where {RT, MT <: $AMT, VT <: Vector} + ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} psolver = $solver pamt = $AMT if verbose > 0 @info "($psolver, $pamt) augmented forward" end - A = _A.val - b = _b.val - x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...) - bx = zeros(length(x)) - bstats = deepcopy(stats) - if needs_primal(config) - return AugmentedReturn( - (x, stats), - (bx, bstats), - (A,x, Ref(bx), verbose, M) - ) - else - return AugmentedReturn(nothing, (bx, bstats), (A,x)) - end + Krylov.$solver( + solver.val, A.val,b.val; + M=M, verbose=verbose, options... + ) + cache = (solver.val.x, A.val,verbose,M) + return AugmentedReturn(nothing, nothing, cache) end function reverse( @@ -203,19 +190,27 @@ for AMT in (:Matrix, :SparseMatrixCSC) ::Const{typeof(Krylov.$solver)}, dret::Type{RT}, cache, + solver::Annotation{ST}, _A::Annotation{MT}, - _b::Annotation{<:Vector}; + _b::Annotation{VT}; options... - ) where {RT, MT <: $AMT} - (A,x,bx,verbose,M) = cache + ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT} + (x, A, verbose,M) = cache psolver = $solver pamt = $AMT if verbose > 0 @info "($psolver, $pamt) reverse" end - _b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, verbose=verbose, options...)[1] - _A.dval .= -x .* _b.dval' - return (nothing, nothing) + Krylov.$solver( + solver.dval, + A, copy(solver.dval.x); M=M, + verbose=verbose, options... + ) + copyto!(_b.dval, solver.dval.x) + if isa(_A, Duplicated) + _A.dval .= -x .* _b.dval' + end + return (nothing, nothing, nothing) end end end diff --git a/test/runtests.jl b/test/runtests.jl index ecfda33..0521341 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,9 +12,9 @@ include("utils.jl") atol = 1e-12 rtol = 0.0 @testset "DiffKrylov" begin - @testset "ForwardDiff" begin - include("forwarddiff.jl") - end + # @testset "ForwardDiff" begin + # include("forwarddiff.jl") + # end @testset "Enzyme" begin include("enzymediff.jl") end diff --git a/test/utils.jl b/test/utils.jl index 3ca1887..8be9dec 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -84,32 +84,28 @@ function check_derivatives_and_values_active_passive(solver, A, b, x) @test isapprox(partials.(dx,1), fda[1]) end -struct GMRES end -struct BICGSTAB end -struct CG end - -function driver!(::GMRES, x, A, b, M, N, ldiv=false) - x .= gmres(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1] +function driver!(solver::GmresSolver, A, b, M, N, ldiv=false) + gmres!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv) nothing end -function driver!(::BICGSTAB, x, A, b, M, N, ldiv=false) - x .= bicgstab(A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv)[1] +function driver!(solver::BicgstabSolver, A, b, M, N, ldiv=false) + bicgstab!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv) nothing end -function driver!(::CG, x, A, b, M, N, ldiv=false) - x .= cg(A,b, atol=1e-16, rtol=1e-16, M=M, verbose=0, ldiv=ldiv)[1] +function driver!(solver::CgSolver, A, b, M, N, ldiv=false) + cg!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, verbose=0, ldiv=ldiv) nothing end function test_enzyme_with(solver, A, b, M, N, ldiv=false) tsolver = if solver == Krylov.cg - CG() + CgSolver(A,b) elseif solver == Krylov.gmres - GMRES() + GmresSolver(A,b) elseif solver == Krylov.bicgstab - BICGSTAB() + BicgstabSolver(A,b) else error("Unsupported solver $solver is tested in DiffKrylov.jl") end @@ -117,17 +113,17 @@ function test_enzyme_with(solver, A, b, M, N, ldiv=false) function A_one_one(hx) _A = copy(A) _A[1,1] = hx - x = zeros(length(b)) - driver!(tsolver, x, _A, b, M, N, ldiv) - return x + # fill!(tsolver.x, zero(eltype(solver.x))) + driver!(tsolver, _A, b, M, N, ldiv) + return tsolver.x[1] end function b_one(hx) _b = copy(b) _b[1] = hx - x = zeros(length(b)) - driver!(tsolver, x, A, _b, M, N, ldiv) - return x + # fill!(tsolver.x, zero(eltype(tsolver.x))) + driver!(tsolver, A, _b, M, N, ldiv) + return tsolver.x[1] end fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a), copy(A[1,1])) @@ -143,31 +139,31 @@ function test_enzyme_with(solver, A, b, M, N, ldiv=false) dA = Duplicated(A, duplicate(A)) db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) + dupsolver = Duplicated(tsolver, deepcopy(tsolver)) + fill!(dupsolver.dval.x, zero(eltype(dupsolver.dval.x))) dA.dval[1,1] = 1.0 db.dval[1] = 1.0 Enzyme.autodiff( Forward, driver!, - Const(tsolver), - dx, + dupsolver, dA, db, Const(M), Const(N), Const(ldiv) ) - @test isapprox(dx.dval, fd, atol=1e-4, rtol=1e-4) + @test isapprox(dupsolver.dval.x[1], fd[1][1], atol=1e-4, rtol=1e-4) # Test reverse dA = Duplicated(A, duplicate(A)) db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - dx.dval[1] = 1.0 + dupsolver = Duplicated(tsolver, deepcopy(tsolver)) + fill!(dupsolver.dval.x, zero(eltype(dupsolver.dval.x))) + dupsolver.dval.x[1] = 1.0 Enzyme.autodiff( Reverse, driver!, - Const(tsolver), - dx, + dupsolver, dA, db, Const(M),