396
396
397
397
# # HVP
398
398
399
- struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
399
+ struct FastDifferentiationHVPPrep{E2,E2!,E1 } <: HVPPrep
400
400
hvp_exe:: E2
401
401
hvp_exe!:: E2!
402
+ gradient_prep:: E1
402
403
end
403
404
404
405
function DI. prepare_hvp (f, :: AutoFastDifferentiation , x, tx:: NTuple )
@@ -409,7 +410,9 @@ function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
409
410
hv_vec_var, v_vec_var = hessian_times_v (y_var, x_vec_var)
410
411
hvp_exe = make_function (hv_vec_var, vcat (x_vec_var, v_vec_var); in_place= false )
411
412
hvp_exe! = make_function (hv_vec_var, vcat (x_vec_var, v_vec_var); in_place= true )
412
- return FastDifferentiationHVPPrep (hvp_exe, hvp_exe!)
413
+
414
+ gradient_prep = DI. prepare_gradient (f, backend, x)
415
+ return FastDifferentiationHVPPrep (hvp_exe, hvp_exe!, gradient_prep)
413
416
end
414
417
415
418
function DI. hvp (
@@ -439,6 +442,28 @@ function DI.hvp!(
439
442
return tg
440
443
end
441
444
445
+ function DI. gradient_and_hvp (
446
+ f, prep:: FastDifferentiationHVPPrep , backend:: AutoFastDifferentiation , x, tx:: NTuple
447
+ )
448
+ tg = DI. hvp (f, prep, backend, x, tx)
449
+ grad = DI. gradient (f, prep. gradient_prep, backend, x)
450
+ return grad, tg
451
+ end
452
+
453
+ function DI. hvp! (
454
+ f,
455
+ grad,
456
+ tg:: NTuple ,
457
+ prep:: FastDifferentiationHVPPrep ,
458
+ backend:: AutoFastDifferentiation ,
459
+ x,
460
+ tx:: NTuple ,
461
+ )
462
+ DI. hvp! (f, tg, prep, backend, x, tx)
463
+ DI. gradient! (f, grad, prep. gradient_prep, backend, x)
464
+ return grad, tg
465
+ end
466
+
442
467
# # Hessian
443
468
444
469
struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep
0 commit comments