From 92898c5c87bee170b858a24e2cfe88e9beb013fd Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 9 Oct 2024 19:52:31 -0400 Subject: [PATCH 1/2] Special case Enzyme usage --- Project.toml | 3 +++ src/internal/jacobian.jl | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1651aab2a..dfc4201f7 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,11 @@ version = "3.15.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -32,6 +34,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index b78eb7383..50c6dfc51 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -1,3 +1,5 @@ +using Enzyme + """ JacobianCache(prob, alg, f::F, fu, u, p; autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, linsolve = missing) where {F} @@ -88,7 +90,11 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if iip DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) else - DI.jacobian(f, di_extras, autodiff, u, Constant(p)) + if autodiff <: AutoEnzyme() + hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, Enzyme.onehot(u)), Const(p))[1]...) + else + DI.jacobian(f, di_extras, autodiff, u, Constant(p)) + end end end else From 44aa1be5c1e9133542fed490b586dee3dbd9cb40 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 9 Oct 2024 20:02:02 -0400 Subject: [PATCH 2/2] fix --- src/internal/jacobian.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index 50c6dfc51..aa1a3c83b 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -63,7 +63,9 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if !has_analytic_jac && needs_jac autodiff = construct_concrete_adtype(f, autodiff) - di_extras = if iip + di_extras = if !iip && autodiff isa AutoEnzyme + Enzyme.onehot(u) + elseif iip DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p)) else DI.prepare_jacobian(f, autodiff, u, Constant(prob.p)) @@ -90,8 +92,8 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if iip DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) else - if autodiff <: AutoEnzyme() - hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, Enzyme.onehot(u)), Const(p))[1]...) + if autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, di_extras), Const(p))[1]...) else DI.jacobian(f, di_extras, autodiff, u, Constant(p)) end @@ -159,6 +161,8 @@ function (cache::JacobianCache{iip})( else if SciMLBase.has_jac(cache.f) return cache.f.jac(u, p) + elseif cache.autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, cache.f, BatchDuplicated(u, cache.di_extras), Const(p))[1]...) else return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) end