From 0f24d223bfeee4545f898530072b98d473e9a9da Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:49:54 +0200 Subject: [PATCH 1/2] Cache support with ForwardDiff --- .../docs/src/explanation/advanced.md | 2 +- .../docs/src/tutorials/advanced.md | 2 + .../DifferentiationInterfaceForwardDiffExt.jl | 2 + .../onearg.jl | 146 ++++++++++-------- .../twoarg.jl | 50 +++--- .../utils.jl | 14 ++ DifferentiationInterface/src/utils/context.jl | 31 ++-- .../test/Back/ForwardDiff/test.jl | 8 + .../test/Misc/Internals/context.jl | 5 +- 9 files changed, 159 insertions(+), 101 deletions(-) diff --git a/DifferentiationInterface/docs/src/explanation/advanced.md b/DifferentiationInterface/docs/src/explanation/advanced.md index 76c5ba012..686558032 100644 --- a/DifferentiationInterface/docs/src/explanation/advanced.md +++ b/DifferentiationInterface/docs/src/explanation/advanced.md @@ -19,7 +19,7 @@ Right now, there are two kinds of context: [`Constant`](@ref) and [`Cache`](@ref This feature is still experimental and will not be supported by all backends. At the moment: - `Constant` is supported by all backends except symbolic ones - - `Cache` is only supported by finite difference backends + - `Cache` is only supported by finite difference backends and [`AutoForwardDiff`](@ref), but it is not yet optimized Semantically, both of these calls compute the partial gradient of `f(x, c)` with respect to `x`, but they consider `c` differently: diff --git a/DifferentiationInterface/docs/src/tutorials/advanced.md b/DifferentiationInterface/docs/src/tutorials/advanced.md index 036973708..ceb4ee26d 100644 --- a/DifferentiationInterface/docs/src/tutorials/advanced.md +++ b/DifferentiationInterface/docs/src/tutorials/advanced.md @@ -49,6 +49,8 @@ prep_other_constant = prepare_gradient(f_multiarg, backend, x, Constant(-1)) gradient(f_multiarg, prep_other_constant, backend, x, Constant(10)) ``` +For additional arguments which act as mutated buffers, the [`Cache`](@ref) wrapper is the appropriate choice instead of [`Constant`](@ref). + ## Sparsity Sparse AD is very useful when Jacobian or Hessian matrices have a lot of zeros. diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index 042a2c5b0..e24b4eec8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -5,6 +5,8 @@ using Base: Fix1, Fix2 import DifferentiationInterface as DI using DifferentiationInterface: BatchSizeSettings, + Cache, + Constant, Context, DerivativePrep, DifferentiateWith, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 3f317cf59..6c76e6587 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -6,19 +6,26 @@ function DI.value_and_pushforward( f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C} ) where {F,B,C} T = tag_type(f, backend, x) - xdual_tmp = make_dual(T, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) + xdual = make_dual(T, x, tx) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual, contexts_dual...) y = myvalue(T, ydual) ty = mypartials(T, Val(B), ydual) return y, ty end function DI.value_and_pushforward!( - f::F, ty::NTuple, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C} -) where {F,C} + f::F, + ty::NTuple{B}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{Context,C}, +) where {F,B,C} T = tag_type(f, backend, x) - xdual_tmp = make_dual(T, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) + xdual = make_dual(T, x, tx) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual, contexts_dual...) y = myvalue(T, ydual) mypartials!(T, ty, ydual) return y, ty @@ -28,18 +35,25 @@ function DI.pushforward( f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C} ) where {F,B,C} T = tag_type(f, backend, x) - xdual_tmp = make_dual(T, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) + xdual = make_dual(T, x, tx) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual, contexts_dual...) ty = mypartials(T, Val(B), ydual) return ty end function DI.pushforward!( - f::F, ty::NTuple, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C} -) where {F,C} + f::F, + ty::NTuple{B}, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{Context,C}, +) where {F,B,C} T = tag_type(f, backend, x) - xdual_tmp = make_dual(T, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) + xdual = make_dual(T, x, tx) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual, contexts_dual...) mypartials!(T, ty, ydual) return ty end @@ -62,11 +76,12 @@ function compute_ydual_onearg( f::F, prep::ForwardDiffOneArgPushforwardPrep{T}, x::Number, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{Context,C}, -) where {F,T,C} - xdual_tmp = make_dual(T, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) +) where {F,T,B,C} + xdual = make_dual(T, x, tx) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual, contexts_dual...) return ydual end @@ -74,12 +89,13 @@ function compute_ydual_onearg( f::F, prep::ForwardDiffOneArgPushforwardPrep{T}, x, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{Context,C}, -) where {F,T,C} +) where {F,T,B,C} (; xdual_tmp) = prep make_dual!(T, xdual_tmp, x, tx) - ydual = f(xdual_tmp, map(unwrap, contexts)...) + contexts_dual = translate(T, Val(B), contexts...) + ydual = f(xdual_tmp, contexts_dual...) return ydual end @@ -146,7 +162,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -157,7 +173,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -171,7 +187,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -184,7 +200,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -197,7 +213,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -208,7 +224,7 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -224,7 +240,7 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -238,7 +254,7 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -250,7 +266,7 @@ function DI.gradient!( end function DI.gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -268,7 +284,7 @@ struct ForwardDiffGradientPrep{C} <: GradientPrep end function DI.prepare_gradient( - f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -283,7 +299,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -294,7 +310,7 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) result = GradientResult(x) @@ -308,14 +324,14 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) return gradient!(grad, fc, x, prep.config) end function DI.gradient( - f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) return gradient(fc, x, prep.config) @@ -326,7 +342,7 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -343,7 +359,7 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -355,7 +371,7 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -367,7 +383,7 @@ function DI.jacobian!( end function DI.jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -385,7 +401,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -400,7 +416,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) y = fc(x) @@ -416,7 +432,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) return fc(x), jacobian(fc, x, prep.config) @@ -428,7 +444,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) return jacobian!(jac, fc, x, prep.config) @@ -439,7 +455,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) return jacobian(fc, x, prep.config) @@ -448,13 +464,17 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} return NoSecondDerivativePrep() end function DI.second_derivative( - f::F, ::NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, + ::NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{Constant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -469,7 +489,7 @@ function DI.second_derivative!( ::NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -479,7 +499,11 @@ function DI.second_derivative!( end function DI.value_derivative_and_second_derivative( - f::F, ::NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, + ::NoSecondDerivativePrep, + backend::AutoForwardDiff, + x, + contexts::Vararg{Constant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -498,7 +522,7 @@ function DI.value_derivative_and_second_derivative!( ::NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -513,7 +537,7 @@ end ## HVP function DI.prepare_hvp( - f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Constant,C} ) where {F,C} return DI.prepare_hvp(f, SecondOrder(backend, backend), x, tx, contexts...) end @@ -524,7 +548,7 @@ function DI.hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} return DI.hvp(f, prep, SecondOrder(backend, backend), x, tx, contexts...) end @@ -536,7 +560,7 @@ function DI.hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} return DI.hvp!(f, tg, prep, SecondOrder(backend, backend), x, tx, contexts...) end @@ -547,7 +571,7 @@ function DI.gradient_and_hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} return DI.gradient_and_hvp(f, prep, SecondOrder(backend, backend), x, tx, contexts...) end @@ -560,7 +584,7 @@ function DI.gradient_and_hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} return DI.gradient_and_hvp!( f, grad, tg, prep, SecondOrder(backend, backend), x, tx, contexts... @@ -572,7 +596,7 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.hessian!( - f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -584,7 +608,7 @@ function DI.hessian!( end function DI.hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -596,7 +620,7 @@ function DI.hessian( end function DI.value_gradient_and_hessian!( - f::F, grad, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, grad, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -613,7 +637,7 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = with_contexts(f, contexts...) @@ -634,7 +658,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -651,14 +675,14 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) return hessian!(hess, fc, x, prep.array_config) end function DI.hessian( - f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) return hessian(fc, x, prep.array_config) @@ -671,7 +695,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc = with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -683,7 +707,7 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Context,C} + f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc = with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 26ad4d364..ce74dd283 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -21,12 +21,13 @@ function compute_ydual_twoarg( y, prep::ForwardDiffTwoArgPushforwardPrep{T}, x::Number, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{Context,C}, -) where {F,T,C} +) where {F,T,B,C} (; ydual_tmp) = prep xdual_tmp = make_dual(T, x, tx) - f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...) + contexts_dual = translate(T, Val(B), contexts...) + f!(ydual_tmp, xdual_tmp, contexts_dual...) return ydual_tmp end @@ -35,12 +36,13 @@ function compute_ydual_twoarg( y, prep::ForwardDiffTwoArgPushforwardPrep{T}, x, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{Context,C}, -) where {F,T,C} +) where {F,T,B,C} (; xdual_tmp, ydual_tmp) = prep make_dual!(T, xdual_tmp, x, tx) - f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...) + contexts_dual = translate(T, Val(B), contexts...) + f!(ydual_tmp, xdual_tmp, contexts_dual...) return ydual_tmp end @@ -109,7 +111,7 @@ end ### Unprepared, only when tag is not specified function DI.value_and_derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if T === Nothing fc! = with_contexts(f!, contexts...) @@ -123,7 +125,7 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if T === Nothing fc! = with_contexts(f!, contexts...) @@ -137,7 +139,7 @@ function DI.value_and_derivative!( end function DI.derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if T === Nothing fc! = with_contexts(f!, contexts...) @@ -149,7 +151,7 @@ function DI.derivative( end function DI.derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if T === Nothing fc! = with_contexts(f!, contexts...) @@ -167,7 +169,7 @@ struct ForwardDiffTwoArgDerivativePrep{C} <: DerivativePrep end function DI.prepare_derivative( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc! = with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -181,7 +183,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -196,7 +198,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -210,7 +212,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) return derivative(fc!, y, x, prep.config) @@ -223,7 +225,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) return derivative!(der, fc!, y, x, prep.config) @@ -234,7 +236,7 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = with_contexts(f!, contexts...) @@ -249,7 +251,7 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = with_contexts(f!, contexts...) @@ -263,7 +265,7 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = with_contexts(f!, contexts...) @@ -275,7 +277,7 @@ function DI.jacobian( end function DI.jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C} + f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Constant,C} ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = with_contexts(f!, contexts...) @@ -293,7 +295,7 @@ struct ForwardDiffTwoArgJacobianPrep{C} <: JacobianPrep end function DI.prepare_jacobian( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{Context,C} + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{Constant,C} ) where {F,C} fc! = with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -308,7 +310,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -324,7 +326,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -338,7 +340,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) return jacobian(fc!, y, x, prep.config) @@ -351,7 +353,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, ::AutoForwardDiff, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Constant,C}, ) where {F,C} fc! = with_contexts(f!, contexts...) return jacobian!(jac, fc!, y, x, prep.config) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index a7ae576a9..a5e90757f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -83,3 +83,17 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end return ty end + +_translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c) + +function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B} + c0 = unwrap(c) + return make_dual(T, c0, ntuple(_ -> similar(c0), Val(B))) # TODO: optimize +end + +function translate(::Type{T}, ::Val{B}, contexts::Vararg{Context,C}) where {T,B,C} + new_contexts = map(contexts) do c + _translate(T, Val(B), c) + end + return new_contexts +end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 9815fd647..02fd7e1c5 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -51,13 +51,10 @@ struct Constant{T} <: Context data::T end +constant_maker(c) = Constant(c) unwrap(c::Constant) = c.data -function Base.convert(::Type{Constant{T}}, x::Constant) where {T} - return Constant(convert(T, x.data)) -end - -Base.convert(::Type{Constant{T}}, x) where {T} = Constant(convert(T, x)) +Base.:(==)(c1::Constant, c2::Constant) = c1.data == c2.data """ Cache @@ -70,25 +67,33 @@ struct Cache{T} <: Context data::T end +cache_maker(c) = Cache(c) unwrap(c::Cache) = c.data -function Base.convert(::Type{Cache{T}}, x::Cache) where {T} - return Cache(convert(T, x.data)) -end - -Base.convert(::Type{Cache{T}}, x) where {T} = Cache(convert(T, x)) +Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data struct Rewrap{C,T} + context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C} - T = typeof(contexts) - return new{C,T}() + context_makers = map(contexts) do c + if c isa Cache + cache_maker + elseif c isa Constant + constant_maker + else + nothing + end + end + return new{C,typeof(context_makers)}(context_makers) end end (::Rewrap{0})() = () function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T} - return T(unannotated_contexts) + return map(r.context_makers, unannotated_contexts) do maker, c + maker(c) + end end with_contexts(f) = f diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f84666f44..56dd666f9 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -24,6 +24,14 @@ test_differentiation( backends, default_scenarios(; include_constantified=true); logging=LOGGING ); +test_differentiation( + AutoForwardDiff(), + default_scenarios(; + include_normal=false, include_batchified=false, include_cachified=true + ); + logging=LOGGING, +); + test_differentiation( AutoForwardDiff(); correctness=false, type_stability=:prepared, logging=LOGGING ); diff --git a/DifferentiationInterface/test/Misc/Internals/context.jl b/DifferentiationInterface/test/Misc/Internals/context.jl index 4a8c3c19a..fa898b6fe 100644 --- a/DifferentiationInterface/test/Misc/Internals/context.jl +++ b/DifferentiationInterface/test/Misc/Internals/context.jl @@ -14,6 +14,7 @@ contexts = () r = @inferred Rewrap() @test r() == () -contexts = (Constant(1), Constant(2.0)) +contexts = (Constant(1.0), Cache([2.0])) r = @inferred Rewrap(contexts...) -@test r(3, 4) == (Constant(3), Constant(4.0)) +@test (@inferred r(3.0, [4.0])) == (Constant(3.0), Cache([4.0])) +@test (@inferred r(3, [4.0f0])) isa Tuple{Constant{Int},Cache{Vector{Float32}}} From cff6e759e8aadd29900a70b173f88395b00ce0be Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:46:51 +0200 Subject: [PATCH 2/2] Fix rewrap --- DifferentiationInterface/Project.toml | 2 +- DifferentiationInterface/src/utils/context.jl | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 2b2cafe01..5f6992b65 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.15" +version = "0.6.16" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 02fd7e1c5..3eb2ef879 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -52,6 +52,7 @@ struct Constant{T} <: Context end constant_maker(c) = Constant(c) +maker(::Constant) = constant_maker unwrap(c::Constant) = c.data Base.:(==)(c1::Constant, c2::Constant) = c1.data == c2.data @@ -68,6 +69,7 @@ struct Cache{T} <: Context end cache_maker(c) = Cache(c) +maker(::Cache) = cache_maker unwrap(c::Cache) = c.data Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data @@ -75,15 +77,7 @@ Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data struct Rewrap{C,T} context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C} - context_makers = map(contexts) do c - if c isa Cache - cache_maker - elseif c isa Constant - constant_maker - else - nothing - end - end + context_makers = map(maker, contexts) return new{C,typeof(context_makers)}(context_makers) end end