Skip to content

Commit

Permalink
WIP: fix KrylovJL_GMRES with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 29, 2024
1 parent 272808b commit 02a9cca
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using LinearSolve.LinearAlgebra
using EnzymeCore
using EnzymeCore: EnzymeRules

@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true

function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down
6 changes: 4 additions & 2 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA),
@test db1 db12
@test db2 db22

#=

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
Expand All @@ -167,12 +167,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
norm(s1 + s2)
end

dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2 atol=5e-5
@test db1 db12
@test db2 db22
=#

A = rand(n, n);
dA = zeros(n, n);
Expand Down

0 comments on commit 02a9cca

Please sign in to comment.