Skip to content

Commit

Permalink
Less allocs in ForwardDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 28, 2024
1 parent 1e30916 commit 1602931
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ function DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx)
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
end

function compute_ydual_onearg(
f, x::Number, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
) where {T}
xdual_tmp = make_dual(T, x, dx)
ydual = f(xdual_tmp)
return ydual
end

function compute_ydual_onearg(
f, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
) where {T}
(; xdual_tmp) = extras
xdual_tmp = if x isa Number
make_dual(T, x, dx)
else
make_dual!(T, xdual_tmp, x, dx)
end
make_dual!(T, xdual_tmp, x, dx)
ydual = f(xdual_tmp)
return ydual
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,27 @@ function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx)
end

function compute_ydual_twoarg(
f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
::Type{T}, f!, y, x::Number, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
(; ydual_tmp) = extras
xdual_tmp = make_dual(T, x, dx)
f!(ydual_tmp, xdual_tmp)
return ydual_tmp
end

function compute_ydual_twoarg(
::Type{T}, f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
(; xdual_tmp, ydual_tmp) = extras
xdual_tmp = if x isa Number
make_dual(T, x, dx)
else
make_dual!(T, xdual_tmp, x, dx)
end
make_dual!(T, xdual_tmp, x, dx)
f!(ydual_tmp, xdual_tmp)
return ydual_tmp
end

function DI.value_and_pushforward(
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
myvalue!(T, y, ydual_tmp)
dy = myderivative(T, ydual_tmp)
return y, dy
Expand All @@ -39,15 +44,15 @@ end
function DI.pushforward(
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
dy = myderivative(T, ydual_tmp)
return dy
end

function DI.value_and_pushforward!(
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
myvalue!(T, y, ydual_tmp)
myderivative!(T, dy, ydual_tmp)
return y, dy
Expand All @@ -56,7 +61,7 @@ end
function DI.pushforward!(
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
) where {T}
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
myderivative!(T, dy, ydual_tmp)
return dy
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,35 @@ tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)}
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx)

make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} = xdual .= Dual{T}.(x, dx)
function make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T}
for i in eachindex(xdual, x, dx)
xdual[i] = Dual{T}(x[i], dx[i])
end
return nothing
end

myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual)

myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} = y .= value.(T, ydual)
function myvalue!(::Type{T}, y::AbstractArray, ydual) where {T}
for i in eachindex(y, ydual)
y[i] = value(T, ydual[i])
end
return nothing
end

myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual)
myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual)

function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T}
return extract_derivative!(T, dy, ydual)
extract_derivative!(T, dy, ydual)
return nothing
end

function myvalueandderivative!(::Type{T}, y, dy, ydual::AbstractArray) where {T}
for i in eachindex(y, dy, ydual)
y[i] = value(T, ydual[i])
dy[i] = extract_derivative(T, ydual[i])
end
return nothing
end

0 comments on commit 1602931

Please sign in to comment.