diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index b9a23e624..5ea09ca34 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -10,15 +10,19 @@ function DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx) return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp) end +function compute_ydual_onearg( + f, x::Number, dx, extras::ForwardDiffOneArgPushforwardExtras{T} +) where {T} + xdual_tmp = make_dual(T, x, dx) + ydual = f(xdual_tmp) + return ydual +end + function compute_ydual_onearg( f, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T} ) where {T} (; xdual_tmp) = extras - xdual_tmp = if x isa Number - make_dual(T, x, dx) - else - make_dual!(T, xdual_tmp, x, dx) - end + make_dual!(T, xdual_tmp, x, dx) ydual = f(xdual_tmp) return ydual end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index c9c2bba7f..fea75abee 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -15,14 +15,19 @@ function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx) end function compute_ydual_twoarg( - f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} + ::Type{T}, f!, y, x::Number, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} +) where {T} + (; ydual_tmp) = extras + xdual_tmp = make_dual(T, x, dx) + f!(ydual_tmp, xdual_tmp) + return ydual_tmp +end + +function compute_ydual_twoarg( + ::Type{T}, f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} ) where {T} (; xdual_tmp, ydual_tmp) = extras - xdual_tmp = if x isa Number - make_dual(T, x, dx) - else - make_dual!(T, xdual_tmp, x, dx) - end + make_dual!(T, xdual_tmp, x, dx) f!(ydual_tmp, xdual_tmp) return ydual_tmp end @@ -30,7 +35,7 @@ end function DI.value_and_pushforward( f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} ) where {T} - ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras) + ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras) myvalue!(T, y, ydual_tmp) dy = myderivative(T, ydual_tmp) return y, dy @@ -39,7 +44,7 @@ end function DI.pushforward( f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} ) where {T} - ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras) + ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras) dy = myderivative(T, ydual_tmp) return dy end @@ -47,7 +52,7 @@ end function DI.value_and_pushforward!( f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} ) where {T} - ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras) + ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras) myvalue!(T, y, ydual_tmp) myderivative!(T, dy, ydual_tmp) return y, dy @@ -56,7 +61,7 @@ end function DI.pushforward!( f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T} ) where {T} - ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras) + ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras) myderivative!(T, dy, ydual_tmp) return dy end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 838a644fc..9ffa0d923 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -7,16 +7,35 @@ tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)} make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx) make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx) -make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} = xdual .= Dual{T}.(x, dx) +function make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} + for i in eachindex(xdual, x, dx) + xdual[i] = Dual{T}(x[i], dx[i]) + end + return nothing +end myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual) myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual) -myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} = y .= value.(T, ydual) +function myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} + for i in eachindex(y, ydual) + y[i] = value(T, ydual[i]) + end + return nothing +end myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual) myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual) function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T} - return extract_derivative!(T, dy, ydual) + extract_derivative!(T, dy, ydual) + return nothing +end + +function myvalueandderivative!(::Type{T}, y, dy, ydual::AbstractArray) where {T} + for i in eachindex(y, dy, ydual) + y[i] = value(T, ydual[i]) + dy[i] = extract_derivative(T, ydual[i]) + end + return nothing end