diff --git a/README.md b/README.md index daf142af7..6d9b82c57 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ This package provides a backend-agnostic syntax to differentiate functions of th ## Compatibility -We support some of the first order backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): +We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): | Backend | Object | | :------------------------------------------------------------------------------ | :----------------------------------------------------------- | @@ -29,8 +29,6 @@ We support some of the first order backends defined by [ADTypes.jl](https://gith | [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` | | [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` | -We also provide a second order backend `SecondOrder(reverse_backend, forward_backend)` for hessian computations. - ## Example Setup: diff --git a/benchmark/utils.jl b/benchmark/utils.jl index a4b5d2707..a2c675f54 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -2,7 +2,7 @@ using ADTypes using ADTypes: AbstractADType using BenchmarkTools using DifferentiationInterface -using DifferentiationInterface: ForwardMode, ReverseMode, autodiff_mode +using DifferentiationInterface: ForwardMode, ReverseMode, mode BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1 @@ -12,7 +12,7 @@ pretty_backend(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules{Zygote}" pretty_backend(::AutoDiffractor) = "Diffractor (forward)" function pretty_backend(backend::AutoEnzyme) - return autodiff_mode(backend) isa ForwardMode ? "Enzyme (forward)" : "Enzyme (reverse)" + return mode(backend) isa ForwardMode ? "Enzyme (forward)" : "Enzyme (reverse)" end pretty_backend(::AutoFiniteDiff) = "FiniteDiff" @@ -45,7 +45,7 @@ function add_pushforward_benchmarks!( dx = n == 1 ? randn() : randn(n) dy = m == 1 ? 0.0 : zeros(m) - if !isa(autodiff_mode(backend), ForwardMode) + if !isa(mode(backend), ForwardMode) return nothing end @@ -90,7 +90,7 @@ function add_pullback_benchmarks!( dx = n == 1 ? 0.0 : zeros(n) dy = m == 1 ? randn() : randn(m) - if !isa(autodiff_mode(backend), ReverseMode) + if !isa(mode(backend), ReverseMode) return nothing end diff --git a/docs/src/api.md b/docs/src/api.md index 221846c2e..3c4f5cb0e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -37,6 +37,20 @@ Modules = [DifferentiationInterface] Pages = ["jacobian.jl"] ``` +## Second order + +```@autodocs +Modules = [DifferentiationInterface] +Pages = ["second_order.jl"] +``` + +## Second derivative + +```@autodocs +Modules = [DifferentiationInterface] +Pages = ["second_derivative.jl"] +``` + ## Hessian ```@autodocs @@ -58,6 +72,13 @@ Modules = [DifferentiationInterface] Pages = ["pullback.jl"] ``` +## Hessian-vector product (HVP) + +```@autodocs +Modules = [DifferentiationInterface] +Pages = ["hessian_vector_product.jl"] +``` + ## Preparation ```@autodocs diff --git a/docs/src/backends.md b/docs/src/backends.md index 38dac13a8..4d518d2bd 100644 --- a/docs/src/backends.md +++ b/docs/src/backends.md @@ -24,7 +24,7 @@ AutoReverseDiff AutoZygote ``` -## [Mutation compatibility](@id mutcompat) +## [Mutation support](@id backend_support_mutation) All backends are compatible with allocating functions `f(x) = y`. Only some are compatible with mutating functions `f!(y, x) = nothing`: @@ -40,14 +40,15 @@ All backends are compatible with allocating functions `f(x) = y`. Only some are | `AutoReverseDiff()` | ✓ | | `AutoZygote()` | ✗ | -## [Second order combinations](@id secondcombin) +## [Second order combinations](@id backend_combination) For hessian computations, in theory we can combine any pair of backends into a [`SecondOrder`](@ref). In practice, many combinations will fail. Here are the ones we tested for you: -| Reverse backend | Forward backend | Hessian tested | +| Inner backend | Outer backend | Hessian tested | | :------------------ | :--------------------------- | -------------- | +| `AutoForwardDiff()` | `AutoForwardDiff()` | ✓ | | `AutoZygote()` | `AutoForwardDiff()` | ✓ | | `AutoReverseDiff()` | `AutoForwardDiff()` | ✓ | | `AutoZygote()` | `AutoEnzyme(Enzyme.Forward)` | ✓ | diff --git a/docs/src/developer.md b/docs/src/developer.md index 1f6d29b34..b3bf9a2a6 100644 --- a/docs/src/developer.md +++ b/docs/src/developer.md @@ -156,30 +156,3 @@ flowchart LR value_and_pullback! end ``` - -### Second order, scalar-valued functions - -```mermaid -flowchart LR - subgraph First order - gradient! - value_and_pushforward! - end - - subgraph Hessian-vector product - gradient_and_hessian_vector_product! - gradient_and_hessian_vector_product --> gradient_and_hessian_vector_product! - end - - gradient_and_hessian_vector_product! --> gradient! - gradient_and_hessian_vector_product! --> value_and_pushforward! - - subgraph Hessian - value_and_gradient_and_hessian! - value_and_gradient_and_hessian --> value_and_gradient_and_hessian! - hessian! --> value_and_gradient_and_hessian! - hessian --> value_and_gradient_and_hessian - end - - value_and_gradient_and_hessian! --> |n|gradient_and_hessian_vector_product! -``` diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index 39abd0316..3edaa83ea 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -3,43 +3,58 @@ ## [Operators](@id operators) Depending on the type of input and output, differentiation operators can have various names. -We choose the following terminology for the first-order operators we provide: +We choose the following terminology for the operators we provide: -| | **scalar output** | **array output** | +| First order | **scalar output** | **array output** | | ---------------- | ----------------- | ----------------- | | **scalar input** | `derivative` | `multiderivative` | | **array input** | `gradient` | `jacobian` | +| Second order | **scalar output** | **array output** | +| ---------------- | ------------------- | ---------------- | +| **scalar input** | `second_derivative` | not implemented | +| **array input** | `hessian` | not implemented | + Most backends have custom implementations for all of these, which we reuse if possible. ## Variants Whenever it makes sense, four variants of the same operator are defined: -| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** | -| :---------------- | :------------------------ | :------------------------- | :---------------------------------- | :----------------------------------- | -| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A | -| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) | -| Gradient | [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | -| Jacobian | [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | -| Pushforward (JVP) | [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | -| Pullback (VJP) | [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | +| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** | +| :---------------- | :-------------------------- | :------------------------- | :----------------------------------------------- | :------------------------------------ | +| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A | +| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) | +| Gradient | [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | +| Jacobian | [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | +| Second derivative | [`second_derivative`](@ref) | N/A | [`value_derivative_and_second_derivative`](@ref) | N/A | +| Hessian | [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) | + +Note that scalar outputs can't be mutated, which is why `derivative` and `second_derivative` do not have mutating variants. -Note that scalar outputs can't be mutated, which is why `derivative` doesn't have mutating variants. +For advanced users, lower-level operators are also exposed: + +| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** | +| :--------------------------- | :------------------------------- | :-------------------------------- | :-------------------------------------------- | :--------------------------------------------- | +| Pushforward (JVP) | [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | +| Pullback (VJP) | [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | +| Hessian-vector product (HVP) | [`hessian_vector_product`](@ref) | [`hessian_vector_product!`](@ref) | [`gradient_and_hessian_vector_product`](@ref) | [`gradient_and_hessian_vector_product!`](@ref) | ## Preparation In many cases, automatic differentiation can be accelerated if the function has been run at least once (e.g. to record a tape) and if some cache objects are provided. This is a backend-specific procedure, but we expose a common syntax to achieve it. -| **Operator** | **preparation function** | -| :---------------- | :-------------------------------- | -| Derivative | [`prepare_derivative`](@ref) | -| Multiderivative | [`prepare_multiderivative`](@ref) | -| Gradient | [`prepare_gradient`](@ref) | -| Jacobian | [`prepare_jacobian`](@ref) | -| Pushforward (JVP) | [`prepare_pushforward`](@ref) | -| Pullback (VJP) | [`prepare_pullback`](@ref) | +| **Operator** | **preparation function** | +| :---------------- | :---------------------------------- | +| Derivative | [`prepare_derivative`](@ref) | +| Multiderivative | [`prepare_multiderivative`](@ref) | +| Gradient | [`prepare_gradient`](@ref) | +| Jacobian | [`prepare_jacobian`](@ref) | +| Second derivative | [`prepare_second_derivative`](@ref) | +| Hessian | [`prepare_hessian`](@ref) | +| Pushforward (JVP) | [`prepare_pushforward`](@ref) | +| Pullback (VJP) | [`prepare_pullback`](@ref) | If you run `prepare_operator(backend, f, x)`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants. This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x`, but it should work with different _values_ of `x`. @@ -65,25 +80,12 @@ Since `f!` operates in-place and the primal is computed every time, only four op Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`. -Check out the list of [backends that support mutating functions](@ref mutcompat). - -## Second order - -For array-to-scalar functions, the Hessian matrix is of significant interest. -That is why we provide the following second-order operators: - -| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** | -| :----------- | :---------------- | :----------------- | ---------------------------------------- | ----------------------------------------- | -| Hessian | [`hessian`](@ref) | [`hessian!`](@ref) | [`value_and_gradient_and_hessian`](@ref) | [`value_and_gradient_and_hessian!`](@ref) | - -When the Hessian is too costly to allocate entirely, its products with vectors can be cheaper to compute. +Check out the list of [backends that support mutating functions](@ref backend_support_mutation). -| **Operator** | **allocating** | **mutating** | -| :--------------------------- | :-------------------------------------------- | :--------------------------------------------- | -| Hessian-vector product (HVP) | [`gradient_and_hessian_vector_product`](@ref) | [`gradient_and_hessian_vector_product!`](@ref) | +## Second order backends -At the moment, second order operators can only be used with a specific backend of type [`SecondOrder`](@ref). -Check out the [compatibility table between backends](@ref secondcombin). +Second order differentiation operators can only be used with a backend of type [`SecondOrder`](@ref). +Check out the [combination table](@ref backend_combination) between backends. ## Multiple inputs/outputs diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index 68e8dc7bf..f2a088448 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -12,12 +12,14 @@ ruleconfig(backend::AutoChainRules) = backend.ruleconfig const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} -DI.autodiff_mode(::AutoForwardChainRules) = DI.ForwardMode() -DI.autodiff_mode(::AutoReverseChainRules) = DI.ReverseMode() +DI.mode(::AutoForwardChainRules) = DI.ForwardMode() +DI.mode(::AutoReverseChainRules) = DI.ReverseMode() ## Primitives -function DI.value_and_pushforward(backend::AutoForwardChainRules, f, x, dx, extras::Nothing) +function DI.value_and_pushforward( + backend::AutoForwardChainRules, f, x, dx, extras::Nothing=nothing +) rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) return y, new_dy @@ -35,7 +37,9 @@ function DI.value_and_pushforward!( return y, update!(dy, new_dy) end -function DI.value_and_pullback(backend::AutoReverseChainRules, f, x, dy, extras::Nothing) +function DI.value_and_pullback( + backend::AutoReverseChainRules, f, x, dy, extras::Nothing=nothing +) rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) @@ -43,7 +47,12 @@ function DI.value_and_pullback(backend::AutoReverseChainRules, f, x, dy, extras: end function DI.value_and_pullback!( - dx::Union{Number,AbstractArray}, backend::AutoReverseChainRules, f, x, dy, extras + dx::Union{Number,AbstractArray}, + backend::AutoReverseChainRules, + f, + x, + dy, + extras=nothing, ) y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras) return y, update!(dx, new_dx) diff --git a/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 854f04eea..0da5691ff 100644 --- a/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -7,17 +7,17 @@ import DifferentiationInterface as DI using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig using DocStringExtensions -DI.autodiff_mode(::AutoDiffractor) = DI.ForwardMode() -DI.autodiff_mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode() +DI.mode(::AutoDiffractor) = DI.ForwardMode() +DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode() -function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing) +function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=nothing) vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) y, dy = vpff((dx,)) return y, dy end function DI.value_and_pushforward!( - dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing + dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing ) vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) y, new_dy = vpff((dx,)) diff --git a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 6cf03437d..cd4f1e78b 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -29,14 +29,14 @@ AutoEnzyme const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode} const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode} -function DI.autodiff_mode(::AutoEnzyme) +function DI.mode(::AutoEnzyme) return error( "You need to specify the Enzyme mode with `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)`", ) end -DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode() -DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode() +DI.mode(::AutoForwardEnzyme) = DI.ForwardMode() +DI.mode(::AutoReverseEnzyme) = DI.ReverseMode() # Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T} diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl b/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl index f07c7ab7e..bb6d4bda3 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl @@ -1,39 +1,43 @@ ## Primitives function DI.value_and_pushforward!( - _dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing + _dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing ) y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) return y, new_dy end function DI.value_and_pushforward!( - dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing + dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing ) y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) dy .= new_dy return y, dy end -function DI.pushforward!(_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) +function DI.pushforward!( + _dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing +) new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) return new_dy end function DI.pushforward!( - dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing + dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing ) new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) dy .= new_dy return dy end -function DI.value_and_pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) +function DI.value_and_pushforward( + backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing +) y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) return y, dy end -function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) +function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing) dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) return dy end @@ -41,7 +45,7 @@ end ## Utilities function DI.value_and_jacobian( - backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing + backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing=nothing ) y = f(x) jac = jacobian(backend.mode, f, x) diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl b/ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl index 060cd891a..6bd83b543 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl @@ -7,7 +7,7 @@ function DI.value_and_pushforward!( f!, x, dx, - extras::Nothing, + extras::Nothing=nothing, ) dx_sametype = convert(typeof(x), dx) dy_sametype = convert(typeof(y), dy) diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl index 57f748d94..6751d5062 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl @@ -12,7 +12,7 @@ end ## Primitives function DI.value_and_pullback!( - _dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing + _dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing ) der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) new_dx = dy * only(der) @@ -20,7 +20,12 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing + dx::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + dy::Number, + extras::Nothing=nothing, ) dx .= zero(eltype(dx)) _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) @@ -29,7 +34,7 @@ function DI.value_and_pullback!( end function DI.pullback!( - _dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing + _dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing ) der = only(autodiff(Reverse, f, Active, Active(x))) new_dx = dy * only(der) @@ -37,7 +42,12 @@ function DI.pullback!( end function DI.pullback!( - dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing + dx::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + dy::Number, + extras::Nothing=nothing, ) dx .= zero(eltype(dx)) autodiff(Reverse, f, Active, Duplicated(x, dx)) @@ -46,7 +56,12 @@ function DI.pullback!( end function DI.value_and_pullback!( - dx::Number, backend::AutoReverseEnzyme, f, x::Number, dy::AbstractArray, extras::Nothing + dx::Number, + backend::AutoReverseEnzyme, + f, + x::Number, + dy::AbstractArray, + extras::Nothing=nothing, ) y = f(x) f! = MakeFunctionMutating(f) @@ -59,7 +74,7 @@ function DI.value_and_pullback!( f, x::AbstractArray, dy::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) y = f(x) f! = MakeFunctionMutating(f) @@ -69,27 +84,29 @@ end ## Utilities function DI.value_and_gradient!( - grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing ) y = f(x) gradient!(Reverse, grad, f, x) return y, grad end -function DI.value_and_gradient(::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing) +function DI.value_and_gradient( + ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing +) y = f(x) grad = gradient(Reverse, f, x) return y, grad end function DI.gradient!( - grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing ) gradient!(Reverse, grad, f, x) return grad end -function DI.gradient(::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing) +function DI.gradient(::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing) grad = gradient(Reverse, f, x) return grad end diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl index 898f0085f..2370cd2af 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl @@ -7,7 +7,7 @@ function DI.value_and_pullback!( f!, x::Number, dy::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) _, dx = only(autodiff(Reverse, f!, Const, Duplicated(y, copy(dy)), Active(x))) return y, dx @@ -20,7 +20,7 @@ function DI.value_and_pullback!( f!, x::AbstractArray, dy::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) dx_sametype = convert(typeof(x), dx) dx_sametype .= zero(eltype(dx_sametype)) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index 6a36081a2..0f2055eda 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -17,7 +17,7 @@ const FUNCTION_NOT_INPLACE = Val{false} ## Primitives function DI.value_and_pushforward!( - _dy::Number, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing + _dy::Number, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing ) where {fdtype} y = f(x) step(t::Number)::Number = f(x .+ t .* dx) @@ -26,7 +26,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing + dy::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing ) where {fdtype} y = f(x) step(t::Number)::AbstractArray = f(x .+ t .* dx) @@ -39,7 +39,7 @@ end ## Utilities function DI.value_and_derivative( - ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing + ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing=nothing ) where {fdtype} y = f(x) der = finite_difference_derivative(f, x, fdtype, eltype(y), y) @@ -47,7 +47,7 @@ function DI.value_and_derivative( end function DI.value_and_multiderivative!( - multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing + multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing=nothing ) where {fdtype} y = f(x) finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -55,7 +55,7 @@ function DI.value_and_multiderivative!( end function DI.value_and_multiderivative( - ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing + ::AutoFiniteDiff{fdtype}, f, x::Number, extras::Nothing=nothing ) where {fdtype} y = f(x) multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -63,7 +63,11 @@ function DI.value_and_multiderivative( end function DI.value_and_gradient!( - grad::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, ) where {fdtype} y = f(x) finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -71,7 +75,7 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - ::AutoFiniteDiff{fdtype}, f, x::AbstractArray, extras::Nothing + ::AutoFiniteDiff{fdtype}, f, x::AbstractArray, extras::Nothing=nothing ) where {fdtype} y = f(x) grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -79,7 +83,7 @@ function DI.value_and_gradient( end function DI.value_and_jacobian( - ::AutoFiniteDiff{fdtype}, f, x::AbstractArray, extras::Nothing + ::AutoFiniteDiff{fdtype}, f, x::AbstractArray, extras::Nothing=nothing ) where {fdtype} y = f(x) jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) @@ -87,7 +91,11 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::AutoFiniteDiff, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + backend::AutoFiniteDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) y, new_jac = DI.value_and_jacobian(backend, f, x, extras) jac .= new_jac diff --git a/ext/DifferentiationInterfaceForwardDiffExt/allocating.jl b/ext/DifferentiationInterfaceForwardDiffExt/allocating.jl index 9f0084e0d..67f5cfa93 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt/allocating.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt/allocating.jl @@ -1,7 +1,7 @@ ## Pushforward function DI.value_and_pushforward!( - _dy::Real, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing + _dy::Real, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing=nothing ) T = tag_type(f, x) xdual = Dual{T}(x, dx) @@ -12,7 +12,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::AbstractArray, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing + dy::AbstractArray, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing=nothing ) T = tag_type(f, x) xdual = Dual{T}(x, dx) @@ -23,7 +23,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - _dy::Real, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing + _dy::Real, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing=nothing ) T = tag_type(f, x) xdual = Dual{T}.(x, dx) @@ -34,7 +34,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing + dy::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing=nothing ) T = tag_type(f, x) xdual = Dual{T}.(x, dx) @@ -46,20 +46,20 @@ end ## Derivative -function DI.derivative(::AutoForwardDiff, f, x::Number, extras::Nothing) +function DI.derivative(::AutoForwardDiff, f, x::Number, extras::Nothing=nothing) return derivative(f, x) end ## Multiderivative function DI.multiderivative!( - multider::AbstractArray, ::AutoForwardDiff, f, x::Number, extras::Nothing + multider::AbstractArray, ::AutoForwardDiff, f, x::Number, extras::Nothing=nothing ) derivative!(multider, f, x) return multider end -function DI.multiderivative(::AutoForwardDiff, f, x::Number, extras::Nothing) +function DI.multiderivative(::AutoForwardDiff, f, x::Number, extras::Nothing=nothing) return derivative(f, x) end @@ -68,27 +68,35 @@ end ### Unprepared function DI.value_and_gradient!( - grad::AbstractArray, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + backend::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) config = DI.prepare_gradient(backend, f, x) return DI.value_and_gradient!(grad, backend, f, x, config) end function DI.value_and_gradient( - backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing=nothing ) config = DI.prepare_gradient(backend, f, x) return DI.value_and_gradient(backend, f, x, config) end function DI.gradient!( - grad::AbstractArray, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + backend::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) config = DI.prepare_gradient(backend, f, x) return DI.gradient!(grad, backend, f, x, config) end -function DI.gradient(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing) +function DI.gradient(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing=nothing) config = DI.prepare_gradient(backend, f, x) return DI.gradient(backend, f, x, config) end @@ -127,27 +135,35 @@ end ### Unprepared function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + backend::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) config = DI.prepare_jacobian(backend, f, x) return DI.value_and_jacobian!(jac, backend, f, x, config) end function DI.value_and_jacobian( - backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing=nothing ) config = DI.prepare_jacobian(backend, f, x) return DI.value_and_jacobian(backend, f, x, config) end function DI.jacobian!( - jac::AbstractMatrix, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + backend::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) config = DI.prepare_jacobian(backend, f, x) return DI.jacobian!(jac, backend, f, x, config) end -function DI.jacobian(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing) +function DI.jacobian(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing=nothing) config = DI.prepare_jacobian(backend, f, x) return DI.jacobian(backend, f, x, config) end diff --git a/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl b/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl index 089cb50fa..05371f539 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl @@ -1,7 +1,13 @@ ## Pushforward function DI.value_and_pushforward!( - y::AbstractArray, dy::AbstractArray, ::AutoForwardDiff, f!, x::Real, dx, extras::Nothing + y::AbstractArray, + dy::AbstractArray, + ::AutoForwardDiff, + f!, + x::Real, + dx, + extras::Nothing=nothing, ) T = tag_type(f!, x) xdual = Dual{T}(x, dx) @@ -19,7 +25,7 @@ function DI.value_and_pushforward!( f!, x::AbstractArray, dx, - extras::Nothing, + extras::Nothing=nothing, ) T = tag_type(f!, x) xdual = Dual{T}.(x, dx) @@ -40,7 +46,7 @@ function DI.value_and_multiderivative!( backend::AutoForwardDiff, f!, x::Number, - extras::Nothing, + extras::Nothing=nothing, ) config = DI.prepare_multiderivative(backend, f!, x, y) return DI.value_and_multiderivative!(y, multider, backend, f!, x, config) @@ -71,7 +77,7 @@ function DI.value_and_jacobian!( backend::AutoForwardDiff, f!, x::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) config = DI.prepare_jacobian(backend, f!, x, y) return DI.value_and_jacobian!(y, jac, backend, f!, x, config) diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 6ef9f61c0..aea656826 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -16,7 +16,7 @@ function DI.value_and_pushforward!( f, x, dx, - extras::Nothing, + extras::Nothing=nothing, ) where {C} return DI.value_and_pushforward!( dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx, extras @@ -30,7 +30,7 @@ function DI.value_and_pushforward!( f!, x, dx, - extras::Nothing, + extras::Nothing=nothing, ) where {C} return DI.value_and_pushforward!( y, dy, AutoForwardDiff{C,Nothing}(nothing), f!, x, dx, extras @@ -40,7 +40,11 @@ end ## Utilities function DI.value_and_gradient!( - grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, ) where {C} y = f(x) threaded_gradient!(f, grad, x, Chunk{C}()) @@ -48,14 +52,22 @@ function DI.value_and_gradient!( end function DI.gradient!( - grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, ) where {C} threaded_gradient!(f, grad, x, Chunk{C}()) return grad end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, ) where {C} y = f(x) threaded_jacobian!(f, jac, x, Chunk{C}()) @@ -63,7 +75,11 @@ function DI.value_and_jacobian!( end function DI.jacobian!( - jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, ) where {C} threaded_jacobian!(f, jac, x, Chunk{C}()) return jac @@ -75,7 +91,7 @@ function DI.value_and_jacobian!( ::AutoPolyesterForwardDiff{C}, f!, x::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) where {C} f!(y, x) threaded_jacobian!(f!, y, jac, x, Chunk{C}()) diff --git a/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl b/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl index f99b3edc2..4132a0a43 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl @@ -1,7 +1,12 @@ ## Pullback function DI.value_and_pullback!( - dx::AbstractArray, ::AutoReverseDiff, f, x::AbstractArray, dy::Real, extras::Nothing + dx::AbstractArray, + ::AutoReverseDiff, + f, + x::AbstractArray, + dy::Real, + extras::Nothing=nothing, ) res = DiffResults.DiffResult(zero(dy), dx) res = gradient!(res, f, x) @@ -16,7 +21,7 @@ function DI.value_and_pullback!( f, x::AbstractArray, dy::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) res = jacobian!(res, f, x) @@ -27,7 +32,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - _dx::Number, backend::AutoReverseDiff, f, x::Number, dy, extras::Nothing + _dx::Number, backend::AutoReverseDiff, f, x::Number, dy, extras::Nothing=nothing ) x_array = [x] dx_array = similar(x_array) @@ -40,24 +45,32 @@ end ### Unprepared function DI.value_and_gradient!( - grad::AbstractArray, backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + backend::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) return DI.value_and_gradient!(grad, backend, f, x, DI.prepare_gradient(backend, f, x)) end function DI.value_and_gradient( - backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing=nothing ) return DI.value_and_gradient(backend, f, x, DI.prepare_gradient(backend, f, x)) end function DI.gradient!( - grad::AbstractArray, backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + grad::AbstractArray, + backend::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) return DI.gradient!(grad, backend, f, x, DI.prepare_gradient(backend, f, x)) end -function DI.gradient(backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing) +function DI.gradient(backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing=nothing) return DI.gradient(backend, f, x, DI.prepare_gradient(backend, f, x)) end @@ -107,31 +120,39 @@ end ### Unprepared function DI.value_and_jacobian!( - jac::AbstractArray, backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + backend::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) return DI.value_and_jacobian!(jac, backend, f, x, DI.prepare_jacobian(backend, f, x)) end function DI.value_and_jacobian( - backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing=nothing ) return DI.value_and_jacobian(backend, f, x, DI.prepare_jacobian(backend, f, x)) end function DI.jacobian!( - jac::AbstractArray, backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, + backend::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, ) return DI.jacobian!(jac, backend, f, x, DI.prepare_jacobian(backend, f, x)) end -function DI.jacobian(backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing) +function DI.jacobian(backend::AutoReverseDiff, f, x::AbstractArray, extras::Nothing=nothing) return DI.jacobian(backend, f, x, DI.prepare_jacobian(backend, f, x)) end ### Prepared function DI.value_and_jacobian!( - jac::AbstractArray, + jac::AbstractMatrix, backend::AutoReverseDiff, f, x::AbstractArray, @@ -155,7 +176,7 @@ function DI.value_and_jacobian( end function DI.jacobian!( - jac::AbstractArray, + jac::AbstractMatrix, backend::AutoReverseDiff, f, x::AbstractArray, diff --git a/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl b/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl index a50f6d5db..8bdd59fb3 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl @@ -7,7 +7,7 @@ function DI.value_and_pullback!( f!, x::AbstractArray, dy::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) res = DiffResults.DiffResult(y, similar(dy, length(y), length(x))) res = jacobian!(res, f!, y, x) @@ -23,7 +23,7 @@ function DI.value_and_pullback!( f!, x::Number, dy, - extras::Nothing, + extras::Nothing=nothing, ) x_array = [x] dx_array = similar(x_array) @@ -36,11 +36,11 @@ end function DI.value_and_jacobian!( y::AbstractArray, - jac::AbstractArray, + jac::AbstractMatrix, backend::AutoReverseDiff, f!, x::AbstractArray, - extras::Nothing, + extras::Nothing=nothing, ) return DI.value_and_jacobian!( y, jac, backend, f!, x, DI.prepare_jacobian(backend, f!, x, y) @@ -49,7 +49,7 @@ end function DI.value_and_jacobian!( y::AbstractArray, - jac::AbstractArray, + jac::AbstractMatrix, backend::AutoReverseDiff, f!, x::AbstractArray, diff --git a/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl b/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl index c8e8958a4..b3823bf08 100644 --- a/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl +++ b/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl @@ -4,11 +4,11 @@ module DifferentiationInterfaceTestExt using ADTypes: AbstractADType using DifferentiationInterface using DifferentiationInterface.DifferentiationTest -using DifferentiationInterface: ForwardMode, ReverseMode, autodiff_mode +using DifferentiationInterface: ForwardMode, ReverseMode, mode import DifferentiationInterface as DI import DifferentiationInterface.DifferentiationTest as DT using DocStringExtensions -using LinearAlgebra: dot +using LinearAlgebra: LinearAlgebra, dot # new dependencies using ForwardDiff: ForwardDiff diff --git a/ext/DifferentiationInterfaceTestExt/correctness.jl b/ext/DifferentiationInterfaceTestExt/correctness.jl index afd337e68..f6f836b88 100644 --- a/ext/DifferentiationInterfaceTestExt/correctness.jl +++ b/ext/DifferentiationInterfaceTestExt/correctness.jl @@ -286,65 +286,6 @@ function test_correctness_gradient_allocating( end end -## Hessian - -function test_correctness_hessian_allocating( - ba::AbstractADType, scenario::Scenario, maybe_extras... -) - (; f, x, y, dx) = deepcopy(scenario) - grad_true = ForwardDiff.gradient(f, x) - hess_true = ForwardDiff.hessian(f, x) - hvp_true = reshape((hess_true * vec(dx)), size(x)) - - y_out1, grad_out1, hess_out1 = value_and_gradient_and_hessian(ba, f, x, maybe_extras...) - grad_in2, hess_in2 = zero(grad_out1), zero(hess_out1) - y_out2, grad_out2, hess_out2 = value_and_gradient_and_hessian!( - grad_in2, hess_in2, ba, f, x, maybe_extras... - ) - - hess_out3 = hessian(ba, f, x, maybe_extras...) - hess_in4 = zero(hess_out3) - hess_out4 = hessian!(hess_in4, ba, f, x, maybe_extras...) - - grad_out5, hvp_out5 = gradient_and_hessian_vector_product(ba, f, x, dx, maybe_extras...) - grad_in6, hvp_in6 = zero(grad_out5), zero(hvp_out5) - grad_out6, hvp_out6 = gradient_and_hessian_vector_product!( - grad_in6, hvp_in6, ba, f, x, dx, maybe_extras... - ) - - @testset "Primal value" begin - @test y_out1 ≈ y - @test y_out2 ≈ y - end - @testset "Gradient value" begin - @test grad_out1 ≈ grad_true rtol = 1e-3 - @test grad_out2 ≈ grad_true rtol = 1e-3 - @test grad_out5 ≈ grad_true rtol = 1e-3 - @test grad_out6 ≈ grad_true rtol = 1e-3 - @testset "Mutation" begin - @test grad_in2 ≈ grad_true rtol = 1e-3 - @test grad_in6 ≈ grad_true rtol = 1e-3 - end - end - @testset "Hessian value" begin - @test hess_out1 ≈ hess_true rtol = 1e-3 - @test hess_out2 ≈ hess_true rtol = 1e-3 - @test hess_out3 ≈ hess_true rtol = 1e-3 - @test hess_out4 ≈ hess_true rtol = 1e-3 - @testset "Mutation" begin - @test hess_in2 ≈ hess_true rtol = 1e-3 - @test hess_in4 ≈ hess_true rtol = 1e-3 - end - end - @testset "Hessian-vector product value" begin - @test hvp_out5 ≈ hvp_true rtol = 1e-3 - @test hvp_out6 ≈ hvp_true rtol = 1e-3 - @testset "Mutation" begin - @test hvp_in6 ≈ hvp_true rtol = 1e-3 - end - end -end - ## Jacobian function test_correctness_jacobian_allocating( @@ -401,3 +342,91 @@ function test_correctness_jacobian_mutating( end end end + +## Second derivative + +function test_correctness_second_derivative_allocating( + ba::AbstractADType, scenario::Scenario, maybe_extras... +) + (; f, x, y) = deepcopy(scenario) + der_true = ForwardDiff.derivative(f, x) + derder_true = ForwardDiff.derivative(x) do z + ForwardDiff.derivative(f, z) + end + + y_out1, der_out1, derder_out1 = value_derivative_and_second_derivative( + ba, f, x, maybe_extras... + ) + + derder_out2 = second_derivative(ba, f, x, maybe_extras...) + + @testset "Primal value" begin + @test y_out1 ≈ y + end + @testset "Derivative value" begin + @test der_out1 ≈ der_true rtol = 1e-3 + end + @testset "Second derivative value" begin + @test derder_out1 ≈ derder_true rtol = 1e-3 + @test derder_out2 ≈ derder_true rtol = 1e-3 + end +end + +## Hessian + +function test_correctness_hessian_allocating( + ba::AbstractADType, scenario::Scenario, maybe_extras... +) + (; f, x, y, dx) = deepcopy(scenario) + grad_true = ForwardDiff.gradient(f, x) + hess_true = ForwardDiff.hessian(f, x) + hvp_true = reshape((hess_true * vec(dx)), size(x)) + + y_out1, grad_out1, hess_out1 = value_gradient_and_hessian(ba, f, x, maybe_extras...) + grad_in2, hess_in2 = zero(grad_out1), zero(hess_out1) + y_out2, grad_out2, hess_out2 = value_gradient_and_hessian!( + grad_in2, hess_in2, ba, f, x, maybe_extras... + ) + + hess_out3 = hessian(ba, f, x, maybe_extras...) + hess_in4 = zero(hess_out3) + hess_out4 = hessian!(hess_in4, ba, f, x, maybe_extras...) + + grad_out5, hvp_out5 = gradient_and_hessian_vector_product(ba, f, x, dx, maybe_extras...) + grad_in6, hvp_in6 = zero(grad_out5), zero(hvp_out5) + grad_out6, hvp_out6 = gradient_and_hessian_vector_product!( + grad_in6, hvp_in6, ba, f, x, dx, maybe_extras... + ) + + @testset "Primal value" begin + @test y_out1 ≈ y + @test y_out2 ≈ y + end + @testset "Gradient value" begin + @test grad_out1 ≈ grad_true rtol = 1e-3 + @test grad_out2 ≈ grad_true rtol = 1e-3 + @test grad_out5 ≈ grad_true rtol = 1e-3 + @test grad_out6 ≈ grad_true rtol = 1e-3 + @testset "Mutation" begin + @test grad_in2 ≈ grad_true rtol = 1e-3 + @test grad_in6 ≈ grad_true rtol = 1e-3 + end + end + @testset "Hessian value" begin + @test hess_out1 ≈ hess_true rtol = 1e-3 + @test hess_out2 ≈ hess_true rtol = 1e-3 + @test hess_out3 ≈ hess_true rtol = 1e-3 + @test hess_out4 ≈ hess_true rtol = 1e-3 + @testset "Mutation" begin + @test hess_in2 ≈ hess_true rtol = 1e-3 + @test hess_in4 ≈ hess_true rtol = 1e-3 + end + end + @testset "Hessian-vector product value" begin + @test hvp_out5 ≈ hvp_true rtol = 1e-3 + @test hvp_out6 ≈ hvp_true rtol = 1e-3 + @testset "Mutation" begin + @test hvp_in6 ≈ hvp_true rtol = 1e-3 + end + end +end diff --git a/ext/DifferentiationInterfaceTestExt/scenarios.jl b/ext/DifferentiationInterfaceTestExt/scenarios.jl index 3e7b4fc0e..3c308a259 100644 --- a/ext/DifferentiationInterfaceTestExt/scenarios.jl +++ b/ext/DifferentiationInterfaceTestExt/scenarios.jl @@ -20,26 +20,24 @@ function gradient_scenarios(scenarios::Vector{<:Scenario}) end end -function hessian_scenarios(scenarios::Vector{<:Scenario}) - filter(scenarios) do scen - in_type(scen) <: AbstractArray && out_type(scen) <: Number - end -end - function jacobian_scenarios(scenarios::Vector{<:Scenario}) filter(scenarios) do scen in_type(scen) <: AbstractArray && out_type(scen) <: AbstractArray end end +second_derivative_scenarios(scenarios) = derivative_scenarios(scenarios) +hessian_scenarios(scenarios) = gradient_scenarios(scenarios) + for prep in [ :prepare_pushforward, :prepare_pullback, :prepare_derivative, :prepare_multiderivative, :prepare_gradient, - :prepare_hessian, :prepare_jacobian, + :prepare_second_derivative, + :prepare_hessian, ] @eval function DI.$prep(ba::AbstractADType, scen::Scenario) if is_mutating(scen) diff --git a/ext/DifferentiationInterfaceTestExt/test_allocating.jl b/ext/DifferentiationInterfaceTestExt/test_allocating.jl index 7b4777d73..c3fdb096f 100644 --- a/ext/DifferentiationInterfaceTestExt/test_allocating.jl +++ b/ext/DifferentiationInterfaceTestExt/test_allocating.jl @@ -94,7 +94,7 @@ function test_gradient_allocating( end end -function test_hessian_allocating( +function test_jacobian_allocating( ba::AbstractADType, scenarios::Vector{<:Scenario}; correctness::Bool=true, @@ -102,18 +102,18 @@ function test_hessian_allocating( allocs::Bool=false, ) @testset "$(in_type(scen)) -> $(out_type(scen))" for scen in scenarios - @testset "Extras: $(isempty(mex))" for mex in ((), (prepare_hessian(ba, scen),)) + @testset "Extras: $(isempty(mex))" for mex in ((), (prepare_jacobian(ba, scen),)) if correctness - test_correctness_hessian_allocating(ba, scen, mex...) + test_correctness_jacobian_allocating(ba, scen, mex...) end if type_stability - test_type_hessian_allocating(ba, scen, mex...) + test_type_jacobian_allocating(ba, scen, mex...) end end end end -function test_jacobian_allocating( +function test_second_derivative_allocating( ba::AbstractADType, scenarios::Vector{<:Scenario}; correctness::Bool=true, @@ -121,12 +121,32 @@ function test_jacobian_allocating( allocs::Bool=false, ) @testset "$(in_type(scen)) -> $(out_type(scen))" for scen in scenarios - @testset "Extras: $(isempty(mex))" for mex in ((), (prepare_jacobian(ba, scen),)) + @testset "Extras: $(isempty(mex))" for mex in + ((), (prepare_second_derivative(ba, scen),)) if correctness - test_correctness_jacobian_allocating(ba, scen, mex...) + test_correctness_second_derivative_allocating(ba, scen, mex...) end if type_stability - test_type_jacobian_allocating(ba, scen, mex...) + test_type_second_derivative_allocating(ba, scen, mex...) + end + end + end +end + +function test_hessian_allocating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + correctness::Bool=true, + type_stability::Bool=true, + allocs::Bool=false, +) + @testset "$(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Extras: $(isempty(mex))" for mex in ((), (prepare_hessian(ba, scen),)) + if correctness + test_correctness_hessian_allocating(ba, scen, mex...) + end + if type_stability + test_type_hessian_allocating(ba, scen, mex...) end end end @@ -153,12 +173,12 @@ function DT.test_operators_allocating( !is_mutating(scen) && in_type(scen) <: input_type && out_type(scen) <: output_type end - if autodiff_mode(ba) isa ForwardMode && :pushforward in kept + if mode(ba) isa ForwardMode && :pushforward in kept @testset "Pushforward" test_pushforward_allocating( ba, scenarios; correctness, type_stability, allocs ) end - if autodiff_mode(ba) isa ReverseMode && :pullback in kept + if mode(ba) isa ReverseMode && :pullback in kept @testset "Pullback" test_pullback_allocating( ba, scenarios; correctness, type_stability, allocs ) @@ -178,11 +198,6 @@ function DT.test_operators_allocating( ba, gradient_scenarios(scenarios); correctness, type_stability, allocs ) end - if :hessian in kept - @testset "Hessian" test_hessian_allocating( - ba, hessian_scenarios(scenarios); correctness, type_stability, allocs - ) - end if :jacobian in kept @testset "Jacobian" test_jacobian_allocating( ba, jacobian_scenarios(scenarios); correctness, type_stability, allocs @@ -190,3 +205,35 @@ function DT.test_operators_allocating( end return nothing end + +""" +$(TYPEDSIGNATURES) +""" +function DT.test_second_order_operators_allocating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}=default_scenarios(); + input_type::Type=Any, + output_type::Type=Any, + correctness::Bool=true, + type_stability::Bool=true, + allocs::Bool=false, + included::Vector{Symbol}=[:second_derivative, :hessian], + excluded::Vector{Symbol}=Symbol[], +) + kept = symdiff(included, excluded) + scenarios = filter(scenarios) do scen + !is_mutating(scen) && in_type(scen) <: input_type && out_type(scen) <: output_type + end + + if :second_derivative in kept + @testset "Second derivative" test_second_derivative_allocating( + ba, second_derivative_scenarios(scenarios); correctness, type_stability, allocs + ) + end + if :hessian in kept + @testset "Hessian" test_hessian_allocating( + ba, hessian_scenarios(scenarios); correctness, type_stability, allocs + ) + end + return nothing +end diff --git a/ext/DifferentiationInterfaceTestExt/test_mutating.jl b/ext/DifferentiationInterfaceTestExt/test_mutating.jl index bc13e9571..d20f2b4de 100644 --- a/ext/DifferentiationInterfaceTestExt/test_mutating.jl +++ b/ext/DifferentiationInterfaceTestExt/test_mutating.jl @@ -96,12 +96,12 @@ function DT.test_operators_mutating( is_mutating(scen) && in_type(scen) <: input_type && out_type(scen) <: output_type end - if autodiff_mode(ba) isa ForwardMode && :pushforward in kept + if mode(ba) isa ForwardMode && :pushforward in kept @testset "Pushforward (mutating)" test_pushforward_mutating( ba, scenarios; correctness, type_stability, allocs ) end - if autodiff_mode(ba) isa ReverseMode && :pullback in kept + if mode(ba) isa ReverseMode && :pullback in kept @testset "Pullback (mutating)" test_pullback_mutating( ba, scenarios; correctness, type_stability, allocs ) diff --git a/ext/DifferentiationInterfaceTestExt/type_stability.jl b/ext/DifferentiationInterfaceTestExt/type_stability.jl index 19367f381..a1147d8a9 100644 --- a/ext/DifferentiationInterfaceTestExt/type_stability.jl +++ b/ext/DifferentiationInterfaceTestExt/type_stability.jl @@ -92,25 +92,6 @@ function test_type_gradient_allocating( @test_opt gradient(ba, f, x, maybe_extras...) end -## Hessian - -function test_type_hessian_allocating( - ba::AbstractADType, scenario::Scenario, maybe_extras... -) - (; f, x, dx) = deepcopy(scenario) - grad_in = zero(dx) - hvp_in = zero(dx) - hess_in = zeros(eltype(x), length(x), length(x)) - @test_opt value_and_gradient_and_hessian!(grad_in, hess_in, ba, f, x, maybe_extras...) - @test_opt hessian!(hess_in, ba, f, x, maybe_extras...) - @test_opt gradient_and_hessian_vector_product!( - grad_in, hvp_in, ba, f, x, dx, maybe_extras... - ) - @test_opt value_and_gradient_and_hessian(ba, f, x, maybe_extras...) - @test_opt hessian(ba, f, x, maybe_extras...) - @test_opt gradient_and_hessian_vector_product(ba, f, x, dx, maybe_extras...) -end - ## Jacobian function test_type_jacobian_allocating( @@ -133,3 +114,40 @@ function test_type_jacobian_mutating( jac_in = zeros(eltype(y), length(y), length(x)) @test_opt value_and_jacobian!(y_in, jac_in, ba, f!, x, maybe_extras...) end + +## Second derivative + +function test_type_second_derivative_allocating( + ba::AbstractADType, scenario::Scenario, maybe_extras... +) + (; f, x) = deepcopy(scenario) + @test_opt value_derivative_and_second_derivative(ba, f, x, maybe_extras...) + @test_opt second_derivative(ba, f, x, maybe_extras...) +end + +## Hessian + +function test_type_hessian_allocating( + ba::AbstractADType, scenario::Scenario, maybe_extras... +) + (; f, x, dx) = deepcopy(scenario) + grad_in = zero(dx) + hvp_in = zero(dx) + hess_in = zeros(eltype(x), length(x), length(x)) + @test_opt ignored_modules = (LinearAlgebra,) value_gradient_and_hessian!( + grad_in, hess_in, ba, f, x, maybe_extras... + ) + @test_opt ignored_modules = (LinearAlgebra,) hessian!( + hess_in, ba, f, x, maybe_extras... + ) + @test_opt ignored_modules = (LinearAlgebra,) gradient_and_hessian_vector_product!( + grad_in, hvp_in, ba, f, x, dx, maybe_extras... + ) + @test_opt ignored_modules = (LinearAlgebra,) value_gradient_and_hessian( + ba, f, x, maybe_extras... + ) + @test_opt ignored_modules = (LinearAlgebra,) hessian(ba, f, x, maybe_extras...) + @test_opt ignored_modules = (LinearAlgebra,) gradient_and_hessian_vector_product( + ba, f, x, dx, maybe_extras... + ) +end diff --git a/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 4d35217c9..9fe3c8234 100644 --- a/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -9,14 +9,14 @@ using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, with ## Primitives function DI.value_and_pullback!( - dx::Union{Number,AbstractArray}, ::AutoZygote, f, x, dy, extras::Nothing + dx::Union{Number,AbstractArray}, ::AutoZygote, f, x, dy, extras::Nothing=nothing ) y, back = pullback(f, x) new_dx = only(back(dy)) return y, update!(dx, new_dx) end -function DI.value_and_pullback(::AutoZygote, f, x, dy, extras::Nothing) +function DI.value_and_pullback(::AutoZygote, f, x, dy, extras::Nothing=nothing) y, back = pullback(f, x) dx = only(back(dy)) return y, dx @@ -24,27 +24,27 @@ end ## Utilities -function DI.value_and_gradient(::AutoZygote, f, x::AbstractArray, extras::Nothing) +function DI.value_and_gradient(::AutoZygote, f, x::AbstractArray, extras::Nothing=nothing) res = withgradient(f, x) return res.val, only(res.grad) end function DI.value_and_gradient!( - grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray, extras + grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray, extras=nothing ) y, new_grad = DI.value_and_gradient(backend, f, x, extras) grad .= new_grad return y, grad end -function DI.value_and_jacobian(::AutoZygote, f, x::AbstractArray, extras::Nothing) +function DI.value_and_jacobian(::AutoZygote, f, x::AbstractArray, extras::Nothing=nothing) y = f(x) jac = jacobian(f, x) return y, only(jac) end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray, extras::Nothing + jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray, extras::Nothing=nothing ) y, new_jac = DI.value_and_jacobian(backend, f, x, extras) jac .= new_jac diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index 0b9d1507d..ae6d60941 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -13,20 +13,26 @@ using ADTypes: AbstractADType, AbstractForwardMode, AbstractReverseMode, AbstractFiniteDifferencesMode using DocStringExtensions using FillArrays: OneElement +using LinearAlgebra: dot -include("backends.jl") include("mode.jl") include("utils.jl") + include("pushforward.jl") include("pullback.jl") +include("zero.jl") + include("derivative.jl") include("multiderivative.jl") include("gradient.jl") include("jacobian.jl") + +include("second_order.jl") +include("second_derivative.jl") +include("hessian_vector_product.jl") include("hessian.jl") -include("zero.jl") + include("prepare.jl") -include("additional_args.jl") # submodules include("DifferentiationTest.jl") @@ -35,47 +41,59 @@ export SecondOrder export value_and_pushforward!, value_and_pushforward export value_and_pullback!, value_and_pullback + export value_and_derivative export value_and_multiderivative!, value_and_multiderivative export value_and_gradient!, value_and_gradient export value_and_jacobian!, value_and_jacobian + export gradient_and_hessian_vector_product!, gradient_and_hessian_vector_product -export value_and_gradient_and_hessian!, value_and_gradient_and_hessian +export hessian_vector_product!, hessian_vector_product + +export value_derivative_and_second_derivative +export value_gradient_and_hessian!, value_gradient_and_hessian export pushforward!, pushforward export pullback!, pullback + export derivative export multiderivative!, multiderivative export gradient!, gradient export jacobian!, jacobian + +export second_derivative export hessian!, hessian export prepare_pushforward export prepare_pullback + export prepare_derivative export prepare_multiderivative export prepare_gradient export prepare_jacobian + +export prepare_second_derivative export prepare_hessian function __init__() Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs f_name = string(exc.f) if ( - f_name == "autodiff_mode" || + f_name == "mode" || contains(f_name, "pushforward") || contains(f_name, "pullback") || contains(f_name, "derivative") || contains(f_name, "gradient") || - contains(f_name, "jacobian") + contains(f_name, "jacobian") || + contains(f_name, "hessian") ) for T in argtypes if T <: AbstractADType print( io, """\n -HINT: To use `DifferentiationInterface` with backend `$T`, you need to load the corresponding package extension. - """, + HINT: To use `DifferentiationInterface` with backend `$T`, you need to load the corresponding package extension. + """, ) return nothing end diff --git a/src/DifferentiationTest.jl b/src/DifferentiationTest.jl index 8d222628e..a4feee5f0 100644 --- a/src/DifferentiationTest.jl +++ b/src/DifferentiationTest.jl @@ -62,10 +62,12 @@ end function default_scenarios end function test_operators_allocating end +function test_second_order_operators_allocating end function test_operators_mutating end export Scenario, default_scenarios -export test_operators_allocating, test_operators_mutating +export test_operators_allocating, + test_second_order_operators_allocating, test_operators_mutating # see https://docs.julialang.org/en/v1/base/base/#Base.Experimental.register_error_hint @@ -76,10 +78,10 @@ function __init__() print( io, """\n -HINT: To use the `DifferentiationInterface.DifferentiationTest` submodule, you need to load the `DifferentiationInterfaceTestExt` package extension. Run the following command in your REPL: + HINT: To use the `DifferentiationInterface.DifferentiationTest` submodule, you need to load the `DifferentiationInterfaceTestExt` package extension. Run the following command in your REPL: - import ForwardDiff, JET, Test -""", + import ForwardDiff, JET, Test + """, ) end end diff --git a/src/additional_args.jl b/src/additional_args.jl deleted file mode 100644 index 7745b7b14..000000000 --- a/src/additional_args.jl +++ /dev/null @@ -1,107 +0,0 @@ -for operator in [:value_and_pushforward, :value_and_pullback, :pushforward, :pullback] - @eval function $operator(backend::AbstractADType, f, x, seed) - return $operator(backend, f, x, seed, nothing) - end -end - -for operator! in [:value_and_pushforward!, :value_and_pullback!, :pushforward!, :pullback!] - @eval function $operator!( - storage::Union{Number,AbstractArray}, backend::AbstractADType, f, x, seed - ) - return $operator!(storage, backend, f, x, seed, nothing) - end -end - -for operator! in [:value_and_pushforward!, :value_and_pullback!] - @eval function $operator!( - y::AbstractArray, - storage::Union{Number,AbstractArray}, - backend::AbstractADType, - f!, - x, - seed, - ) - return $operator!(y, storage, backend, f!, x, seed, nothing) - end -end - -for operator in [ - :value_and_derivative, - :value_and_multiderivative, - :value_and_gradient, - :value_and_jacobian, - :derivative, - :multiderivative, - :gradient, - :jacobian, -] - @eval function $operator(backend::AbstractADType, f, x) - return $operator(backend, f, x, nothing) - end - @eval function $operator(backend::AbstractADType, f, x, extras) - return $operator(backend, f, x, extras, autodiff_mode(backend)) - end -end - -for operator! in [ - :value_and_multiderivative!, - :value_and_gradient!, - :value_and_jacobian!, - :multiderivative!, - :gradient!, - :jacobian!, -] - @eval function $operator!(storage::AbstractArray, backend::AbstractADType, f, x) - return $operator!(storage, backend, f, x, nothing) - end - @eval function $operator!(storage::AbstractArray, backend::AbstractADType, f, x, extras) - return $operator!(storage, backend, f, x, extras, autodiff_mode(backend)) - end -end - -for operator! in [:value_and_multiderivative!, :value_and_jacobian!] - @eval function $operator!( - y::AbstractArray, storage::AbstractArray, backend::AbstractADType, f!, x - ) - return $operator!(y, storage, backend, f!, x, nothing) - end - @eval function $operator!( - y::AbstractArray, storage::AbstractArray, backend::AbstractADType, f!, x, extras - ) - return $operator!(y, storage, backend, f!, x, extras, autodiff_mode(backend)) - end -end - -for operator in [:gradient_and_hessian_vector_product] - @eval function $operator(backend::AbstractADType, f, x, v) - return $operator(backend::AbstractADType, f, x, v, nothing) - end -end - -for operator! in [:gradient_and_hessian_vector_product!] - @eval function $operator!( - grad::AbstractArray, hvp::AbstractArray, backend::AbstractADType, f, x, v - ) - return $operator!(grad, hvp, backend, f, x, v, nothing) - end -end - -for operator in [:value_and_gradient_and_hessian, :hessian] - @eval function $operator(backend::AbstractADType, f, x) - return $operator(backend, f, x, nothing) - end -end - -for operator! in [:value_and_gradient_and_hessian!] - @eval function $operator!( - grad::AbstractArray, hess::AbstractMatrix, backend::AbstractADType, f, x - ) - return $operator!(grad, hess, backend, f, x, nothing) - end -end - -for operator! in [:hessian!] - @eval function $operator!(hess::AbstractMatrix, backend::AbstractADType, f, x) - return $operator!(hess, backend, f, x, nothing) - end -end diff --git a/src/backends.jl b/src/backends.jl deleted file mode 100644 index d5d5984ed..000000000 --- a/src/backends.jl +++ /dev/null @@ -1,14 +0,0 @@ -## Traits and access - -""" - autodiff_mode(backend) - -Return `ForwardMode()` or `ReverseMode()` in a statically predictable way. - -This function must be overloaded for backends that do not inherit from `ADTypes.AbstractForwardMode` or `ADTypes.AbstractReverseMode` (e.g. because they support both forward and reverse). - -We classify `ADTypes.AbstractFiniteDifferencesMode` as forward mode. -""" -autodiff_mode(::AbstractForwardMode) = ForwardMode() -autodiff_mode(::AbstractFiniteDifferencesMode) = ForwardMode() -autodiff_mode(::AbstractReverseMode) = ReverseMode() diff --git a/src/derivative.jl b/src/derivative.jl index 1adb7edfb..8ca284bea 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -3,13 +3,19 @@ Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function value_and_derivative end +function value_and_derivative(backend::AbstractADType, f, x::Number, extras=nothing) + return value_and_derivative_aux(backend, f, x, extras, mode(backend)) +end -function value_and_derivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) +function value_and_derivative_aux( + backend::AbstractADType, f, x::Number, extras, ::ForwardMode +) return value_and_pushforward(backend, f, x, one(x), extras) end -function value_and_derivative(backend::AbstractADType, f, x::Number, extras, ::ReverseMode) +function value_and_derivative_aux( + backend::AbstractADType, f, x::Number, extras, ::ReverseMode +) return value_and_pullback(backend, f, x, one(x), extras) end @@ -18,12 +24,14 @@ end Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function derivative end +function derivative(backend::AbstractADType, f, x::Number, extras=nothing) + return derivative_aux(backend, f, x, extras, mode(backend)) +end -function derivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) +function derivative_aux(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) return pushforward(backend, f, x, one(x), extras) end -function derivative(backend::AbstractADType, f, x::Number, extras, ::ReverseMode) +function derivative_aux(backend::AbstractADType, f, x::Number, extras, ::ReverseMode) return pullback(backend, f, x, one(x), extras) end diff --git a/src/gradient.jl b/src/gradient.jl index cce0a1f23..801b02bfd 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -3,9 +3,13 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad`. """ -function value_and_gradient! end - function value_and_gradient!( + grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras=nothing +) + return value_and_gradient_aux!(grad, backend, f, x, extras, mode(backend)) +end + +function value_and_gradient_aux!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) y = f(x) @@ -16,7 +20,7 @@ function value_and_gradient!( return y, grad end -function value_and_gradient!( +function value_and_gradient_aux!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode ) return value_and_pullback!(grad, backend, f, x, one(eltype(x)), extras) @@ -27,16 +31,18 @@ end Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function value_and_gradient end +function value_and_gradient(backend::AbstractADType, f, x::AbstractArray, extras=nothing) + return value_and_gradient_aux(backend, f, x, extras, mode(backend)) +end -function value_and_gradient( +function value_and_gradient_aux( backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) grad = similar(x) return value_and_gradient!(grad, backend, f, x, extras) end -function value_and_gradient( +function value_and_gradient_aux( backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode ) return value_and_pullback(backend, f, x, one(eltype(x)), extras) @@ -47,15 +53,19 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad`. """ -function gradient! end - function gradient!( + grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras=nothing +) + return gradient_aux!(grad, backend, f, x, extras, mode(backend)) +end + +function gradient_aux!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) return last(value_and_gradient!(grad, backend, f, x, extras)) end -function gradient!( +function gradient_aux!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode ) return pullback!(grad, backend, f, x, one(eltype(x)), extras) @@ -66,12 +76,14 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function gradient end +function gradient(backend::AbstractADType, f, x::AbstractArray, extras=nothing) + return gradient_aux(backend, f, x, extras, mode(backend)) +end -function gradient(backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode) +function gradient_aux(backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode) return last(value_and_gradient(backend, f, x, extras)) end -function gradient(backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode) +function gradient_aux(backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode) return pullback(backend, f, x, one(eltype(x)), extras) end diff --git a/src/hessian.jl b/src/hessian.jl index 4c2ea5315..eadb455fd 100644 --- a/src/hessian.jl +++ b/src/hessian.jl @@ -1,86 +1,3 @@ -#= -Sources: -- https://d2jud02ci9yv69.cloudfront.net/2024-05-07-bench-hvp-81/blog/bench-hvp/ -- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html -=# - -""" - SecondOrder - -Combination of two backends for second-order differentiation of array-to-scalar functions. - -# Fields - -$(TYPEDFIELDS) -""" -struct SecondOrder{AD1<:AbstractADType,AD2<:AbstractADType} <: AbstractADType - "backend for the inner differentiation, must be reverse mode" - first::AD1 - "backend for the outer differentiation, must be forward mode" - second::AD2 - - function SecondOrder(first::AbstractADType, second::AbstractADType) - if !(autodiff_mode(first) isa ReverseMode) - throw( - ArgumentError( - "Second order is only supported with forward-over-reverse, and $first is not reverse mode.", - ), - ) - elseif !(autodiff_mode(second) isa ForwardMode) - throw( - ArgumentError( - "Second order is only supported with forward-over-reverse, and $second is not forward mode.", - ), - ) - end - return new{typeof(first),typeof(second)}(first, second) - end -end - -function Base.show(io::IO, backend::SecondOrder) - return print(io, "SecondOrder($(backend.first), $(backend.second))") -end - -function autodiff_mode(backend::SecondOrder) - return (autodiff_mode(backend.first), autodiff_mode(backend.second)) -end - -""" - gradient_and_hessian_vector_product!(grad, hvp, backend, f, x, v, [extras]) -> (grad, hvp) - -Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function, overwriting `grad` and `hvp`. -""" -function gradient_and_hessian_vector_product! end - -function gradient_and_hessian_vector_product!( - grad::AbstractArray, - hvp::AbstractArray, - backend::SecondOrder, - f, - x::AbstractArray, - v::AbstractArray, - extras, -) - function grad_aux!(grad, x) - gradient!(grad, backend.first, f, x, extras) - return nothing - end - return value_and_pushforward!(grad, hvp, backend.second, grad_aux!, x, v, extras) -end - -""" - gradient_and_hessian_vector_product(f, x, v, [extras]) -> (grad, hvp) - -Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function. -""" -function gradient_and_hessian_vector_product end - -function gradient_and_hessian_vector_product( - backend::SecondOrder, f, x::AbstractArray, v::AbstractArray, extras -) - grad_aux(x) = gradient(backend.first, f, x, extras) - return value_and_pushforward(backend.second, grad_aux, x, v, extras) -end const HESS_NOTES = """ ## Notes @@ -98,21 +15,34 @@ function check_hess(hess::AbstractMatrix, x::AbstractArray) end """ - value_and_gradient_and_hessian!(grad, hess, backend, f, x, [extras]) -> (y, grad, hess) + value_gradient_and_hessian!(grad, hess, backend, f, x, [extras]) -> (y, grad, hess) Compute the primal value `y = f(x)`, the gradient `grad = ∇f(x)` and the Hessian `hess = ∇²f(x)` of an array-to-scalar function, overwriting `grad` and `hess`. $HESS_NOTES """ -function value_and_gradient_and_hessian! end +function value_gradient_and_hessian!( + grad::AbstractArray, + hess::AbstractMatrix, + backend::SecondOrder, + f, + x::AbstractArray, + extras=nothing, +) + return value_gradient_and_hessian_aux!( + grad, hess, backend, f, x, extras, mode(inner(backend)), mode(outer(backend)) + ) +end -function value_and_gradient_and_hessian!( +function value_gradient_and_hessian_aux!( grad::AbstractArray, hess::AbstractMatrix, - backend::AbstractADType, + backend::SecondOrder, f, x::AbstractArray, extras, + ::AbstractMode, + ::ForwardMode, ) y = f(x) check_hess(hess, x) @@ -124,21 +54,39 @@ function value_and_gradient_and_hessian!( return y, grad, hess end +function value_gradient_and_hessian_aux!( + grad::AbstractArray, + hess::AbstractMatrix, + backend::SecondOrder, + f, + x::AbstractArray, + extras, + ::AbstractMode, + ::ReverseMode, +) + y, _ = value_and_gradient!(grad, inner(backend), f, x, extras) + check_hess(hess, x) + for (k, j) in enumerate(eachindex(IndexCartesian(), x)) + dx_j = basisarray(backend, x, j) + hess_col_j = reshape(view(hess, :, k), size(x)) + hessian_vector_product!(hess_col_j, backend, f, x, dx_j, extras) + end + return y, grad, hess +end + """ - value_and_gradient_and_hessian(backend, f, x, [extras]) -> (y, grad, hess) + value_gradient_and_hessian(backend, f, x, [extras]) -> (y, grad, hess) Compute the primal value `y = f(x)`, the gradient `grad = ∇f(x)` and the Hessian `hess = ∇²f(x)` of an array-to-scalar function, overwriting `grad` and `hess`. $HESS_NOTES """ -function value_and_gradient_and_hessian end - -function value_and_gradient_and_hessian( - backend::AbstractADType, f, x::AbstractArray, extras +function value_gradient_and_hessian( + backend::AbstractADType, f, x::AbstractArray, extras=nothing ) grad = similar(x) hess = similar(x, length(x), length(x)) - return value_and_gradient_and_hessian!(grad, hess, backend, f, x, extras) + return value_gradient_and_hessian!(grad, hess, backend, f, x, extras) end """ @@ -148,13 +96,11 @@ Compute the Hessian `hess = ∇²f(x)` of an array-to-scalar function, overwriti $HESS_NOTES """ -function hessian! end - function hessian!( - hess::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras + hess::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras=nothing ) grad = similar(x) - return last(value_and_gradient_and_hessian!(grad, hess, backend, f, x, extras)) + return last(value_gradient_and_hessian!(grad, hess, backend, f, x, extras)) end """ @@ -164,8 +110,6 @@ Compute the Hessian `hess = ∇²f(x)` of an array-to-scalar function. $HESS_NOTES """ -function hessian end - -function hessian(backend::AbstractADType, f, x::AbstractArray, extras) - return last(value_and_gradient_and_hessian(backend, f, x, extras)) +function hessian(backend::AbstractADType, f, x::AbstractArray, extras=nothing) + return last(value_gradient_and_hessian(backend, f, x, extras)) end diff --git a/src/hessian_vector_product.jl b/src/hessian_vector_product.jl new file mode 100644 index 000000000..b98ac598a --- /dev/null +++ b/src/hessian_vector_product.jl @@ -0,0 +1,216 @@ +#= +Sources: +- https://d2jud02ci9yv69.cloudfront.net/2024-05-07-bench-hvp-81/blog/bench-hvp/ +- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html + +Start by reading the allocating versions +=# + +## Forward-over-something gives gradient too + +""" + gradient_and_hessian_vector_product(backend, f, x, v, [extras]) -> (grad, hvp) + +Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function. + +!!! warning + Only works with a forward outer mode. +""" +function gradient_and_hessian_vector_product( + backend::SecondOrder, f, x::AbstractArray, v::AbstractArray, extras=nothing +) + return gradient_and_hessian_vector_product_aux( + backend, f, x, v, extras, mode(inner(backend)), mode(outer(backend)) + ) +end + +function gradient_and_hessian_vector_product_aux( + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ForwardMode, +) + grad_aux(z) = gradient(inner(backend), f, z, extras) + return value_and_pushforward(outer(backend), grad_aux, x, v, extras) +end + +function gradient_and_hessian_vector_product_aux( + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ReverseMode, +) + throw(ArgumentError("HVP must be computed without gradient for reverse-over-something")) +end + +""" + gradient_and_hessian_vector_product!(grad, backend, hvp, backend, f, x, v, [extras]) -> (grad, hvp) + +Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function, overwriting `grad` and `hvp`. + +!!! warning + Only works with a forward outer mode. +""" +function gradient_and_hessian_vector_product!( + grad::AbstractArray, + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras=nothing, +) + return gradient_and_hessian_vector_product_aux!( + grad, hvp, backend, f, x, v, extras, mode(inner(backend)), mode(outer(backend)) + ) +end + +function gradient_and_hessian_vector_product_aux!( + grad::AbstractArray, + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ForwardMode, +) + function grad_aux!(storage, z) + gradient!(storage, inner(backend), f, z, extras) + return nothing + end + return value_and_pushforward!(grad, hvp, outer(backend), grad_aux!, x, v, extras) +end + +function gradient_and_hessian_vector_product_aux!( + grad::AbstractArray, + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ReverseMode, +) + throw(ArgumentError("HVP must be computed without gradient for reverse-over-something")) +end + +## Reverse-over-something only gives hvp + +""" + hessian_vector_product(backend, f, x, v, [extras]) -> hvp + +Compute the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function. +""" +function hessian_vector_product( + backend::SecondOrder, f, x::AbstractArray, v::AbstractArray, extras=nothing +) + return hessian_vector_product_aux( + backend, f, x, v, extras, mode(inner(backend)), mode(outer(backend)) + ) +end + +function hessian_vector_product_aux( + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::ReverseMode, + ::ReverseMode, +) + dotgrad_aux(z) = dot(gradient(inner(backend), f, z), v, extras) + return gradient(outer(backend), dotgrad_aux, x, extras) +end + +function hessian_vector_product_aux( + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::ForwardMode, + ::ReverseMode, +) + jvp_aux(z) = pushforward(inner(backend), f, z, v, extras) + return gradient(outer(backend), jvp_aux, x, extras) +end + +function hessian_vector_product_aux( + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ForwardMode, +) + throw(ArgumentError("HVP must be computed with gradient for forward-over-something")) +end + +""" + hessian_vector_product!(hvp, backend, f, x, v, [extras]) -> hvp + +Compute the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function, overwriting `hvp`. +""" +function hessian_vector_product!( + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras=nothing, +) + return hessian_vector_product_aux!( + hvp, backend, f, x, v, extras, mode(inner(backend)), mode(outer(backend)) + ) +end + +function hessian_vector_product_aux!( + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::ReverseMode, + ::ReverseMode, +) + dotgrad_aux(z) = dot(gradient(inner(backend), f, z), v, extras) # allocates + return gradient!(hvp, outer(backend), dotgrad_aux, x, extras) +end + +function hessian_vector_product_aux!( + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::ForwardMode, + ::ReverseMode, +) + jvp_aux(z) = pushforward(inner(backend), f, z, v, extras) + return gradient!(hvp, outer(backend), jvp_aux, x, extras) +end + +function hessian_vector_product_aux!( + hvp::AbstractArray, + backend::SecondOrder, + f, + x::AbstractArray, + v::AbstractArray, + extras, + ::AbstractMode, + ::ForwardMode, +) + throw(ArgumentError("HVP must be computed with gradient for forward-over-something")) +end diff --git a/src/jacobian.jl b/src/jacobian.jl index 954c95ed4..514603b83 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -21,9 +21,24 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ -function value_and_jacobian! end +function value_and_jacobian!( + jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras=nothing +) + return value_and_jacobian_aux!(jac, backend, f, x, extras, mode(backend)) +end function value_and_jacobian!( + y::AbstractArray, + jac::AbstractMatrix, + backend::AbstractADType, + f, + x::AbstractArray, + extras=nothing, +) + return value_and_jacobian_aux!(y, jac, backend, f, x, extras, mode(backend)) +end + +function value_and_jacobian_aux!( jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) y = f(x) @@ -36,7 +51,7 @@ function value_and_jacobian!( return y, jac end -function value_and_jacobian!( +function value_and_jacobian_aux!( y::AbstractArray, jac::AbstractMatrix, backend::AbstractADType, @@ -54,7 +69,7 @@ function value_and_jacobian!( return y, jac end -function value_and_jacobian!( +function value_and_jacobian_aux!( jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode ) y = f(x) @@ -67,7 +82,7 @@ function value_and_jacobian!( return y, jac end -function value_and_jacobian!( +function value_and_jacobian_aux!( y::AbstractArray, jac::AbstractMatrix, backend::AbstractADType, @@ -92,11 +107,7 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ -function value_and_jacobian end - -function value_and_jacobian( - backend::AbstractADType, f, x::AbstractArray, extras, ::AbstractMode -) +function value_and_jacobian(backend::AbstractADType, f, x::AbstractArray, extras=nothing) y = f(x) T = promote_type(eltype(x), eltype(y)) jac = similar(y, T, length(y), length(x)) @@ -110,15 +121,8 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overw $JAC_NOTES """ -function jacobian! end - function jacobian!( - jac::AbstractMatrix, - backend::AbstractADType, - f, - x::AbstractArray, - extras, - ::AbstractMode, + jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras=nothing ) return last(value_and_jacobian!(jac, backend, f, x, extras)) end @@ -130,8 +134,6 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function jacobian end - -function jacobian(backend::AbstractADType, f, x::AbstractArray, extras, ::AbstractMode) +function jacobian(backend::AbstractADType, f, x::AbstractArray, extras=nothing) return last(value_and_jacobian(backend, f, x, extras)) end diff --git a/src/mode.jl b/src/mode.jl index 365e93e5a..0b36c6241 100644 --- a/src/mode.jl +++ b/src/mode.jl @@ -3,15 +3,28 @@ abstract type AbstractMode end """ ForwardMode -Trait identifying forward mode AD backends. -Used for internal dispatch only. +Trait identifying forward mode first-order AD backends. """ struct ForwardMode <: AbstractMode end """ ReverseMode -Trait identifying reverse mode AD backends. -Used for internal dispatch only. +Trait identifying reverse mode first-order AD backends. """ struct ReverseMode <: AbstractMode end + +""" + mode(backend) + +Return the AD mode of a backend in a statically predictable way. + +This function must be overloaded for backends that do not inherit from `ADTypes.AbstractForwardMode` or `ADTypes.AbstractReverseMode` (e.g. because they support both forward and reverse). + +We classify `ADTypes.AbstractFiniteDifferencesMode` as forward mode. +""" +function mode end + +mode(::AbstractForwardMode) = ForwardMode() +mode(::AbstractFiniteDifferencesMode) = ForwardMode() +mode(::AbstractReverseMode) = ReverseMode() diff --git a/src/multiderivative.jl b/src/multiderivative.jl index 4aa1b804c..49b55b93b 100644 --- a/src/multiderivative.jl +++ b/src/multiderivative.jl @@ -4,15 +4,30 @@ Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider`. """ -function value_and_multiderivative! end +function value_and_multiderivative!( + multider::AbstractArray, backend::AbstractADType, f, x::Number, extras=nothing +) + return value_and_multiderivative_aux!(multider, backend, f, x, extras, mode(backend)) +end function value_and_multiderivative!( + y::AbstractArray, + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, + extras=nothing, +) + return value_and_multiderivative_aux!(y, multider, backend, f, x, extras, mode(backend)) +end + +function value_and_multiderivative_aux!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) return value_and_pushforward!(multider, backend, f, x, one(x), extras) end -function value_and_multiderivative!( +function value_and_multiderivative_aux!( y::AbstractArray, multider::AbstractArray, backend::AbstractADType, @@ -24,7 +39,7 @@ function value_and_multiderivative!( return value_and_pushforward!(y, multider, backend, f!, x, one(x), extras) end -function value_and_multiderivative!( +function value_and_multiderivative_aux!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ReverseMode ) y = f(x) @@ -35,7 +50,7 @@ function value_and_multiderivative!( return y, multider end -function value_and_multiderivative!( +function value_and_multiderivative_aux!( y::AbstractArray, multider::AbstractArray, backend::AbstractADType, @@ -56,15 +71,17 @@ end Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function value_and_multiderivative end +function value_and_multiderivative(backend::AbstractADType, f, x::Number, extras=nothing) + return value_and_multiderivative_aux(backend, f, x, extras, mode(backend)) +end -function value_and_multiderivative( +function value_and_multiderivative_aux( backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) return value_and_pushforward(backend, f, x, one(x), extras) end -function value_and_multiderivative( +function value_and_multiderivative_aux( backend::AbstractADType, f, x::Number, extras, ::ReverseMode ) multider = similar(f(x)) @@ -76,15 +93,19 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider`. """ -function multiderivative! end - function multiderivative!( + multider::AbstractArray, backend::AbstractADType, f, x::Number, extras=nothing +) + return multiderivative_aux!(multider, backend, f, x, extras, mode(backend)) +end + +function multiderivative_aux!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) return pushforward!(multider, backend, f, x, one(x), extras) end -function multiderivative!( +function multiderivative_aux!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ReverseMode ) return last(value_and_multiderivative!(multider, backend, f, x, extras)) @@ -95,12 +116,14 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function multiderivative end +function multiderivative(backend::AbstractADType, f, x::Number, extras=nothing) + return multiderivative_aux(backend, f, x, extras, mode(backend)) +end -function multiderivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) +function multiderivative_aux(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) return pushforward(backend, f, x, one(x), extras) end -function multiderivative(backend::AbstractADType, f, x::Number, extras, ::ReverseMode) +function multiderivative_aux(backend::AbstractADType, f, x::Number, extras, ::ReverseMode) return last(value_and_multiderivative(backend, f, x, extras)) end diff --git a/src/prepare.jl b/src/prepare.jl index 981c6d31f..483d58745 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -60,6 +60,15 @@ function prepare_jacobian end prepare_jacobian(::AbstractADType, f, x::AbstractArray) = nothing prepare_jacobian(::AbstractADType, f!, x::AbstractArray, y::AbstractArray) = nothing +""" + prepare_second_derivative(backend, f, x) -> extras + +Create an `extras` object that can be given to second derivative operators. +""" +function prepare_second_derivative end + +prepare_second_derivative(::AbstractADType, f, x::Number) = nothing + """ prepare_hessian(backend, f, x) -> extras diff --git a/src/pullback.jl b/src/pullback.jl index 79774c635..0a50a44a6 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -14,9 +14,7 @@ function value_and_pullback! end Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function value_and_pullback end - -function value_and_pullback(backend::AbstractADType, f, x, dy, extras) +function value_and_pullback(backend::AbstractADType, f, x, dy, extras=nothing) dx = mysimilar(x) return value_and_pullback!(dx, backend, f, x, dy, extras) end @@ -26,9 +24,7 @@ end Compute the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. """ -function pullback! end - -function pullback!(dx, backend::AbstractADType, f, x, dy, extras) +function pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) return last(value_and_pullback!(dx, backend, f, x, dy, extras)) end @@ -37,8 +33,6 @@ end Compute the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function pullback end - -function pullback(backend::AbstractADType, f, x, dy, extras) +function pullback(backend::AbstractADType, f, x, dy, extras=nothing) return last(value_and_pullback(backend, f, x, dy, extras)) end diff --git a/src/pushforward.jl b/src/pushforward.jl index 2ba5236e4..eeafb3e22 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -14,7 +14,7 @@ function value_and_pushforward! end Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function value_and_pushforward(backend::AbstractADType, f, x, dx, extras) +function value_and_pushforward(backend::AbstractADType, f, x, dx, extras=nothing) dy = mysimilar(f(x)) return value_and_pushforward!(dy, backend, f, x, dx, extras) end @@ -24,9 +24,7 @@ end Compute the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. """ -function pushforward! end - -function pushforward!(dy, backend::AbstractADType, f, x, dx, extras) +function pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) return last(value_and_pushforward!(dy, backend, f, x, dx, extras)) end @@ -35,8 +33,6 @@ end Compute the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function pushforward end - -function pushforward(backend::AbstractADType, f, x, dx, extras) +function pushforward(backend::AbstractADType, f, x, dx, extras=nothing) return last(value_and_pushforward(backend, f, x, dx, extras)) end diff --git a/src/second_derivative.jl b/src/second_derivative.jl new file mode 100644 index 000000000..61a95d1c8 --- /dev/null +++ b/src/second_derivative.jl @@ -0,0 +1,36 @@ +""" + value_derivative_and_second_derivative(backend, f, x, [extras]) -> (y, der, derder) + +Compute the primal value `y = f(x)`, the derivative `der = f'(x)` and the second derivative `derder = f''(x)` of a scalar-to-scalar function. +""" +function value_derivative_and_second_derivative( + backend::SecondOrder, f, x::Number, extras=nothing +) + y = f(x) + der_aux(x) = derivative(inner(backend), f, x, extras) + der, derder = value_and_derivative(outer(backend), der_aux, x, extras) + return y, der, derder +end + +function value_derivative_and_second_derivative( + backend::AbstractADType, f, x::Number, extras=nothing +) + return value_derivative_and_second_derivative( + SecondOrder(backend, backend), f, x, extras + ) +end + +""" + second_derivative(backend, f, x, [extras]) -> derder + +Compute the second derivative `derder = f''(x)` of a scalar-to-scalar function. +""" +function second_derivative(backend::SecondOrder, f, x::Number, extras=nothing) + der_aux(x) = derivative(inner(backend), f, x, extras) + derder = derivative(outer(backend), der_aux, x, extras) + return derder +end + +function second_derivative(backend::AbstractADType, f, x::Number, extras=nothing) + return second_derivative(SecondOrder(backend, backend), f, x, extras) +end diff --git a/src/second_order.jl b/src/second_order.jl new file mode 100644 index 000000000..547ecbea7 --- /dev/null +++ b/src/second_order.jl @@ -0,0 +1,22 @@ +""" + SecondOrder + +Combination of two backends for second-order differentiation. + +# Fields + +$(TYPEDFIELDS) +""" +struct SecondOrder{AD1<:AbstractADType,AD2<:AbstractADType} <: AbstractADType + "backend for the inner differentiation" + inner::AD1 + "backend for the outer differentiation" + outer::AD2 +end + +inner(backend::SecondOrder) = backend.inner +outer(backend::SecondOrder) = backend.outer + +function Base.show(io::IO, backend::SecondOrder) + return print(io, "SecondOrder($(inner(backend)), $(outer(backend)))") +end diff --git a/src/zero.jl b/src/zero.jl index 264700d14..c8a3bfd2c 100644 --- a/src/zero.jl +++ b/src/zero.jl @@ -14,13 +14,13 @@ Trivial backend that sets all derivatives to zero. Used in testing and benchmark struct AutoZeroReverse <: AbstractReverseMode end function value_and_pushforward!( - dy::Union{Number,AbstractArray}, ::AutoZeroForward, f, x, dx, extras::Nothing + dy::Union{Number,AbstractArray}, ::AutoZeroForward, f, x, dx, extras::Nothing=nothing ) return f(x), zero!(dy) end function value_and_pullback!( - dx::Union{Number,AbstractArray}, ::AutoZeroReverse, f, x, dy, extras::Nothing + dx::Union{Number,AbstractArray}, ::AutoZeroReverse, f, x, dy, extras::Nothing=nothing ) return f(x), zero!(dx) end @@ -32,7 +32,7 @@ function value_and_pushforward!( f!, x, dx, - extras::Nothing, + extras::Nothing=nothing, ) f!(y, x) return y, zero!(dy) @@ -45,7 +45,7 @@ function value_and_pullback!( f!, x, dy, - extras::Nothing, + extras::Nothing=nothing, ) f!(y, x) return y, zero!(dx) diff --git a/test/runtests.jl b/test/runtests.jl index 414aa983b..db842726c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,35 +33,37 @@ using Zygote: Zygote include("zero.jl") end - @testset "ChainRules (forward)" begin - @test_skip include("chainrules_forward.jl") - end - @testset "ChainRules (reverse)" begin - include("chainrules_reverse.jl") - end - @testset "Diffractor (forward)" begin - include("diffractor.jl") - end - @testset "Enzyme (forward)" begin - include("enzyme_forward.jl") - end - @testset "Enzyme (reverse)" begin - include("enzyme_reverse.jl") - end - @testset "FiniteDiff" begin - include("finitediff.jl") - end - @testset "ForwardDiff" begin - include("forwarddiff.jl") - end - @testset "PolyesterForwardDiff" begin - include("polyesterforwarddiff.jl") - end - @testset "ReverseDiff" begin - include("reversediff.jl") - end - @testset "Zygote" begin - include("zygote.jl") + @testset verbose = true "First order" begin + @testset "ChainRules (forward)" begin + @test_skip include("chainrules_forward.jl") + end + @testset "ChainRules (reverse)" begin + include("chainrules_reverse.jl") + end + @testset "Diffractor (forward)" begin + include("diffractor.jl") + end + @testset "Enzyme (forward)" begin + include("enzyme_forward.jl") + end + @testset "Enzyme (reverse)" begin + include("enzyme_reverse.jl") + end + @testset "FiniteDiff" begin + include("finitediff.jl") + end + @testset "ForwardDiff" begin + include("forwarddiff.jl") + end + @testset "PolyesterForwardDiff" begin + include("polyesterforwarddiff.jl") + end + @testset "ReverseDiff" begin + include("reversediff.jl") + end + @testset "Zygote" begin + include("zygote.jl") + end end @testset "Second order" begin diff --git a/test/second_order.jl b/test/second_order.jl index 040bf44ff..610edb62e 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -9,14 +9,13 @@ using ReverseDiff: ReverseDiff using Zygote: Zygote second_order_backends = [ + SecondOrder(AutoForwardDiff(), AutoForwardDiff()), SecondOrder(AutoZygote(), AutoForwardDiff()), SecondOrder(AutoReverseDiff(), AutoForwardDiff()), SecondOrder(AutoZygote(), AutoEnzyme(Enzyme.Forward)), ] -@testset "$(typeof(backend.second)) over $(typeof(backend.first))" for backend in - second_order_backends - test_operators_allocating( - backend; input_type=AbstractVector, included=[:hessian], type_stability=false - ) +@testset "$(typeof(backend.outer)) over $(typeof(backend.inner))" for backend in + second_order_backends + test_second_order_operators_allocating(backend; type_stability=false) end; diff --git a/test/zero.jl b/test/zero.jl index f7d0652fc..caa376044 100644 --- a/test/zero.jl +++ b/test/zero.jl @@ -1,4 +1,4 @@ -using DifferentiationInterface: AutoZeroForward, AutoZeroReverse +using DifferentiationInterface: SecondOrder, AutoZeroForward, AutoZeroReverse using DifferentiationInterface.DifferentiationTest test_operators_allocating(AutoZeroForward(); correctness=false); @@ -7,8 +7,26 @@ test_operators_mutating(AutoZeroForward(); correctness=false); test_operators_allocating(AutoZeroReverse(); correctness=false); test_operators_mutating(AutoZeroReverse(); correctness=false); -test_operators_allocating( +test_second_order_operators_allocating( + SecondOrder(AutoZeroForward(), AutoZeroForward()); + correctness=false, + included=[:second_derivative, :hessian], +) + +test_second_order_operators_allocating( SecondOrder(AutoZeroReverse(), AutoZeroForward()); correctness=false, - included=[:hessian], + included=[:second_derivative, :hessian], +) + +test_second_order_operators_allocating( + SecondOrder(AutoZeroForward(), AutoZeroReverse()); + correctness=false, + included=[:second_derivative, :hessian], +) + +test_second_order_operators_allocating( + SecondOrder(AutoZeroReverse(), AutoZeroReverse()); + correctness=false, + included=[:second_derivative, :hessian], )