From 02a9cca82f7098c4c26a6a4258523f24dad7ea0c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 24 Sep 2023 21:36:43 -0400 Subject: [PATCH] WIP: fix KrylovJL_GMRES with Enzyme --- ext/LinearSolveEnzymeExt.jl | 3 +++ test/enzyme.jl | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index abd2232e..70aadc40 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -5,6 +5,9 @@ using LinearSolve.LinearAlgebra using EnzymeCore using EnzymeCore: EnzymeRules +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true + function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} diff --git a/test/enzyme.jl b/test/enzyme.jl index b09c0de5..cd638996 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -157,7 +157,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), @test db1 ≈ db12 @test db2 ≈ db22 -#= + function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -167,12 +167,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) norm(s1 + s2) end +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); dA = zeros(n, n);