Skip to content

Commit

Permalink
Contexts for ReverseDiff (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 25, 2024
1 parent 959f634 commit 9ec4e7e
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 71 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.0"
version = "0.6.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
25 changes: 22 additions & 3 deletions DifferentiationInterface/docs/src/explanation/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ In practice, many AD backends have custom implementations for high-level operato
| `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ |

Moreover, each context type is supported by a specific subset of backends:

| | [`Constant`](@ref) |
| -------------------------- | ------------------ |
| `AutoChainRules` ||
| `AutoDiffractor` ||
| `AutoEnzyme` (forward) ||
| `AutoEnzyme` (reverse) ||
| `AutoFastDifferentiation` ||
| `AutoFiniteDiff` ||
| `AutoFiniteDifferences` ||
| `AutoForwardDiff` ||
| `AutoMooncake` ||
| `AutoPolyesterForwardDiff` ||
| `AutoReverseDiff` ||
| `AutoSymbolics` ||
| `AutoTracker` ||
| `AutoZygote` ||

## Second order

For second-order operators like [`second_derivative`](@ref), [`hessian`](@ref) and [`hvp`](@ref), there are two main options.
Expand All @@ -81,9 +100,9 @@ In general, using a forward outer backend over a reverse inner backend will yiel
## Backend switch

The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
It takes a function `f` and specifies that `f` should be differentiated with the backend of your choice, instead of whatever other backend the code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing.
At the moment, `DifferentiateWith` only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl).
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.

## Implementations

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
module DifferentiationInterfaceReverseDiffExt

using ADTypes: AutoReverseDiff
using Base: Fix2
import DifferentiationInterface as DI
using DifferentiationInterface:
DerivativePrep, GradientPrep, HessianPrep, JacobianPrep, NoPullbackPrep
Context,
DerivativePrep,
GradientPrep,
HessianPrep,
JacobianPrep,
NoGradientPrep,
NoHessianPrep,
NoJacobianPrep,
NoPullbackPrep,
unwrap,
with_contexts
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
using LinearAlgebra: dot, mul!
using ReverseDiff:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,48 +1,74 @@
## Pullback

DI.prepare_pullback(f, ::AutoReverseDiff, x, ty::NTuple) = NoPullbackPrep()
function DI.prepare_pullback(
f, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{Context,C}
) where {C}
return NoPullbackPrep()
end

function DI.value_and_pullback(
f, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple
)
y = f(x)
f,
::NoPullbackPrep,
::AutoReverseDiff,
x::AbstractArray,
ty::NTuple,
contexts::Vararg{Context,C},
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x)
dotclosure(z, dy) = dot(fc(z), dy)
tx = map(ty) do dy
if y isa Number
dy .* gradient(f, x)
dy .* gradient(fc, x)
elseif y isa AbstractArray
gradient(z -> dot(f(z), dy), x)
gradient(Fix2(dotclosure, dy), x)
end
end
return y, tx
end

function DI.value_and_pullback!(
f, ::NoPullbackPrep, tx::NTuple, ::AutoReverseDiff, x::AbstractArray, ty::NTuple
)
y = f(x)
f,
::NoPullbackPrep,
tx::NTuple,
::AutoReverseDiff,
x::AbstractArray,
ty::NTuple,
contexts::Vararg{Context,C},
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x)
dotclosure(z, dy) = dot(fc(z), dy)
for b in eachindex(tx, ty)
dx, dy = tx[b], ty[b]
if y isa Number
dx = gradient!(dx, f, x)
dx = gradient!(dx, fc, x)
dx .*= dy
elseif y isa AbstractArray
gradient!(dx, z -> dot(f(z), dy), x)
gradient!(dx, Fix2(dotclosure, dy), x)
end
end
return y, tx
end

function DI.value_and_pullback(
f, ::NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple
)
f,
::NoPullbackPrep,
backend::AutoReverseDiff,
x::Number,
ty::NTuple,
contexts::Vararg{Context,C},
) where {C}
x_array = [x]
f_array = f only
y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty)
f_array(x_array, args...) = f(only(x_array), args...)
y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty, contexts...)
return y, only.(tx_array)
end

## Gradient

### Without contexts

struct ReverseDiffGradientPrep{T} <: GradientPrep
tape::T
end
Expand All @@ -56,7 +82,7 @@ function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile}
end

function DI.value_and_gradient!(
f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
Expand All @@ -71,18 +97,55 @@ function DI.value_and_gradient(
return DI.value_and_gradient!(f, grad, prep, backend, x)
end

function DI.gradient!(
_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray
)
function DI.gradient!(_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
return gradient!(grad, prep.tape, x)
end

function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
return gradient!(prep.tape, x)
end

### With contexts

function DI.prepare_gradient(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
return NoGradientPrep()
end

function DI.value_and_gradient!(
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
result = gradient!(result, fc, x)
return DiffResults.value(result), DiffResults.derivative(result)
end

function DI.value_and_gradient(
f, prep::NoGradientPrep, backend::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
grad = similar(x)
return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
end

function DI.gradient!(
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return gradient!(grad, fc, x)
end

function DI.gradient(
f, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return gradient(fc, x)
end

## Jacobian

### Without contexts

struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
tape::T
end
Expand Down Expand Up @@ -116,8 +179,47 @@ function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff,
return jacobian!(prep.tape, x)
end

### With contexts

function DI.prepare_jacobian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
return NoJacobianPrep()
end

function DI.value_and_jacobian!(
f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x)
result = MutableDiffResult(y, (jac,))
result = jacobian!(result, fc, x)
return DiffResults.value(result), DiffResults.derivative(result)
end

function DI.value_and_jacobian(
f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return fc(x), jacobian(fc, x)
end

function DI.jacobian!(
f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return jacobian!(jac, fc, x)
end

function DI.jacobian(
f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return jacobian(fc, x)
end

## Hessian

### Without contexts

struct ReverseDiffHessianPrep{T} <: HessianPrep
tape::T
end
Expand Down Expand Up @@ -152,11 +254,54 @@ end
function DI.value_gradient_and_hessian(
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
)
result = MutableDiffResult(
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, prep.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

### With contexts

function DI.prepare_hessian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
return NoHessianPrep()
end

function DI.hessian!(
f, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return hessian!(hess, fc, x)
end

function DI.hessian(
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
return hessian(fc, x)
end

function DI.value_gradient_and_hessian!(
f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

function DI.value_gradient_and_hessian(
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end
Loading

0 comments on commit 9ec4e7e

Please sign in to comment.