Skip to content

Commit e5dfd5c

Browse files
committed
Add symbolic backends
1 parent 1d5d6a0 commit e5dfd5c

File tree

2 files changed

+43
-2
lines changed
  • DifferentiationInterface/ext
    • DifferentiationInterfaceFastDifferentiationExt
    • DifferentiationInterfaceSymbolicsExt

2 files changed

+43
-2
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,10 @@ end
396396

397397
## HVP
398398

399-
struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
399+
struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep
400400
hvp_exe::E2
401401
hvp_exe!::E2!
402+
gradient_prep::E1
402403
end
403404

404405
function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
@@ -409,7 +410,9 @@ function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
409410
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
410411
hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
411412
hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
412-
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!)
413+
414+
gradient_prep = DI.prepare_gradient(f, backend, x)
415+
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep)
413416
end
414417

415418
function DI.hvp(
@@ -439,6 +442,28 @@ function DI.hvp!(
439442
return tg
440443
end
441444

445+
function DI.gradient_and_hvp(
446+
f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple
447+
)
448+
tg = DI.hvp(f, prep, backend, x, tx)
449+
grad = DI.gradient(f, prep.gradient_prep, backend, x)
450+
return grad, tg
451+
end
452+
453+
function DI.hvp!(
454+
f,
455+
grad,
456+
tg::NTuple,
457+
prep::FastDifferentiationHVPPrep,
458+
backend::AutoFastDifferentiation,
459+
x,
460+
tx::NTuple,
461+
)
462+
DI.hvp!(f, tg, prep, backend, x, tx)
463+
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
464+
return grad, tg
465+
end
466+
442467
## Hessian
443468

444469
struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,22 @@ function DI.hvp!(
317317
return tg
318318
end
319319

320+
function DI.gradient_and_hvp(
321+
f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
322+
)
323+
tg = DI.hvp(f, prep, backend, x, tx)
324+
grad = DI.gradient(f, prep.gradient_prep, backend, x)
325+
return grad, tg
326+
end
327+
328+
function DI.gradient_and_hvp!(
329+
f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
330+
)
331+
DI.hvp!(f, tg, prep, backend, x, tx)
332+
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
333+
return grad, tg
334+
end
335+
320336
## Second derivative
321337

322338
struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep

0 commit comments

Comments
 (0)