Skip to content

Commit

Permalink
Clean up ReverseDiff type annotations (#498)
Browse files Browse the repository at this point in the history
* Remove unneeded evaluation for ReverseDiff

* Undo fix
  • Loading branch information
gdalle authored Sep 25, 2024
1 parent 93c0659 commit 987ca87
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ struct ReverseDiffGradientPrep{T} <: GradientPrep
tape::T
end

function DI.prepare_gradient(
f, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile}
tape = GradientTape(f, x)
if Compile
tape = compile(tape)
Expand All @@ -58,11 +56,7 @@ function DI.prepare_gradient(
end

function DI.value_and_gradient!(
f,
grad::AbstractArray,
prep::ReverseDiffGradientPrep,
::AutoReverseDiff,
x::AbstractArray,
f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
Expand All @@ -71,23 +65,19 @@ function DI.value_and_gradient!(
end

function DI.value_and_gradient(
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x::AbstractArray
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x
)
grad = similar(x)
return DI.value_and_gradient!(f, grad, prep, backend, x)
end

function DI.gradient!(
_f,
grad::AbstractArray,
prep::ReverseDiffGradientPrep,
::AutoReverseDiff,
x::AbstractArray,
_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray
)
return gradient!(grad, prep.tape, x)
end

function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray)
function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
return gradient!(prep.tape, x)
end

Expand All @@ -97,9 +87,7 @@ struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
tape::T
end

function DI.prepare_jacobian(
f, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
function DI.prepare_jacobian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
tape = JacobianTape(f, x)
if Compile
tape = compile(tape)
Expand All @@ -108,37 +96,23 @@ function DI.prepare_jacobian(
end

function DI.value_and_jacobian!(
f,
jac::AbstractMatrix,
prep::ReverseDiffOneArgJacobianPrep,
::AutoReverseDiff,
x::AbstractArray,
f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x
)
y = f(x)
result = MutableDiffResult(y, (jac,))
result = jacobian!(result, prep.tape, x)
return DiffResults.value(result), DiffResults.derivative(result)
end

function DI.value_and_jacobian(
f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray
)
function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
return f(x), jacobian!(prep.tape, x)
end

function DI.jacobian!(
_f,
jac::AbstractMatrix,
prep::ReverseDiffOneArgJacobianPrep,
::AutoReverseDiff,
x::AbstractArray,
)
function DI.jacobian!(_f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
return jacobian!(jac, prep.tape, x)
end

function DI.jacobian(
f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray
)
function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
return jacobian!(prep.tape, x)
end

Expand All @@ -148,35 +122,24 @@ struct ReverseDiffHessianPrep{T} <: HessianPrep
tape::T
end

function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile}
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
tape = HessianTape(f, x)
if Compile
tape = compile(tape)
end
return ReverseDiffHessianPrep(tape)
end

function DI.hessian!(
_f,
hess::AbstractMatrix,
prep::ReverseDiffHessianPrep,
::AutoReverseDiff,
x::AbstractArray,
)
function DI.hessian!(_f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
return hessian!(hess, prep.tape, x)
end

function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray)
function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
return hessian!(prep.tape, x)
end

function DI.value_gradient_and_hessian!(
f,
grad,
hess::AbstractMatrix,
prep::ReverseDiffHessianPrep,
::AutoReverseDiff,
x::AbstractArray,
f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
Expand All @@ -187,10 +150,11 @@ function DI.value_gradient_and_hessian!(
end

function DI.value_gradient_and_hessian(
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = MutableDiffResult(
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
)
result = hessian!(result, prep.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ struct ReverseDiffTwoArgJacobianPrep{T} <: JacobianPrep
tape::T
end

function DI.prepare_jacobian(
f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{Compile}, x) where {Compile}
tape = JacobianTape(f!, y, x)
if Compile
tape = compile(tape)
Expand Down

0 comments on commit 987ca87

Please sign in to comment.