Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 23, 2023
1 parent 84c5196 commit bb93d68
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,44 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
else
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)

Check warning on line 16 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L16

Added line #L16 was not covered by tests
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing)
d_A = if EnzymeRules.width(config) == 1
dres.A

Check warning on line 19 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
else
(dval.A for dval in dres)

Check warning on line 21 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L21

Added line #L21 was not covered by tests
end
d_b = if EnzymeRules.width(config) == 1
dres.b

Check warning on line 24 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
else
(dval.b for dval in dres)

Check warning on line 26 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L26

Added line #L26 was not covered by tests
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))

Check warning on line 28 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L28

Added line #L28 was not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b = cache

Check warning on line 32 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L31-L32

Added lines #L31 - L32 were not covered by tests

if EnzymeRules.width(config) == 1
if d_A !== prob.dval.A
prob.dval.A .+= d_A
d_A .= 0

Check warning on line 37 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L34-L37

Added lines #L34 - L37 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b .+= d_b
d_b .= 0

Check warning on line 41 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L39-L41

Added lines #L39 - L41 were not covered by tests
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob.dval.A
prob.dval.A[i] .+= d_A[i]
d_A[i] .= 0

Check warning on line 47 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L44-L47

Added lines #L44 - L47 were not covered by tests
end
if d_b !== prob.dval.b
prob.dval.b[i] .+= d_b[i]
d_b[i] .= 0

Check warning on line 51 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L49-L51

Added lines #L49 - L51 were not covered by tests
end
end

Check warning on line 53 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L53

Added line #L53 was not covered by tests
end

return (nothing, nothing)

Check warning on line 56 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L56

Added line #L56 was not covered by tests
end

Expand Down

0 comments on commit bb93d68

Please sign in to comment.