Skip to content

Commit

Permalink
Da fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 16, 2024
1 parent 2deb5da commit 1d5d6a0
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function DI.prepare_hvp(
T = tag_type(f, tagged_outer_backend, x)
xdual = make_dual(T, x, tx)
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
# TODO: get rid of closure?
function inner_gradient(x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
Expand Down Expand Up @@ -77,27 +78,30 @@ end
function DI.gradient_and_hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
backend::SecondOrder{<:AutoForwardDiff},
::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
tg = DI.hvp(f, prep, backend, x, tx, contexts...)
grad = DI.gradient(f, inner(backend), x, tx, contexts...) # TODO: optimize
return grad, tg
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
return DI.value_and_pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
)
end

function DI.gradient_and_hvp!(
f::F,
grad,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
backend::SecondOrder{<:AutoForwardDiff},
::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
DI.hvp!(f, tg, prep, backend, x, tx, contexts...)
DI.gradient(f, grad, inner(backend), x, tx, contexts...) # TODO: optimize
return grad, tg
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
new_grad, _ = DI.value_and_pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
)
return copyto!(grad, new_grad), tg
end
Loading

0 comments on commit 1d5d6a0

Please sign in to comment.