Skip to content

Commit

Permalink
Hessian (#49)
Browse files Browse the repository at this point in the history
* Hessian

* Shuffle
  • Loading branch information
gdalle authored Mar 16, 2024
1 parent a2805cf commit 53fd898
Show file tree
Hide file tree
Showing 25 changed files with 1,143 additions and 565 deletions.
21 changes: 6 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,34 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[weakdeps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
DifferentiationInterfaceDiffractorExt = [
"Diffractor",
"AbstractDifferentiation",
]
DifferentiationInterfaceDiffractorExt = ["Diffractor", "AbstractDifferentiation"]
DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfacePolyesterForwardDiffExt = [
"PolyesterForwardDiff",
"ForwardDiff",
"DiffResults",
]
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceTestExt = ["ForwardDiff", "JET", "Random", "Test"]
DifferentiationInterfaceTestExt = ["ForwardDiff", "JET", "Test"]
DifferentiationInterfaceZygoteExt = ["Zygote"]

[compat]
AbstractDifferentiation = "0.6"
ADTypes = "0.2.7"
AbstractDifferentiation = "0.6"
ChainRulesCore = "1.19"
Diffractor = "0.2"
DiffResults = "1.1"
Diffractor = "0.2"
DocStringExtensions = "0.9"
Enzyme = "0.11"
FillArrays = "1"
Expand All @@ -56,7 +48,6 @@ ForwardDiff = "0.10"
JET = "0.8"
LinearAlgebra = "1"
PolyesterForwardDiff = "0.1"
Random = "1"
ReverseDiff = "1.15"
Test = "1"
Zygote = "0.6"
Expand Down
37 changes: 20 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,29 @@ This package provides a backend-agnostic syntax to differentiate functions of th

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Object | Allocating | Mutating |
| :------------------------------------------------------------------------------ | :-------------------------------------- | :--------- | :------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |||
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |||
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (forward) | `AutoEnzyme(Enzyme.Forward)` |||
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (reverse) | `AutoEnzyme(Enzyme.Reverse)` |||
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` || soon |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |||
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |||
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |||
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |||
We support some of the first order backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Object |
| :------------------------------------------------------------------------------ | :----------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)` |
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |

We also provide a second order backend `SecondOrder(reverse_backend, forward_backend)` for hessian computations.

## Example

Setup:

```jldoctest readme
julia> import DifferentiationInterface, ADTypes, ForwardDiff
julia> import ADTypes, ForwardDiff
julia> using DifferentiationInterface
julia> backend = ADTypes.AutoForwardDiff();
Expand All @@ -45,7 +48,7 @@ julia> f(x) = sum(abs2, x);
Out-of-place gradient:

```jldoctest readme
julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.])
julia> value_and_gradient(backend, f, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
```

Expand All @@ -54,7 +57,7 @@ In-place gradient:
```jldoctest readme
julia> grad = zeros(3);
julia> DifferentiationInterface.value_and_gradient!(grad, backend, f, [1., 2., 3.])
julia> value_and_gradient!(grad, backend, f, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
julia> grad
Expand All @@ -73,5 +76,5 @@ julia> grad

Goals for future releases:

- implement backend-specific cache objects
- optimize performance for each backend
- define user-facing functions to test and benchmark backends against each other
9 changes: 8 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ Modules = [DifferentiationInterface]
Pages = ["jacobian.jl"]
```

## Hessian

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["hessian.jl"]
```

## Pushforward (JVP)

```@autodocs
Expand Down Expand Up @@ -73,4 +80,4 @@ These are not part of the public API.

```@autodocs
Modules = [DifferentiationTest, Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTestExt)]
```
```
28 changes: 28 additions & 0 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,34 @@ AutoReverseDiff
AutoZygote
```

## [Mutation compatibility](@id mutcompat)

All backends are compatible with allocating functions `f(x) = y`. Only some are compatible with mutating functions `f!(y, x) = nothing`:

| Backend | Mutating functions |
| :-------------------------------------- | ------------------ |
| `AutoChainRules(ruleconfig)` ||
| `AutoDiffractor()` ||
| `AutoEnzyme(Enzyme.Forward)` ||
| `AutoEnzyme(Enzyme.Reverse)` ||
| `AutoFiniteDiff()` | soon |
| `AutoForwardDiff()` ||
| `AutoPolyesterForwardDiff(; chunksize)` ||
| `AutoReverseDiff()` ||
| `AutoZygote()` ||

## [Second order combinations](@id secondcombin)

For hessian computations, in theory we can combine any pair of backends into a [`SecondOrder`](@ref).
In practice, many combinations will fail.
Here are the ones we tested for you:

| Reverse backend | Forward backend | Hessian tested |
| :------------------ | :--------------------------- | -------------- |
| `AutoZygote()` | `AutoForwardDiff()` ||
| `AutoReverseDiff()` | `AutoForwardDiff()` ||
| `AutoZygote()` | `AutoEnzyme(Enzyme.Forward)` ||

## Package extensions

```@meta
Expand Down
27 changes: 27 additions & 0 deletions docs/src/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,30 @@ flowchart LR
value_and_pullback!
end
```

### Second order, scalar-valued functions

```mermaid
flowchart LR
subgraph First order
gradient!
value_and_pushforward!
end
subgraph Hessian-vector product
gradient_and_hessian_vector_product!
gradient_and_hessian_vector_product --> gradient_and_hessian_vector_product!
end
gradient_and_hessian_vector_product! --> gradient!
gradient_and_hessian_vector_product! --> value_and_pushforward!
subgraph Hessian
value_and_gradient_and_hessian!
value_and_gradient_and_hessian --> value_and_gradient_and_hessian!
hessian! --> value_and_gradient_and_hessian!
hessian --> value_and_gradient_and_hessian
end
value_and_gradient_and_hessian! --> |n|gradient_and_hessian_vector_product!
```
29 changes: 25 additions & 4 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
## [Operators](@id operators)

Depending on the type of input and output, differentiation operators can have various names.
We choose the following terminology for the ones we provide:
We choose the following terminology for the first-order operators we provide:

| | **scalar output** | **array output** |
| ---------------- | ----------------- | ----------------- |
| **scalar input** | `derivative` | `multiderivative` |
| **array input** | `gradient` | `jacobian` |

Most backends have custom implementations for all of these, which we reuse whenever possible.
Most backends have custom implementations for all of these, which we reuse if possible.

## Variants

Expand Down Expand Up @@ -52,8 +52,9 @@ We do not make any guarantees on their implementation for each backend, or on th

## Mutating functions

In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x) = nothing` whenever the output is an array (beware that it must return `nothing`).
Since they operate in-place and the primal is computed every time, only four operators are defined:
In addition to allocating functions `f(x) = y`, some backends also support mutating functions `f!(y, x) = nothing` whenever the output is an array.
Beware that the function `f!` must return `nothing`!.
Since `f!` operates in-place and the primal is computed every time, only four operators are defined:

| **Operator** | **mutating with primal** |
| :---------------- | :----------------------------------- |
Expand All @@ -64,6 +65,26 @@ Since they operate in-place and the primal is computed every time, only four ope

Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`.

Check out the list of [backends that support mutating functions](@ref mutcompat).

## Second order

For array-to-scalar functions, the Hessian matrix is of significant interest.
That is why we provide the following second-order operators:

| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :----------- | :---------------- | :----------------- | ---------------------------------------- | ----------------------------------------- |
| Hessian | [`hessian`](@ref) | [`hessian!`](@ref) | [`value_and_gradient_and_hessian`](@ref) | [`value_and_gradient_and_hessian!`](@ref) |

When the Hessian is too costly to allocate entirely, its products with vectors can be cheaper to compute.

| **Operator** | **allocating** | **mutating** |
| :--------------------------- | :-------------------------------------------- | :--------------------------------------------- |
| Hessian-vector product (HVP) | [`gradient_and_hessian_vector_product`](@ref) | [`gradient_and_hessian_vector_product!`](@ref) |

At the moment, second order operators can only be used with a specific backend of type [`SecondOrder`](@ref).
Check out the [compatibility table between backends](@ref secondcombin).

## Multiple inputs/outputs

Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ using LinearAlgebra: dot
# new dependencies
using ForwardDiff: ForwardDiff
using JET: @test_opt
using Random: AbstractRNG, default_rng, randn!
using Test: @test, @testset

include("scenarios.jl")
include("correctness.jl")
include("type_stability.jl")
include("test_allocating.jl")
include("test_mutating.jl")

Expand Down
Loading

0 comments on commit 53fd898

Please sign in to comment.