diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 24c68904d..094b98dee 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -53,14 +53,14 @@ function DI.prepare_gradient( end function DI.value_and_gradient!( - _f, + f, grad::AbstractArray, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffGradientExtras, ) - result = MutableDiffResult(zero(eltype(x)), (grad,)) - println(result) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad,)) result = gradient!(result, extras.tape, x) return DiffResults.value(result), DiffResults.derivative(result) end @@ -170,14 +170,15 @@ function DI.hessian( end function DI.value_gradient_and_hessian!( - _f, + f, grad, hess::AbstractMatrix, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras, ) - result = MutableDiffResult(one(eltype(x)), (grad, hess)) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad, hess)) result = hessian!(result, extras.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) @@ -185,11 +186,10 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - _f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras + f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras ) - result = MutableDiffResult( - one(eltype(x)), (similar(x), similar(x, length(x), length(x))) - ) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) result = hessian!(result, extras.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index 954df9b7a..c5be00d7e 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -4,6 +4,7 @@ Pkg.add("ReverseDiff") using DifferentiationInterface, DifferentiationInterfaceTest using DifferentiationInterface: AutoReverseFromPrimitive using ReverseDiff: ReverseDiff +using StaticArrays: StaticArrays using Test dense_backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] @@ -17,3 +18,5 @@ for backend in vcat(dense_backends, fromprimitive_backends) end test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING); + +test_differentiation(AutoReverseDiff(), static_scenarios(); logging=LOGGING);