From c1f8df6b2718c777635b2e36c5465c6f96cc5888 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Wed, 13 Dec 2023 14:27:17 -0600 Subject: [PATCH] EnzymeRules (#3) Support for EnzymeRules --- Project.toml | 1 + README.md | 37 ++++++++- src/DiffKrylov.jl | 5 +- src/EnzymeRules/enzymerules.jl | 82 +++++++++++++++++++ src/ForwardDiff/{krylov.jl => forwarddiff.jl} | 0 test/enzymediff.jl | 64 +++++++++++++++ test/runtests.jl | 3 + 7 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 src/EnzymeRules/enzymerules.jl rename src/ForwardDiff/{krylov.jl => forwarddiff.jl} (100%) create mode 100644 test/enzymediff.jl diff --git a/Project.toml b/Project.toml index e20af54..7082957 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Michel Schanen "] version = "0.1.0" [deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" diff --git a/README.md b/README.md index 8150225..7abeea9 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,42 @@ # DiffKrylov [![][build-latest-img]][build-url] [![][codecov-latest-img]][codecov-latest-url] -DiffKrylov provides a differentiable API for [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) using [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). In the future, we will also support Enzyme and its reverse mode. +DiffKrylov provides a differentiable API for +[Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) using +[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and +[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). This is a work in progress and +eventually should enable numerical comparisons between discrete and continuous +tangent and adjoint methods (see this +[report](http://137.226.34.227/Publications/AIB/2012/2012-10.pdf)). + +## Current Technical Limitations + +* Only supports `gmres`, `cg`, and `bicgstab` methods +* No support for inplace methods `gmres!`, `cg!`, and `bicgstab!` +* No support for options when using Enzyme +* No support for sparse matrices using Enzyme +* No support for linear operators + +## Current Open Questions +* How to handle preconditioners? +* How to set the options for the tangent/adjoint solve based on the options for the forward solve? For example `bicgtab` may return `NaN` for the tangents or adjoints. + +## Installation + +```julia +] add DiffKrylov +``` + +## Usage + +Using ForwardDiff.jl, we can compute the Jacobian of `x` with respect to `b` using the ForwardDiff.jl API: + +```julia +using ForwardDiff, DiffKrylov, Krylov, Random +A = rand(64,64) +b = rand(64) +J = ForwardDiff.jacobian(x -> gmres(A, x)[1], b) +``` [codecov-latest-img]: https://codecov.io/gh/JuliaSmoothOptimizers/DiffKrylov.jl/branch/main/graphs/badge.svg?branch=main [codecov-latest-url]: https://codecov.io/github/JuliaSmoothOptimizers/DiffKrylov.jl?branch=main diff --git a/src/DiffKrylov.jl b/src/DiffKrylov.jl index ee69a4f..bef9a09 100644 --- a/src/DiffKrylov.jl +++ b/src/DiffKrylov.jl @@ -2,5 +2,6 @@ module DiffKrylov using Krylov using SparseArrays -include("ForwardDiff/krylov.jl") -end \ No newline at end of file +include("ForwardDiff/forwarddiff.jl") +include("EnzymeRules/enzymerules.jl") +end diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl new file mode 100644 index 0000000..fe151ab --- /dev/null +++ b/src/EnzymeRules/enzymerules.jl @@ -0,0 +1,82 @@ +using Enzyme +import .EnzymeRules: forward, reverse, augmented_primal +using .EnzymeRules + +for solver in (:cg, :bicgstab, :gmres) + @eval begin + function forward( + func::Const{typeof(Krylov.$solver)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + _A::Union{Const, Duplicated}, + _b::Union{Const, Duplicated}; + options... + ) + A = _A.val + b = _b.val + dx = [] + x, stats = Krylov.$solver(A,b; options...) + if isa(_A, Duplicated) && isa(_b, Duplicated) + dA = _A.dval + db = _b.dval + db -= dA*x + dx, dstats = Krylov.$solver(A,db; options...) + elseif isa(_A, Duplicated) && isa(_b, Const) + dA = _A.dval + db = -dA*x + dx, dstats = Krylov.$solver(A,db; options...) + elseif isa(_A, Const) && isa(_b, Duplicated) + db = _b.dval + dx, dstats = Krylov.$solver(A,db; options...) + elseif isa(_A, Const) && isa(_b, Const) + nothing + else + error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))") + end + + if RT <: Const + return (x, stats) + elseif RT <: DuplicatedNoNeed + return (dx, stats) + else + return Duplicated((x, stats), (dx, dstats)) + end + end + end +end + +export forward + +function augmented_primal( + config, + func::Const{typeof(Krylov.gmres)}, + ret::Type{<:Duplicated}, + _A::Union{Const, Duplicated}, + _b::Union{Const, Duplicated} +) + A = _A.val + b = _b.val + x, stats = Krylov.gmres(A,b) + bx = zeros(length(x)) + bstats = deepcopy(stats) + if needs_primal(config) + return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx))) + else + return AugmentedReturn(nothing, (bx, bstats), (A,x)) + end +end + +function reverse( + config, + ::Const{typeof(Krylov.gmres)}, + dret, + tape, + _A, + _b +) + (A,x,bx) = tape + _b.dval .= gmres(transpose(A), bx[])[1] + _A.dval .= -x .* _b.dval' + return (nothing, nothing) +end + +export augmented_primal, reverse diff --git a/src/ForwardDiff/krylov.jl b/src/ForwardDiff/forwarddiff.jl similarity index 100% rename from src/ForwardDiff/krylov.jl rename to src/ForwardDiff/forwarddiff.jl diff --git a/test/enzymediff.jl b/test/enzymediff.jl new file mode 100644 index 0000000..be91bb8 --- /dev/null +++ b/test/enzymediff.jl @@ -0,0 +1,64 @@ +using Enzyme +import .EnzymeRules: forward, reverse, augmented_primal +using .EnzymeRules + +include("get_div_grad.jl") +include("utils.jl") + +@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab) + function sparse_laplacian(n :: Int=16; FC=Float64) + A = get_div_grad(n, n, n) + b = ones(n^3) + return A, b + end + + A, b = sparse_laplacian(4, FC=Float64) + denseA = Matrix(A) + fdm = central_fdm(8, 1); + function A_one_one(x) + _A = copy(denseA) + _A[1,1] = x + solver(_A,b) + end + + function b_one(x) + _b = copy(b) + _b[1] = x + solver(denseA,_b) + end + + fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(denseA[1,1])) + fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a)[1], copy(b[1])) + fd =fda[1] + fdb[1] + # Test forward + ddA = Duplicated(denseA, zeros(size(denseA))) + ddb = Duplicated(b, zeros(length(b))) + ddA.dval[1,1] = 1.0 + ddb.dval[1] = 1.0 + ddx = Enzyme.autodiff( + Forward, + solver, + ddA, + ddb + ) + @test isapprox(ddx[1][1], fd, atol=1e-4, rtol=1e-4) + # Test reverse + function driver!(x, A, b) + x .= gmres(A,b)[1] + nothing + end + ddA = Duplicated(denseA, zeros(size(denseA))) + ddb = Duplicated(b, zeros(length(b))) + ddx = Duplicated(zeros(length(b)), zeros(length(b))) + ddx.dval[1] = 1.0 + Enzyme.autodiff( + Reverse, + driver!, + ddx, + ddA, + ddb + ) + + @test isapprox(ddb.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4) + @test isapprox(ddA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4) +end diff --git a/test/runtests.jl b/test/runtests.jl index e9ee66b..ddcdbbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,7 @@ rtol = 0.0 @testset "ForwardDiff" begin include("forwarddiff.jl") end + @testset "Enzyme" begin + include("enzymediff.jl") + end end