Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 10, 2024
1 parent 92898c5 commit 44aa1be
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/internal/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,

if !has_analytic_jac && needs_jac
autodiff = construct_concrete_adtype(f, autodiff)
di_extras = if iip
di_extras = if !iip && autodiff isa AutoEnzyme
Enzyme.onehot(u)
elseif iip
DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p))
else
DI.prepare_jacobian(f, autodiff, u, Constant(prob.p))
Expand All @@ -90,8 +92,8 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
if iip
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p))
else
if autodiff <: AutoEnzyme()
hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, Enzyme.onehot(u)), Const(p))[1]...)
if autodiff isa AutoEnzyme
hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, di_extras), Const(p))[1]...)
else
DI.jacobian(f, di_extras, autodiff, u, Constant(p))
end
Expand Down Expand Up @@ -159,6 +161,8 @@ function (cache::JacobianCache{iip})(
else
if SciMLBase.has_jac(cache.f)
return cache.f.jac(u, p)
elseif cache.autodiff isa AutoEnzyme
hcat(Enzyme.autodiff(Forward, cache.f, BatchDuplicated(u, cache.di_extras), Const(p))[1]...)
else
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
end
Expand Down

0 comments on commit 44aa1be

Please sign in to comment.