From e5dfd5c95e15e130a601dcce84bef2a920aef807 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:46:52 +0200 Subject: [PATCH] Add symbolic backends --- .../onearg.jl | 29 +++++++++++++++++-- .../onearg.jl | 16 ++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index b9ce86f32..c2b4cf1f5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -396,9 +396,10 @@ end ## HVP -struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep +struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep hvp_exe::E2 hvp_exe!::E2! + gradient_prep::E1 end function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple) @@ -409,7 +410,9 @@ function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple) hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var) hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) - return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!) + + gradient_prep = DI.prepare_gradient(f, backend, x) + return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( @@ -439,6 +442,28 @@ function DI.hvp!( return tg end +function DI.gradient_and_hvp( + f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple +) + tg = DI.hvp(f, prep, backend, x, tx) + grad = DI.gradient(f, prep.gradient_prep, backend, x) + return grad, tg +end + +function DI.hvp!( + f, + grad, + tg::NTuple, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, +) + DI.hvp!(f, tg, prep, backend, x, tx) + DI.gradient!(f, grad, prep.gradient_prep, backend, x) + return grad, tg +end + ## Hessian struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 60c222310..2ad158a1c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -317,6 +317,22 @@ function DI.hvp!( return tg end +function DI.gradient_and_hvp( + f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple +) + tg = DI.hvp(f, prep, backend, x, tx) + grad = DI.gradient(f, prep.gradient_prep, backend, x) + return grad, tg +end + +function DI.gradient_and_hvp!( + f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple +) + DI.hvp!(f, tg, prep, backend, x, tx) + DI.gradient!(f, grad, prep.gradient_prep, backend, x) + return grad, tg +end + ## Second derivative struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep