Skip to content

Commit

Permalink
Add symbolic backends
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 16, 2024
1 parent 1d5d6a0 commit e5dfd5c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,10 @@ end

## HVP

struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep
hvp_exe::E2
hvp_exe!::E2!
gradient_prep::E1
end

function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
Expand All @@ -409,7 +410,9 @@ function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!)

gradient_prep = DI.prepare_gradient(f, backend, x)
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep)
end

function DI.hvp(
Expand Down Expand Up @@ -439,6 +442,28 @@ function DI.hvp!(
return tg
end

function DI.gradient_and_hvp(
f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple
)
tg = DI.hvp(f, prep, backend, x, tx)
grad = DI.gradient(f, prep.gradient_prep, backend, x)
return grad, tg
end

function DI.hvp!(
f,
grad,
tg::NTuple,
prep::FastDifferentiationHVPPrep,
backend::AutoFastDifferentiation,
x,
tx::NTuple,
)
DI.hvp!(f, tg, prep, backend, x, tx)
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
return grad, tg
end

## Hessian

struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,22 @@ function DI.hvp!(
return tg
end

function DI.gradient_and_hvp(
f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
)
tg = DI.hvp(f, prep, backend, x, tx)
grad = DI.gradient(f, prep.gradient_prep, backend, x)
return grad, tg
end

function DI.gradient_and_hvp!(
f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
)
DI.hvp!(f, tg, prep, backend, x, tx)
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
return grad, tg
end

## Second derivative

struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep
Expand Down

0 comments on commit e5dfd5c

Please sign in to comment.