Skip to content

Commit

Permalink
Add second derivative and simplify dispatch (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 16, 2024
1 parent 53fd898 commit 3324c90
Show file tree
Hide file tree
Showing 46 changed files with 991 additions and 614 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This package provides a backend-agnostic syntax to differentiate functions of th

## Compatibility

We support some of the first order backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):
We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Object |
| :------------------------------------------------------------------------------ | :----------------------------------------------------------- |
Expand All @@ -29,8 +29,6 @@ We support some of the first order backends defined by [ADTypes.jl](https://gith
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |

We also provide a second order backend `SecondOrder(reverse_backend, forward_backend)` for hessian computations.

## Example

Setup:
Expand Down
8 changes: 4 additions & 4 deletions benchmark/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ADTypes
using ADTypes: AbstractADType
using BenchmarkTools
using DifferentiationInterface
using DifferentiationInterface: ForwardMode, ReverseMode, autodiff_mode
using DifferentiationInterface: ForwardMode, ReverseMode, mode

BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1

Expand All @@ -12,7 +12,7 @@ pretty_backend(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules{Zygote}"
pretty_backend(::AutoDiffractor) = "Diffractor (forward)"

function pretty_backend(backend::AutoEnzyme)
return autodiff_mode(backend) isa ForwardMode ? "Enzyme (forward)" : "Enzyme (reverse)"
return mode(backend) isa ForwardMode ? "Enzyme (forward)" : "Enzyme (reverse)"
end

pretty_backend(::AutoFiniteDiff) = "FiniteDiff"
Expand Down Expand Up @@ -45,7 +45,7 @@ function add_pushforward_benchmarks!(
dx = n == 1 ? randn() : randn(n)
dy = m == 1 ? 0.0 : zeros(m)

if !isa(autodiff_mode(backend), ForwardMode)
if !isa(mode(backend), ForwardMode)
return nothing
end

Expand Down Expand Up @@ -90,7 +90,7 @@ function add_pullback_benchmarks!(
dx = n == 1 ? 0.0 : zeros(n)
dy = m == 1 ? randn() : randn(m)

if !isa(autodiff_mode(backend), ReverseMode)
if !isa(mode(backend), ReverseMode)
return nothing
end

Expand Down
21 changes: 21 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ Modules = [DifferentiationInterface]
Pages = ["jacobian.jl"]
```

## Second order

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["second_order.jl"]
```

## Second derivative

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["second_derivative.jl"]
```

## Hessian

```@autodocs
Expand All @@ -58,6 +72,13 @@ Modules = [DifferentiationInterface]
Pages = ["pullback.jl"]
```

## Hessian-vector product (HVP)

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["hessian_vector_product.jl"]
```

## Preparation

```@autodocs
Expand Down
7 changes: 4 additions & 3 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ AutoReverseDiff
AutoZygote
```

## [Mutation compatibility](@id mutcompat)
## [Mutation support](@id backend_support_mutation)

All backends are compatible with allocating functions `f(x) = y`. Only some are compatible with mutating functions `f!(y, x) = nothing`:

Expand All @@ -40,14 +40,15 @@ All backends are compatible with allocating functions `f(x) = y`. Only some are
| `AutoReverseDiff()` ||
| `AutoZygote()` ||

## [Second order combinations](@id secondcombin)
## [Second order combinations](@id backend_combination)

For hessian computations, in theory we can combine any pair of backends into a [`SecondOrder`](@ref).
In practice, many combinations will fail.
Here are the ones we tested for you:

| Reverse backend | Forward backend | Hessian tested |
| Inner backend | Outer backend | Hessian tested |
| :------------------ | :--------------------------- | -------------- |
| `AutoForwardDiff()` | `AutoForwardDiff()` ||
| `AutoZygote()` | `AutoForwardDiff()` ||
| `AutoReverseDiff()` | `AutoForwardDiff()` ||
| `AutoZygote()` | `AutoEnzyme(Enzyme.Forward)` ||
Expand Down
27 changes: 0 additions & 27 deletions docs/src/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,3 @@ flowchart LR
value_and_pullback!
end
```

### Second order, scalar-valued functions

```mermaid
flowchart LR
subgraph First order
gradient!
value_and_pushforward!
end
subgraph Hessian-vector product
gradient_and_hessian_vector_product!
gradient_and_hessian_vector_product --> gradient_and_hessian_vector_product!
end
gradient_and_hessian_vector_product! --> gradient!
gradient_and_hessian_vector_product! --> value_and_pushforward!
subgraph Hessian
value_and_gradient_and_hessian!
value_and_gradient_and_hessian --> value_and_gradient_and_hessian!
hessian! --> value_and_gradient_and_hessian!
hessian --> value_and_gradient_and_hessian
end
value_and_gradient_and_hessian! --> |n|gradient_and_hessian_vector_product!
```
74 changes: 38 additions & 36 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,58 @@
## [Operators](@id operators)

Depending on the type of input and output, differentiation operators can have various names.
We choose the following terminology for the first-order operators we provide:
We choose the following terminology for the operators we provide:

| | **scalar output** | **array output** |
| First order | **scalar output** | **array output** |
| ---------------- | ----------------- | ----------------- |
| **scalar input** | `derivative` | `multiderivative` |
| **array input** | `gradient` | `jacobian` |

| Second order | **scalar output** | **array output** |
| ---------------- | ------------------- | ---------------- |
| **scalar input** | `second_derivative` | not implemented |
| **array input** | `hessian` | not implemented |

Most backends have custom implementations for all of these, which we reuse if possible.

## Variants

Whenever it makes sense, four variants of the same operator are defined:

| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :---------------- | :------------------------ | :------------------------- | :---------------------------------- | :----------------------------------- |
| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A |
| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) |
| Gradient | [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
| Jacobian | [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
| Pushforward (JVP) | [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
| Pullback (VJP) | [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :---------------- | :-------------------------- | :------------------------- | :----------------------------------------------- | :------------------------------------ |
| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A |
| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) |
| Gradient | [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
| Jacobian | [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
| Second derivative | [`second_derivative`](@ref) | N/A | [`value_derivative_and_second_derivative`](@ref) | N/A |
| Hessian | [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) |

Note that scalar outputs can't be mutated, which is why `derivative` and `second_derivative` do not have mutating variants.

Note that scalar outputs can't be mutated, which is why `derivative` doesn't have mutating variants.
For advanced users, lower-level operators are also exposed:

| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :--------------------------- | :------------------------------- | :-------------------------------- | :-------------------------------------------- | :--------------------------------------------- |
| Pushforward (JVP) | [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
| Pullback (VJP) | [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
| Hessian-vector product (HVP) | [`hessian_vector_product`](@ref) | [`hessian_vector_product!`](@ref) | [`gradient_and_hessian_vector_product`](@ref) | [`gradient_and_hessian_vector_product!`](@ref) |

## Preparation

In many cases, automatic differentiation can be accelerated if the function has been run at least once (e.g. to record a tape) and if some cache objects are provided.
This is a backend-specific procedure, but we expose a common syntax to achieve it.

| **Operator** | **preparation function** |
| :---------------- | :-------------------------------- |
| Derivative | [`prepare_derivative`](@ref) |
| Multiderivative | [`prepare_multiderivative`](@ref) |
| Gradient | [`prepare_gradient`](@ref) |
| Jacobian | [`prepare_jacobian`](@ref) |
| Pushforward (JVP) | [`prepare_pushforward`](@ref) |
| Pullback (VJP) | [`prepare_pullback`](@ref) |
| **Operator** | **preparation function** |
| :---------------- | :---------------------------------- |
| Derivative | [`prepare_derivative`](@ref) |
| Multiderivative | [`prepare_multiderivative`](@ref) |
| Gradient | [`prepare_gradient`](@ref) |
| Jacobian | [`prepare_jacobian`](@ref) |
| Second derivative | [`prepare_second_derivative`](@ref) |
| Hessian | [`prepare_hessian`](@ref) |
| Pushforward (JVP) | [`prepare_pushforward`](@ref) |
| Pullback (VJP) | [`prepare_pullback`](@ref) |

If you run `prepare_operator(backend, f, x)`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x`, but it should work with different _values_ of `x`.
Expand All @@ -65,25 +80,12 @@ Since `f!` operates in-place and the primal is computed every time, only four op

Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`.

Check out the list of [backends that support mutating functions](@ref mutcompat).

## Second order

For array-to-scalar functions, the Hessian matrix is of significant interest.
That is why we provide the following second-order operators:

| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :----------- | :---------------- | :----------------- | ---------------------------------------- | ----------------------------------------- |
| Hessian | [`hessian`](@ref) | [`hessian!`](@ref) | [`value_and_gradient_and_hessian`](@ref) | [`value_and_gradient_and_hessian!`](@ref) |

When the Hessian is too costly to allocate entirely, its products with vectors can be cheaper to compute.
Check out the list of [backends that support mutating functions](@ref backend_support_mutation).

| **Operator** | **allocating** | **mutating** |
| :--------------------------- | :-------------------------------------------- | :--------------------------------------------- |
| Hessian-vector product (HVP) | [`gradient_and_hessian_vector_product`](@ref) | [`gradient_and_hessian_vector_product!`](@ref) |
## Second order backends

At the moment, second order operators can only be used with a specific backend of type [`SecondOrder`](@ref).
Check out the [compatibility table between backends](@ref secondcombin).
Second order differentiation operators can only be used with a backend of type [`SecondOrder`](@ref).
Check out the [combination table](@ref backend_combination) between backends.

## Multiple inputs/outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ ruleconfig(backend::AutoChainRules) = backend.ruleconfig
const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}}
const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}

DI.autodiff_mode(::AutoForwardChainRules) = DI.ForwardMode()
DI.autodiff_mode(::AutoReverseChainRules) = DI.ReverseMode()
DI.mode(::AutoForwardChainRules) = DI.ForwardMode()
DI.mode(::AutoReverseChainRules) = DI.ReverseMode()

## Primitives

function DI.value_and_pushforward(backend::AutoForwardChainRules, f, x, dx, extras::Nothing)
function DI.value_and_pushforward(
backend::AutoForwardChainRules, f, x, dx, extras::Nothing=nothing
)
rc = ruleconfig(backend)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, new_dy
Expand All @@ -35,15 +37,22 @@ function DI.value_and_pushforward!(
return y, update!(dy, new_dy)
end

function DI.value_and_pullback(backend::AutoReverseChainRules, f, x, dy, extras::Nothing)
function DI.value_and_pullback(
backend::AutoReverseChainRules, f, x, dy, extras::Nothing=nothing
)
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
return y, new_dx
end

function DI.value_and_pullback!(
dx::Union{Number,AbstractArray}, backend::AutoReverseChainRules, f, x, dy, extras
dx::Union{Number,AbstractArray},
backend::AutoReverseChainRules,
f,
x,
dy,
extras=nothing,
)
y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras)
return y, update!(dx, new_dx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import DifferentiationInterface as DI
using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig
using DocStringExtensions

DI.autodiff_mode(::AutoDiffractor) = DI.ForwardMode()
DI.autodiff_mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode()
DI.mode(::AutoDiffractor) = DI.ForwardMode()
DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode()

function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing)
function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=nothing)
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
y, dy = vpff((dx,))
return y, dy
end

function DI.value_and_pushforward!(
dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing
dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing
)
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
y, new_dy = vpff((dx,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ AutoEnzyme
const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}

function DI.autodiff_mode(::AutoEnzyme)
function DI.mode(::AutoEnzyme)
return error(
"You need to specify the Enzyme mode with `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)`",
)
end

DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode()
DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode()
DI.mode(::AutoForwardEnzyme) = DI.ForwardMode()
DI.mode(::AutoReverseEnzyme) = DI.ReverseMode()

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
Expand Down
18 changes: 11 additions & 7 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,51 @@
## Primitives

function DI.value_and_pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
end

function DI.pushforward!(_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
function DI.pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return new_dy
end

function DI.pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dy .= new_dy
return dy
end

function DI.value_and_pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
function DI.value_and_pushforward(
backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, dy
end

function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing)
dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return dy
end

## Utilities

function DI.value_and_jacobian(
backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing
backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
jac = jacobian(backend.mode, f, x)
Expand Down
Loading

0 comments on commit 3324c90

Please sign in to comment.