Skip to content

Commit

Permalink
Typeof tag
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Jul 28, 2024
1 parent 87c2a39 commit b520af0
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
function DI.prepare_hvp(
f::F, backend::SecondOrder{<:AutoForwardDiff,<:AutoReverseDiff}, x, dx
) where {F}
T = _hvp_tag(f, outer(backend), x)
T = typeof(_hvp_tag(f, outer(backend), x))
xdual = DIForwardDiffExt.make_dual(T, x, dx)
tape = ReverseDiff.GradientTape(f, xdual)
if inner(backend) isa AutoReverseDiff{true}
Expand All @@ -57,7 +57,7 @@ function DI.hvp(
extras::ForwardDiffOverReverseDiffHVPExtras,
) where {F}
@compat (; inner_gradient) = extras
T = _hvp_tag(f, outer(backend), x)
T = typeof(_hvp_tag(f, outer(backend), x))
xdual = DIForwardDiffExt.make_dual(T, x, dx)
ydual = inner_gradient(xdual)
return DIForwardDiffExt.myderivative(T, ydual)
Expand All @@ -72,7 +72,7 @@ function DI.hvp!(
extras::ForwardDiffOverReverseDiffHVPExtras,
) where {F}
@compat (; inner_gradient) = extras
T = _hvp_tag(f, outer(backend), x)
T = typeof(_hvp_tag(f, outer(backend), x))
xdual = DIForwardDiffExt.make_dual(T, x, dx)
ydual = inner_gradient(xdual)
DIForwardDiffExt.myderivative!(T, dg, ydual)
Expand Down

0 comments on commit b520af0

Please sign in to comment.