Skip to content

Commit 0d8a47a

Browse files
committed
fix: forward rules aliasing issue
1 parent f5aefbe commit 0d8a47a

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
1717
return nothing
1818
end
1919
end
20+
2021
dres = func.val(prob.dval, alg.val; kwargs...)
21-
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
22-
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
22+
23+
if dres.b == res.b
24+
dres.b .= false
25+
end
26+
if dres.A == res.A
27+
dres.A .= false
28+
end
2329

2430
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
2531
return Duplicated(res, dres)
@@ -50,14 +56,12 @@ function EnzymeRules.forward(
5056
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
5157
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
5258
end
53-
b = deepcopy(linsolve.val.b)
5459

55-
db = linsolve.dval.b
56-
dA = linsolve.dval.A
60+
res = deepcopy(res) # Without this copy, the next solve will end up mutating the result
5761

58-
linsolve.val.b = db - dA * res.u
62+
b = linsolve.val.b
63+
linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u
5964
dres = func.val(linsolve.val; kwargs...)
60-
6165
linsolve.val.b = b
6266

6367
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)

test/enzyme.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ end
209209
en_jac = map(onehot(A)) do dA
210210
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
211211
Duplicated(A, dA), Const(b1), Const(alg)))
212-
end |> collect |> (x -> reshape(x, n, n))
212+
end |> collect
213213
@show en_jac
214214

215215
@test en_jacfd_jac rtol=1e-4

0 commit comments

Comments
 (0)