Skip to content

Commit

Permalink
Add primal value computation
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jul 28, 2024
1 parent e5e9faf commit 7f5528d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ function DI.prepare_gradient(
end

function DI.value_and_gradient!(
_f,
f,
grad::AbstractArray,
::AutoReverseDiff,
x::AbstractArray,
extras::ReverseDiffGradientExtras,
)
result = MutableDiffResult(zero(eltype(x)), (grad,))
println(result)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
result = gradient!(result, extras.tape, x)
return DiffResults.value(result), DiffResults.derivative(result)
end
Expand Down Expand Up @@ -170,26 +170,26 @@ function DI.hessian(
end

function DI.value_gradient_and_hessian!(
_f,
f,
grad,
hess::AbstractMatrix,
::AutoReverseDiff,
x::AbstractArray,
extras::ReverseDiffHessianExtras,
)
result = MutableDiffResult(one(eltype(x)), (grad, hess))
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
result = hessian!(result, extras.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

function DI.value_gradient_and_hessian(
_f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras
f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras
)
result = MutableDiffResult(
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, extras.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/test/Back/ReverseDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Pkg.add("ReverseDiff")
using DifferentiationInterface, DifferentiationInterfaceTest
using DifferentiationInterface: AutoReverseFromPrimitive
using ReverseDiff: ReverseDiff
using StaticArrays: StaticArrays
using Test

dense_backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
Expand All @@ -17,3 +18,5 @@ for backend in vcat(dense_backends, fromprimitive_backends)
end

test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING);

test_differentiation(AutoReverseDiff(), static_scenarios(); logging=LOGGING);

0 comments on commit 7f5528d

Please sign in to comment.