Skip to content

Commit

Permalink
Test ForwardDiff over ReverseDiff (#386)
Browse files Browse the repository at this point in the history
* Test ForwardDiff over ReverseDiff

* Add print

* Fix twoarg

* Add primal value computation
  • Loading branch information
gdalle authored Jul 28, 2024
1 parent 9b46c79 commit 1c35de5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ function DI.prepare_gradient(
end

function DI.value_and_gradient!(
_f,
f,
grad::AbstractArray,
::AutoReverseDiff,
x::AbstractArray,
extras::ReverseDiffGradientExtras,
)
result = MutableDiffResult(zero(eltype(x)), (grad,))
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
result = gradient!(result, extras.tape, x)
return DiffResults.value(result), DiffResults.derivative(result)
end
Expand Down Expand Up @@ -169,26 +170,26 @@ function DI.hessian(
end

function DI.value_gradient_and_hessian!(
_f,
f,
grad,
hess::AbstractMatrix,
::AutoReverseDiff,
x::AbstractArray,
extras::ReverseDiffHessianExtras,
)
result = MutableDiffResult(one(eltype(x)), (grad, hess))
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
result = hessian!(result, extras.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

function DI.value_gradient_and_hessian(
_f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras
f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras
)
result = MutableDiffResult(
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, extras.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/test/Back/ReverseDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Pkg.add("ReverseDiff")
using DifferentiationInterface, DifferentiationInterfaceTest
using DifferentiationInterface: AutoReverseFromPrimitive
using ReverseDiff: ReverseDiff
using StaticArrays: StaticArrays
using Test

dense_backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
Expand All @@ -17,3 +18,5 @@ for backend in vcat(dense_backends, fromprimitive_backends)
end

test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING);

test_differentiation(AutoReverseDiff(), static_scenarios(); logging=LOGGING);
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/SecondOrder/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ onearg_backends = [
]

twoarg_backends = [
SecondOrder(AutoForwardDiff(), AutoReverseDiff()),
SecondOrder(
AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward, constant_function=true)
),
Expand Down

0 comments on commit 1c35de5

Please sign in to comment.