Skip to content

Commit

Permalink
EnzymeRules (#3)
Browse files Browse the repository at this point in the history
Support for EnzymeRules
  • Loading branch information
michel2323 authored Dec 13, 2023
1 parent b97b3eb commit c1f8df6
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michel Schanen <[email protected]>"]
version = "0.1.0"

[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Expand Down
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
# DiffKrylov
[![][build-latest-img]][build-url] [![][codecov-latest-img]][codecov-latest-url]

DiffKrylov provides a differentiable API for [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) using [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). In the future, we will also support Enzyme and its reverse mode.
DiffKrylov provides a differentiable API for
[Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) using
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and
[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). This is a work in progress and
eventually should enable numerical comparisons between discrete and continuous
tangent and adjoint methods (see this
[report](http://137.226.34.227/Publications/AIB/2012/2012-10.pdf)).

## Current Technical Limitations

* Only supports `gmres`, `cg`, and `bicgstab` methods
* No support for inplace methods `gmres!`, `cg!`, and `bicgstab!`
* No support for options when using Enzyme
* No support for sparse matrices using Enzyme
* No support for linear operators

## Current Open Questions
* How to handle preconditioners?
* How to set the options for the tangent/adjoint solve based on the options for the forward solve? For example `bicgtab` may return `NaN` for the tangents or adjoints.

## Installation

```julia
] add DiffKrylov
```

## Usage

Using ForwardDiff.jl, we can compute the Jacobian of `x` with respect to `b` using the ForwardDiff.jl API:

```julia
using ForwardDiff, DiffKrylov, Krylov, Random
A = rand(64,64)
b = rand(64)
J = ForwardDiff.jacobian(x -> gmres(A, x)[1], b)
```

[codecov-latest-img]: https://codecov.io/gh/JuliaSmoothOptimizers/DiffKrylov.jl/branch/main/graphs/badge.svg?branch=main
[codecov-latest-url]: https://codecov.io/github/JuliaSmoothOptimizers/DiffKrylov.jl?branch=main
Expand Down
5 changes: 3 additions & 2 deletions src/DiffKrylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module DiffKrylov

using Krylov
using SparseArrays
include("ForwardDiff/krylov.jl")
end
include("ForwardDiff/forwarddiff.jl")
include("EnzymeRules/enzymerules.jl")
end
82 changes: 82 additions & 0 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using Enzyme
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

for solver in (:cg, :bicgstab, :gmres)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
_A::Union{Const, Duplicated},
_b::Union{Const, Duplicated};
options...
)
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
else
return Duplicated((x, stats), (dx, dstats))
end
end
end
end

export forward

function augmented_primal(
config,
func::Const{typeof(Krylov.gmres)},
ret::Type{<:Duplicated},
_A::Union{Const, Duplicated},
_b::Union{Const, Duplicated}
)
A = _A.val
b = _b.val
x, stats = Krylov.gmres(A,b)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn((x, stats), (bx, bstats), (A,x, Ref(bx)))
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
end

function reverse(
config,
::Const{typeof(Krylov.gmres)},
dret,
tape,
_A,
_b
)
(A,x,bx) = tape
_b.dval .= gmres(transpose(A), bx[])[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
end

export augmented_primal, reverse
File renamed without changes.
64 changes: 64 additions & 0 deletions test/enzymediff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using Enzyme
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

include("get_div_grad.jl")
include("utils.jl")

@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab)
function sparse_laplacian(n :: Int=16; FC=Float64)
A = get_div_grad(n, n, n)
b = ones(n^3)
return A, b
end

A, b = sparse_laplacian(4, FC=Float64)
denseA = Matrix(A)
fdm = central_fdm(8, 1);
function A_one_one(x)
_A = copy(denseA)
_A[1,1] = x
solver(_A,b)
end

function b_one(x)
_b = copy(b)
_b[1] = x
solver(denseA,_b)
end

fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(denseA[1,1]))
fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a)[1], copy(b[1]))
fd =fda[1] + fdb[1]
# Test forward
ddA = Duplicated(denseA, zeros(size(denseA)))
ddb = Duplicated(b, zeros(length(b)))
ddA.dval[1,1] = 1.0
ddb.dval[1] = 1.0
ddx = Enzyme.autodiff(
Forward,
solver,
ddA,
ddb
)
@test isapprox(ddx[1][1], fd, atol=1e-4, rtol=1e-4)
# Test reverse
function driver!(x, A, b)
x .= gmres(A,b)[1]
nothing
end
ddA = Duplicated(denseA, zeros(size(denseA)))
ddb = Duplicated(b, zeros(length(b)))
ddx = Duplicated(zeros(length(b)), zeros(length(b)))
ddx.dval[1] = 1.0
Enzyme.autodiff(
Reverse,
driver!,
ddx,
ddA,
ddb
)

@test isapprox(ddb.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4)
@test isapprox(ddA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ rtol = 0.0
@testset "ForwardDiff" begin
include("forwarddiff.jl")
end
@testset "Enzyme" begin
include("enzymediff.jl")
end
end

0 comments on commit c1f8df6

Please sign in to comment.