Skip to content

Commit

Permalink
Special case Enzyme usage
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 9, 2024
1 parent e7e218b commit 92898c5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ version = "3.15.1"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -32,6 +34,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Expand Down
8 changes: 7 additions & 1 deletion src/internal/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Enzyme

"""
JacobianCache(prob, alg, f::F, fu, u, p; autodiff = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, linsolve = missing) where {F}
Expand Down Expand Up @@ -88,7 +90,11 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
if iip
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p))
else
DI.jacobian(f, di_extras, autodiff, u, Constant(p))
if autodiff <: AutoEnzyme()
hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, Enzyme.onehot(u)), Const(p))[1]...)
else
DI.jacobian(f, di_extras, autodiff, u, Constant(p))
end
end
end
else
Expand Down

0 comments on commit 92898c5

Please sign in to comment.