Skip to content

Commit

Permalink
fix: forward rules aliasing issue
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 26, 2024
1 parent f5aefbe commit 0d8a47a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
return nothing
end
end

dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A

if dres.b == res.b
dres.b .= false
end
if dres.A == res.A
dres.A .= false
end

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
Expand Down Expand Up @@ -50,14 +56,12 @@ function EnzymeRules.forward(
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
b = deepcopy(linsolve.val.b)

db = linsolve.dval.b
dA = linsolve.dval.A
res = deepcopy(res) # Without this copy, the next solve will end up mutating the result

linsolve.val.b = db - dA * res.u
b = linsolve.val.b
linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u
dres = func.val(linsolve.val; kwargs...)

linsolve.val.b = b

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
Expand Down
2 changes: 1 addition & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ end
en_jac = map(onehot(A)) do dA
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
Duplicated(A, dA), Const(b1), Const(alg)))
end |> collect |> (x -> reshape(x, n, n))
end |> collect
@show en_jac

@test en_jacfd_jac rtol=1e-4
Expand Down

0 comments on commit 0d8a47a

Please sign in to comment.