Skip to content

Commit

Permalink
Hessian from HVPs, tested in each mod (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 27, 2024
1 parent 38fd496 commit e2611fa
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module DifferentiationInterfaceZygoteExt
using ADTypes: AutoZygote
import DifferentiationInterface as DI
using DocStringExtensions
using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, withjacobian
using Zygote:
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian

DI.supports_mutation(::AutoZygote) = DI.MutationNotSupported()

Expand Down Expand Up @@ -36,12 +37,26 @@ end

## Jacobian

function DI.value_and_jacobian(f, ::AutoZygote, x::AbstractArray, extras::Nothing)
return f(x), only(jacobian(f, x))
function DI.value_and_jacobian(f, ::AutoZygote, x, extras::Nothing)
return f(x), only(jacobian(f, x)) # https://github.com/FluxML/Zygote.jl/issues/1506
end

function DI.jacobian(f, ::AutoZygote, x::AbstractArray, extras::Nothing)
function DI.jacobian(f, ::AutoZygote, x, extras::Nothing)
return only(jacobian(f, x))
end

function DI.value_and_jacobian!!(f, jac, backend::AutoZygote, x, extras::Nothing)
return DI.value_and_jacobian(f, backend, x, extras)
end

function DI.jacobian!!(f, jac, backend::AutoZygote, x, extras::Nothing)
return DI.jacobian(f, backend, x, extras)
end

## Hessian

function DI.hessian(f, ::AutoZygote, x, extras::Nothing)
return hessian(f, x)
end

end
2 changes: 1 addition & 1 deletion src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using ADTypes:
AbstractSymbolicDifferentiationMode
using DocStringExtensions
using FillArrays: OneElement
using LinearAlgebra: dot
using LinearAlgebra: Symmetric, dot

"""
AutoFastDifferentiation
Expand Down
9 changes: 5 additions & 4 deletions src/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ function hessian(f, backend::AbstractADType, x, extras=prepare_hessian(f, backen
end

function hessian(f, backend::SecondOrder, x, extras=prepare_hessian(f, backend, x))
# suboptimal for reverse-over-forward
gradient_closure(z) = gradient(f, inner(backend), z, inner(extras))
hess = jacobian(gradient_closure, outer(backend), x, outer(extras))
return hess
hess = stack(vec(CartesianIndices(x))) do j
hess_col_j = hvp(f, backend, x, basis(backend, x, j), extras)
vec(hess_col_j)
end
return Symmetric(hess)
end
2 changes: 1 addition & 1 deletion src/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function hvp(f, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend, x))
return hvp_aux(f, backend, x, v, extras, hvp_mode(backend))
end

function hvp_aux(f, backend, x, v, extras, orwardOverReverse)
function hvp_aux(f, backend, x, v, extras, ::ForwardOverReverse)
# JVP of the gradient
gradient_closure(z) = gradient(f, inner(backend), z, inner(extras))
p = pushforward(gradient_closure, outer(backend), x, v, outer(extras))
Expand Down
9 changes: 9 additions & 0 deletions test/second_order.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using Tracker: Tracker
using Zygote: Zygote

second_order_backends = [AutoForwardDiff(), AutoReverseDiff()]

second_order_mixed_backends = [
# forward over forward
SecondOrder(AutoEnzyme(Enzyme.Forward), AutoForwardDiff()),
SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)),
# forward over reverse
SecondOrder(AutoForwardDiff(), AutoZygote()),
# reverse over forward
SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoForwardDiff()),
# reverse over reverse
SecondOrder(AutoReverseDiff(), AutoZygote()),
]

for backend in vcat(second_order_backends, second_order_mixed_backends)
Expand Down

0 comments on commit e2611fa

Please sign in to comment.