Skip to content

Commit

Permalink
Support mutating functions f!(y, x) (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 14, 2024
1 parent ccb7f43 commit be4f4a8
Show file tree
Hide file tree
Showing 36 changed files with 927 additions and 185 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ An interface to various automatic differentiation backends in Julia.

## Goal

This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either real numbers or abstract arrays.
This package provides a backend-agnostic syntax to differentiate functions of the following types:

It supports in-place versions of every operator, and ensures type stability whenever possible.
- Allocating: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- Mutating: `f!(y, x)` where `y` is an abstract array and `x` can be a real number or an abstract array

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Type |
|:--------------------------------------------------------------------------------|:-----------------------------------------------------------|
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` |
Expand Down Expand Up @@ -64,13 +65,12 @@ julia> grad

## Related packages

- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl. We aim to be less generic (one input, one output, first order only) but more efficient (type stability, memory reuse).
- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl. We provide similar functionality (except for the matrix-like behavior) but cover more backends.
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl.
- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl.

## Roadmap

Goals for future releases:

- implement backend-specific cache objects
- support in-place functions `f!(y, x)`
- define user-facing functions to test and benchmark backends against each other
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Base: get_extension
using DifferentiationInterface
using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface as DI
using Documenter
using DocumenterMermaid
Expand Down Expand Up @@ -49,6 +50,7 @@ makedocs(;
modules=[
ADTypes,
DifferentiationInterface,
DifferentiationInterface.DifferentiationTest,
ChainRulesCoreExt,
DiffractorExt,
EnzymeExt,
Expand Down
9 changes: 8 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ These are not part of the public API.

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["backends.jl", "mode.jl", "utils.jl", "DifferentiationTest.jl"]
Public = false
```

## Submodules

These are not part of the public API.

```@autodocs
Modules = [DifferentiationTest]
```
13 changes: 13 additions & 0 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ AutoReverseDiff
AutoZygote
```

## Accepted functions

| Backend | `f(x) = y` | `f!(y, x)` |
| -------------------------- | ---------- | ---------- |
| `AutoChainRules` | yes | no |
| `AutoDiffractor` | yes | no |
| `AutoEnzyme` | yes | soon |
| `AutoForwardDiff` | yes | yes |
| `AutoFiniteDiff` | yes | soon |
| `AutoPolyesterForwardDiff` | yes | soon |
| `AutoReverseDiff` | yes | soon |
| `AutoZygote` | yes | no |

## Package extensions

```@meta
Expand Down
47 changes: 44 additions & 3 deletions docs/src/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ Advanced users are welcome to code more backends and submit pull requests!

Edge labels correspond to the amount of function calls when applying operators to a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$.


### Forward mode
### Forward mode, allocating functions

```mermaid
flowchart LR
Expand Down Expand Up @@ -69,7 +68,28 @@ flowchart LR
end
```

### Reverse mode
### Forward mode, mutating functions

```mermaid
flowchart LR
subgraph Multiderivative
value_and_multiderivative!
end
value_and_multiderivative! --> value_and_pushforward!
subgraph Jacobian
value_and_jacobian!
end
value_and_jacobian! --> |n|value_and_pushforward!
subgraph Pushforward
value_and_pushforward!
end
```

### Reverse mode, allocating functions

```mermaid
flowchart LR
Expand Down Expand Up @@ -115,3 +135,24 @@ flowchart LR
pullback --> value_and_pullback
end
```

### Reverse mode, mutating functions

```mermaid
flowchart LR
subgraph Multiderivative
value_and_multiderivative!
end
value_and_multiderivative! --> |m|value_and_pullback!
subgraph Jacobian
value_and_jacobian!
end
value_and_jacobian! --> |m|value_and_pullback!
subgraph Pullback
value_and_pullback!
end
```
23 changes: 21 additions & 2 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ We choose the following terminology for the ones we provide:

Most backends have custom implementations for all of these, which we reuse whenever possible.

### Variants
## Variants

Whenever it makes sense, four variants of the same operator are defined:

| **Operator** | **non-mutating** | **mutating** | **non-mutating with primal** | **mutating with primal** |
| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :---------------- | :------------------------ | :------------------------- | :---------------------------------- | :----------------------------------- |
| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A |
| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) |
Expand Down Expand Up @@ -49,3 +49,22 @@ This is especially worth it if you plan to call `operator` several times in simi

By default, all the preparation functions return `nothing`.
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.

## Mutating functions

In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x)` whenever the output is an array.
Since they operate in-place and the primal is computed every time, only four operators are defined:

| **Operator** | **mutating with primal** |
| :---------------- | :----------------------------------- |
| Multiderivative | [`value_and_multiderivative!`](@ref) |
| Jacobian | [`value_and_jacobian!`](@ref) |
| Pushforward (JVP) | [`value_and_pushforward!`](@ref) |
| Pullback (VJP) | [`value_and_pullback!`](@ref) |

Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`.

## Multiple inputs/outputs

Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ function DI.value_and_pushforward(
end

function DI.value_and_pushforward!(
dy, backend::AutoForwardChainRules, f, x, dx, extras=nothing
dy::Union{Number,AbstractArray},
backend::AutoForwardChainRules,
f,
x,
dx,
extras=nothing,
)
y, new_dy = DI.value_and_pushforward(backend, f, x, dx, extras)
return y, update!(dy, new_dy)
Expand All @@ -42,7 +47,12 @@ function DI.value_and_pullback(
end

function DI.value_and_pullback!(
dx, backend::AutoReverseChainRules, f, x, dy, extras=nothing
dx::Union{Number,AbstractArray},
backend::AutoReverseChainRules,
f,
x,
dy,
extras=nothing,
)
y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras)
return y, update!(dx, new_dx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=no
return y, dy
end

function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing)
function DI.value_and_pushforward!(
dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing
)
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
y, new_dy = vpff((dx,))
return y, update!(dy, new_dy)
Expand Down
8 changes: 4 additions & 4 deletions ext/DifferentiationInterfaceEnzymeExt/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode()
## Primitives

function DI.value_and_pushforward!(
_dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing
) where {X,Y<:Real}
_dy::Real, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing
) where {X,Y<:AbstractArray}
dy::AbstractArray, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
Expand Down
31 changes: 23 additions & 8 deletions ext/DifferentiationInterfaceEnzymeExt/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,49 @@ end
## Primitives

function DI.value_and_pullback!(
_dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:Number,Y<:Union{Real,Nothing}}
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy, extras::Nothing=nothing
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end

function DI.value_and_pullback!(
dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::Number,
extras::Nothing=nothing,
)
dx .= zero(eltype(dx))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx))
dx .*= dy
return y, dx
end

function DI.value_and_pullback!(
_dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:Number,Y<:AbstractArray}
_dx::Number,
::AutoReverseEnzyme,
f,
x::Number,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
mf! = MakeFunctionMutating(f)
_, new_dx = only(autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Active(x)))
return y, new_dx
end

function DI.value_and_pullback!(
dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:AbstractArray,Y<:AbstractArray}
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
dx_like_x = zero(x)
mf! = MakeFunctionMutating(f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ const FUNCTION_NOT_INPLACE = Val{false}
## Primitives

function DI.value_and_pushforward!(
dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {Y<:Number,fdtype}
_dy::Number, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {fdtype}
y = f(x)
step(t::Number)::Number = f(x .+ t .* dx)
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
return y, new_dy
end

function DI.value_and_pushforward!(
dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {Y<:AbstractArray,fdtype}
dy::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {fdtype}
y = f(x)
step(t::Number)::AbstractArray = f(x .+ t .* dx)
finite_difference_gradient!(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module DifferentiationInterfaceForwardDiffExt

using ADTypes: AutoForwardDiff
import DifferentiationInterface as DI
using DiffResults: DiffResults
using DocStringExtensions
using ForwardDiff:
Chunk,
Dual,
DerivativeConfig,
GradientConfig,
JacobianConfig,
Tag,
derivative,
derivative!,
extract_derivative,
extract_derivative!,
gradient,
gradient!,
jacobian,
jacobian!,
value
using LinearAlgebra: mul!

choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()

tag_type(::F, ::V) where {F,V<:Number} = Tag{F,V}
tag_type(::F, ::AbstractArray{V}) where {F,V<:Number} = Tag{F,V}

include("non_mutating.jl")
include("mutating.jl")

end # module
Loading

0 comments on commit be4f4a8

Please sign in to comment.