Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve ForwardDiff tagging for HVP #596

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.17"
version = "0.6.18"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ using DifferentiationInterface:
BatchSizeSettings,
Cache,
Constant,
PrepContext,
Context,
FixTail,
DerivativePrep,
DifferentiateWith,
GradientPrep,
Expand All @@ -21,6 +23,7 @@ using DifferentiationInterface:
SecondOrder,
inner,
outer,
shuffled_gradient,
unwrap,
with_contexts
import ForwardDiff.DiffResults as DR
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
struct ForwardDiffOverSomethingHVPWrapper{F}
f::F
end

"""
tag_backend_hvp(f, ::AutoForwardDiff, x)

Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like.
"""
tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend

function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize}
tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x))
return AutoForwardDiff{chunksize,typeof(tag)}(tag)
end

struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep
tagged_outer_backend::B
inner_gradient::G
outer_pushforward_prep::E
struct ForwardDiffOverSomethingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep
inner_gradient_prep::E1
outer_pushforward_prep::E2
end

function DI.prepare_hvp(
Expand All @@ -27,65 +10,94 @@ function DI.prepare_hvp(
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
rewrap = Rewrap(contexts...)
tagged_outer_backend = tag_backend_hvp(f, outer(backend), x)
T = tag_type(f, tagged_outer_backend, x)
T = tag_type(shuffled_gradient, outer(backend), x)
xdual = make_dual(T, x, tx)
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
# TODO: get rid of closure?
function inner_gradient(x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
end
outer_pushforward_prep = DI.prepare_pushforward(
inner_gradient, tagged_outer_backend, x, tx, contexts...
inner_gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return ForwardDiffOverSomethingHVPPrep(
tagged_outer_backend, inner_gradient, outer_pushforward_prep
outer_pushforward_prep = DI.prepare_pushforward(
shuffled_gradient, outer(backend), x, tx, new_contexts...
)
return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep)
end

function DI.hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
)
end

function DI.hvp!(
f::F,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
DI.pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.pushforward!(
shuffled_gradient,
tg,
outer_pushforward_prep,
outer(backend),
x,
tx,
new_contexts...,
)
return tg
end

function DI.gradient_and_hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.value_and_pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
)
end

Expand All @@ -94,14 +106,28 @@ function DI.gradient_and_hvp!(
grad,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
new_grad, _ = DI.value_and_pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient,
tg,
outer_pushforward_prep,
outer(backend),
x,
tx,
new_contexts...,
)
return copyto!(grad, new_grad), tg
end
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
end

_translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c)
_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = unwrap(c)

function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B}
c0 = unwrap(c)
Expand Down
11 changes: 11 additions & 0 deletions DifferentiationInterface/src/first_order/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,14 @@ function shuffled_gradient(
) where {F,C}
return gradient(f, backend, x, rewrap(unannotated_contexts...)...)
end

function shuffled_gradient(
x,
f::F,
prep::GradientPrep,
backend::AbstractADType,
rewrap::Rewrap{C},
unannotated_contexts::Vararg{Any,C},
) where {F,C}
return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...)
end
6 changes: 6 additions & 0 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ unwrap(c::Cache) = c.data

Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data

struct PrepContext{T<:Prep} <: Context
data::T
end

unwrap(c::PrepContext) = c.data

struct Rewrap{C,T}
context_makers::T
function Rewrap(contexts::Vararg{Context,C}) where {C}
Expand Down
Loading