From be4f4a8f16d36b85c98f598a8de26445b8c492fc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 14 Mar 2024 17:24:52 +0100 Subject: [PATCH] Support mutating functions `f!(y, x)` (#41) --- README.md | 12 +- docs/make.jl | 2 + docs/src/api.md | 9 +- docs/src/backends.md | 13 ++ docs/src/developer.md | 47 +++- docs/src/getting_started.md | 23 +- ...fferentiationInterfaceChainRulesCoreExt.jl | 14 +- .../DifferentiationInterfaceDiffractorExt.jl | 4 +- .../forward.jl | 8 +- .../reverse.jl | 31 ++- .../DifferentiationInterfaceFiniteDiffExt.jl | 8 +- .../DifferentiationInterfaceForwardDiffExt.jl | 34 +++ .../mutating.jl | 97 ++++++++ .../non_mutating.jl} | 123 ++++------ ...tiationInterfacePolyesterForwardDiffExt.jl | 7 +- .../DifferentiationInterfaceReverseDiffExt.jl | 24 +- .../DifferentiationInterfaceTestExt.jl | 3 +- .../scenarios.jl | 121 ++++++++-- .../test_mutating.jl | 217 ++++++++++++++++++ .../{test.jl => test_non_mutating.jl} | 31 +-- .../DifferentiationInterfaceZygoteExt.jl | 4 +- src/DifferentiationTest.jl | 70 ++++-- src/derivative.jl | 4 + src/gradient.jl | 8 + src/jacobian.jl | 45 ++++ src/mode_trait.jl | 25 +- src/multiderivative.jl | 37 +++ src/prepare.jl | 40 +++- src/pullback.jl | 7 + src/pushforward.jl | 6 + test/enzyme_forward.jl | 6 +- test/enzyme_reverse.jl | 2 +- test/finitediff.jl | 6 +- test/forwarddiff.jl | 10 +- test/runtests.jl | 6 + test/zero.jl | 8 + 36 files changed, 927 insertions(+), 185 deletions(-) rename ext/{ => DifferentiationInterfaceChainRulesCoreExt}/DifferentiationInterfaceChainRulesCoreExt.jl (85%) rename ext/{ => DifferentiationInterfaceDiffractorExt}/DifferentiationInterfaceDiffractorExt.jl (86%) rename ext/{ => DifferentiationInterfaceFiniteDiffExt}/DifferentiationInterfaceFiniteDiffExt.jl (93%) create mode 100644 ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl create mode 100644 ext/DifferentiationInterfaceForwardDiffExt/mutating.jl rename ext/{DifferentiationInterfaceForwardDiffExt.jl => DifferentiationInterfaceForwardDiffExt/non_mutating.jl} (50%) rename ext/{ => DifferentiationInterfacePolyesterForwardDiffExt}/DifferentiationInterfacePolyesterForwardDiffExt.jl (89%) rename ext/{ => DifferentiationInterfaceReverseDiffExt}/DifferentiationInterfaceReverseDiffExt.jl (81%) create mode 100644 ext/DifferentiationInterfaceTestExt/test_mutating.jl rename ext/DifferentiationInterfaceTestExt/{test.jl => test_non_mutating.jl} (93%) rename ext/{ => DifferentiationInterfaceZygoteExt}/DifferentiationInterfaceZygoteExt.jl (91%) diff --git a/README.md b/README.md index 72993d831..c769b1ede 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,17 @@ An interface to various automatic differentiation backends in Julia. ## Goal -This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either real numbers or abstract arrays. +This package provides a backend-agnostic syntax to differentiate functions of the following types: -It supports in-place versions of every operator, and ensures type stability whenever possible. +- Allocating: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays +- Mutating: `f!(y, x)` where `y` is an abstract array and `x` can be a real number or an abstract array ## Compatibility We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): | Backend | Type | -|:--------------------------------------------------------------------------------|:-----------------------------------------------------------| +| :------------------------------------------------------------------------------ | :--------------------------------------------------------- | | [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` | | [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` | | [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` | @@ -64,13 +65,12 @@ julia> grad ## Related packages -- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl. We aim to be less generic (one input, one output, first order only) but more efficient (type stability, memory reuse). -- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl. We provide similar functionality (except for the matrix-like behavior) but cover more backends. +- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl. +- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl. ## Roadmap Goals for future releases: - implement backend-specific cache objects -- support in-place functions `f!(y, x)` - define user-facing functions to test and benchmark backends against each other diff --git a/docs/make.jl b/docs/make.jl index 92ebe4ca1..ad7c2a8ef 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,6 @@ using Base: get_extension using DifferentiationInterface +using DifferentiationInterface.DifferentiationTest import DifferentiationInterface as DI using Documenter using DocumenterMermaid @@ -49,6 +50,7 @@ makedocs(; modules=[ ADTypes, DifferentiationInterface, + DifferentiationInterface.DifferentiationTest, ChainRulesCoreExt, DiffractorExt, EnzymeExt, diff --git a/docs/src/api.md b/docs/src/api.md index 50364a1fa..d15c67461 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -64,6 +64,13 @@ These are not part of the public API. ```@autodocs Modules = [DifferentiationInterface] -Pages = ["backends.jl", "mode.jl", "utils.jl", "DifferentiationTest.jl"] Public = false ``` + +## Submodules + +These are not part of the public API. + +```@autodocs +Modules = [DifferentiationTest] +``` diff --git a/docs/src/backends.md b/docs/src/backends.md index cd75ed10e..d141753c4 100644 --- a/docs/src/backends.md +++ b/docs/src/backends.md @@ -24,6 +24,19 @@ AutoReverseDiff AutoZygote ``` +## Accepted functions + +| Backend | `f(x) = y` | `f!(y, x)` | +| -------------------------- | ---------- | ---------- | +| `AutoChainRules` | yes | no | +| `AutoDiffractor` | yes | no | +| `AutoEnzyme` | yes | soon | +| `AutoForwardDiff` | yes | yes | +| `AutoFiniteDiff` | yes | soon | +| `AutoPolyesterForwardDiff` | yes | soon | +| `AutoReverseDiff` | yes | soon | +| `AutoZygote` | yes | no | + ## Package extensions ```@meta diff --git a/docs/src/developer.md b/docs/src/developer.md index 95d819e1e..b3bf9a2a6 100644 --- a/docs/src/developer.md +++ b/docs/src/developer.md @@ -21,8 +21,7 @@ Advanced users are welcome to code more backends and submit pull requests! Edge labels correspond to the amount of function calls when applying operators to a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$. - -### Forward mode +### Forward mode, allocating functions ```mermaid flowchart LR @@ -69,7 +68,28 @@ flowchart LR end ``` -### Reverse mode +### Forward mode, mutating functions + +```mermaid +flowchart LR + subgraph Multiderivative + value_and_multiderivative! + end + + value_and_multiderivative! --> value_and_pushforward! + + subgraph Jacobian + value_and_jacobian! + end + + value_and_jacobian! --> |n|value_and_pushforward! + + subgraph Pushforward + value_and_pushforward! + end +``` + +### Reverse mode, allocating functions ```mermaid flowchart LR @@ -115,3 +135,24 @@ flowchart LR pullback --> value_and_pullback end ``` + +### Reverse mode, mutating functions + +```mermaid +flowchart LR + subgraph Multiderivative + value_and_multiderivative! + end + + value_and_multiderivative! --> |m|value_and_pullback! + + subgraph Jacobian + value_and_jacobian! + end + + value_and_jacobian! --> |m|value_and_pullback! + + subgraph Pullback + value_and_pullback! + end +``` diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index ff7f9afeb..d32395b1a 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -12,11 +12,11 @@ We choose the following terminology for the ones we provide: Most backends have custom implementations for all of these, which we reuse whenever possible. -### Variants +## Variants Whenever it makes sense, four variants of the same operator are defined: -| **Operator** | **non-mutating** | **mutating** | **non-mutating with primal** | **mutating with primal** | +| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** | | :---------------- | :------------------------ | :------------------------- | :---------------------------------- | :----------------------------------- | | Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A | | Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) | @@ -49,3 +49,22 @@ This is especially worth it if you plan to call `operator` several times in simi By default, all the preparation functions return `nothing`. We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected. + +## Mutating functions + +In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x)` whenever the output is an array. +Since they operate in-place and the primal is computed every time, only four operators are defined: + +| **Operator** | **mutating with primal** | +| :---------------- | :----------------------------------- | +| Multiderivative | [`value_and_multiderivative!`](@ref) | +| Jacobian | [`value_and_jacobian!`](@ref) | +| Pushforward (JVP) | [`value_and_pushforward!`](@ref) | +| Pullback (VJP) | [`value_and_pullback!`](@ref) | + +Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`. + +## Multiple inputs/outputs + +Restricting the API to one input and one output has many coding advantages, but it is not very flexible. +If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`. \ No newline at end of file diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl similarity index 85% rename from ext/DifferentiationInterfaceChainRulesCoreExt.jl rename to ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index 7c5bf6f5e..97a25f6a2 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -26,7 +26,12 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - dy, backend::AutoForwardChainRules, f, x, dx, extras=nothing + dy::Union{Number,AbstractArray}, + backend::AutoForwardChainRules, + f, + x, + dx, + extras=nothing, ) y, new_dy = DI.value_and_pushforward(backend, f, x, dx, extras) return y, update!(dy, new_dy) @@ -42,7 +47,12 @@ function DI.value_and_pullback( end function DI.value_and_pullback!( - dx, backend::AutoReverseChainRules, f, x, dy, extras=nothing + dx::Union{Number,AbstractArray}, + backend::AutoReverseChainRules, + f, + x, + dy, + extras=nothing, ) y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras) return y, update!(dx, new_dx) diff --git a/ext/DifferentiationInterfaceDiffractorExt.jl b/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl similarity index 86% rename from ext/DifferentiationInterfaceDiffractorExt.jl rename to ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 2e0a5785d..779a818e9 100644 --- a/ext/DifferentiationInterfaceDiffractorExt.jl +++ b/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -16,7 +16,9 @@ function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=no return y, dy end -function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing) +function DI.value_and_pushforward!( + dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing +) vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) y, new_dy = vpff((dx,)) return y, update!(dy, new_dy) diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl index ed2d4431f..095ddb9e1 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -4,15 +4,15 @@ DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode() ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing -) where {X,Y<:Real} + _dy::Real, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing +) y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) return y, new_dy end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing -) where {X,Y<:AbstractArray} + dy::AbstractArray, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing +) y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) dy .= new_dy return y, dy diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl index 9f9699a81..9d3025754 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -15,16 +15,21 @@ end ## Primitives function DI.value_and_pullback!( - _dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:Number,Y<:Union{Real,Nothing}} + _dx::Number, ::AutoReverseEnzyme, f, x::Number, dy, extras::Nothing=nothing +) der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) new_dx = dy * only(der) return y, new_dx end function DI.value_and_pullback!( - dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:Union{Real,Nothing}} + dx::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + dy::Number, + extras::Nothing=nothing, +) dx .= zero(eltype(dx)) _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) dx .*= dy @@ -32,8 +37,13 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - _dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:Number,Y<:AbstractArray} + _dx::Number, + ::AutoReverseEnzyme, + f, + x::Number, + dy::AbstractArray, + extras::Nothing=nothing, +) y = f(x) mf! = MakeFunctionMutating(f) _, new_dx = only(autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Active(x))) @@ -41,8 +51,13 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:AbstractArray} + dx::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + dy::AbstractArray, + extras::Nothing=nothing, +) y = f(x) dx_like_x = zero(x) mf! = MakeFunctionMutating(f) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl similarity index 93% rename from ext/DifferentiationInterfaceFiniteDiffExt.jl rename to ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index ba5b0d0bb..0f2055eda 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -17,8 +17,8 @@ const FUNCTION_NOT_INPLACE = Val{false} ## Primitives function DI.value_and_pushforward!( - dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing -) where {Y<:Number,fdtype} + _dy::Number, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing +) where {fdtype} y = f(x) step(t::Number)::Number = f(x .+ t .* dx) new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y) @@ -26,8 +26,8 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing -) where {Y<:AbstractArray,fdtype} + dy::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing +) where {fdtype} y = f(x) step(t::Number)::AbstractArray = f(x .+ t .* dx) finite_difference_gradient!( diff --git a/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl new file mode 100644 index 000000000..533464a5e --- /dev/null +++ b/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -0,0 +1,34 @@ +module DifferentiationInterfaceForwardDiffExt + +using ADTypes: AutoForwardDiff +import DifferentiationInterface as DI +using DiffResults: DiffResults +using DocStringExtensions +using ForwardDiff: + Chunk, + Dual, + DerivativeConfig, + GradientConfig, + JacobianConfig, + Tag, + derivative, + derivative!, + extract_derivative, + extract_derivative!, + gradient, + gradient!, + jacobian, + jacobian!, + value +using LinearAlgebra: mul! + +choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) +choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}() + +tag_type(::F, ::V) where {F,V<:Number} = Tag{F,V} +tag_type(::F, ::AbstractArray{V}) where {F,V<:Number} = Tag{F,V} + +include("non_mutating.jl") +include("mutating.jl") + +end # module diff --git a/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl b/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl new file mode 100644 index 000000000..587a4ddd6 --- /dev/null +++ b/ext/DifferentiationInterfaceForwardDiffExt/mutating.jl @@ -0,0 +1,97 @@ +## Pushforward + +function DI.value_and_pushforward!( + y::AbstractArray, dy, ::AutoForwardDiff, f!, x::Real, dx, extras::Nothing=nothing +) + T = tag_type(f!, x) + xdual = Dual{T}(x, dx) + ydual = Dual{T}.(y, dy) + f!(ydual, xdual) + y .= value.(T, ydual) + dy = extract_derivative!(T, dy, ydual) + return y, dy +end + +function DI.value_and_pushforward!( + y::AbstractArray, + dy::AbstractArray, + ::AutoForwardDiff, + f!, + x::AbstractArray, + dx, + extras::Nothing=nothing, +) + T = tag_type(f!, x) + xdual = Dual{T}.(x, dx) + ydual = Dual{T}.(y, dy) + f!(ydual, xdual) + y .= value.(T, ydual) + dy = extract_derivative!(T, dy, ydual) + return y, dy +end + +## Multiderivative + +function DI.value_and_multiderivative!( + y::AbstractArray, + multider::AbstractArray, + backend::AutoForwardDiff, + f!, + x::Number, + extras::Nothing, +) + config = DI.prepare_multiderivative(backend, f!, x, y) + return DI.value_and_multiderivative!(y, multider, backend, f!, x, config) +end + +function DI.value_and_multiderivative!( + y::AbstractArray, + multider::AbstractArray, + ::AutoForwardDiff, + f!, + x::Number, + config::DerivativeConfig, +) + result = DiffResults.DiffResult(y, multider) + result = derivative!(result, f!, y, x, config) + return DiffResults.value(result), DiffResults.derivative(result) +end + +## Jacobian + +function DI.value_and_jacobian!( + y::AbstractArray, + jac::AbstractMatrix, + backend::AutoForwardDiff, + f!, + x::AbstractArray, + extras::Nothing, +) + config = DI.prepare_jacobian(backend, f!, x, y) + return DI.value_and_jacobian!(y, jac, backend, f!, x, config) +end + +function DI.value_and_jacobian!( + y::AbstractArray, + jac::AbstractMatrix, + ::AutoForwardDiff, + f!, + x::AbstractArray, + config::JacobianConfig, +) + result = DiffResults.DiffResult(y, jac) + result = jacobian!(result, f!, y, x, config) + return DiffResults.value(result), DiffResults.jacobian(result) +end + +## Preparation + +function DI.prepare_multiderivative(::AutoForwardDiff, f!, x::Number, y::AbstractArray) + return DerivativeConfig(f!, y, x) +end + +function DI.prepare_jacobian( + backend::AutoForwardDiff, f!, x::AbstractArray, y::AbstractArray +) + return JacobianConfig(f!, y, x, choose_chunk(backend, x)) +end diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt/non_mutating.jl similarity index 50% rename from ext/DifferentiationInterfaceForwardDiffExt.jl rename to ext/DifferentiationInterfaceForwardDiffExt/non_mutating.jl index 98b15c863..cf6c9a98b 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt/non_mutating.jl @@ -1,32 +1,9 @@ -module DifferentiationInterfaceForwardDiffExt - -using ADTypes: AutoForwardDiff -import DifferentiationInterface as DI -using DiffResults: DiffResults -using DocStringExtensions -using ForwardDiff: - Chunk, - Dual, - GradientConfig, - JacobianConfig, - Tag, - derivative, - derivative!, - extract_derivative, - extract_derivative!, - gradient, - gradient!, - jacobian, - jacobian!, - value -using LinearAlgebra: mul! - ## Pushforward function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing -) where {X<:Real,Y<:Real} - T = typeof(Tag(f, X)) + _dy::Real, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing=nothing +) + T = tag_type(f, x) xdual = Dual{T}(x, dx) ydual = f(xdual) y = value(T, ydual) @@ -35,9 +12,9 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing -) where {X<:Real,Y<:AbstractArray} - T = typeof(Tag(f, X)) + dy::AbstractArray, ::AutoForwardDiff, f, x::Real, dx, extras::Nothing=nothing +) + T = tag_type(f, x) xdual = Dual{T}(x, dx) ydual = f(xdual) y = value.(T, ydual) @@ -46,9 +23,9 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:Real} - T = typeof(Tag(f, eltype(X))) + _dy::Real, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing=nothing +) + T = tag_type(f, x) xdual = Dual{T}.(x, dx) ydual = f(xdual) y = value(T, ydual) @@ -57,9 +34,9 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:AbstractArray} - T = typeof(Tag(f, eltype(X))) + dy::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, dx, extras::Nothing=nothing +) + T = tag_type(f, x) xdual = Dual{T}.(x, dx) ydual = f(xdual) y = value.(T, ydual) @@ -91,52 +68,54 @@ end function DI.value_and_gradient!( grad::AbstractArray, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.value_and_gradient!(grad, backend, f, x, DI.prepare_gradient(backend, f, x)) + config = DI.prepare_gradient(backend, f, x) + return DI.value_and_gradient!(grad, backend, f, x, config) end function DI.value_and_gradient!( - grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, extras::GradientConfig + grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, config::GradientConfig ) result = DiffResults.DiffResult(zero(eltype(x)), grad) - result = gradient!(result, f, x, extras) + result = gradient!(result, f, x, config) return DiffResults.value(result), DiffResults.gradient(result) end function DI.value_and_gradient( backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.value_and_gradient(backend, f, x, DI.prepare_gradient(backend, f, x)) + config = DI.prepare_gradient(backend, f, x) + return DI.value_and_gradient(backend, f, x, config) end function DI.value_and_gradient( - ::AutoForwardDiff, f, x::AbstractArray, extras::GradientConfig + ::AutoForwardDiff, f, x::AbstractArray, config::GradientConfig ) result = DiffResults.GradientResult(x) - result = gradient!(result, f, x, extras) + result = gradient!(result, f, x, config) return DiffResults.value(result), DiffResults.gradient(result) end function DI.gradient!( grad::AbstractArray, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.gradient!(grad, backend, f, x, DI.prepare_gradient(backend, f, x)) + config = DI.prepare_gradient(backend, f, x) + return DI.gradient!(grad, backend, f, x, config) end function DI.gradient!( - grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, extras::GradientConfig + grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray, config::GradientConfig ) - gradient!(grad, f, x, extras) + gradient!(grad, f, x, config) return grad end function DI.gradient(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing) - return DI.gradient(backend, f, x, DI.prepare_gradient(backend, f, x)) + config = DI.prepare_gradient(backend, f, x) + return DI.gradient(backend, f, x, config) end -function DI.gradient( - ::AutoForwardDiff, f, x::AbstractArray, extras::Union{Nothing,GradientConfig} -) - return gradient(f, x, extras) +function DI.gradient(::AutoForwardDiff, f, x::AbstractArray, config::GradientConfig) + return gradient(f, x, config) end ## Jacobian @@ -144,67 +123,61 @@ end function DI.value_and_jacobian!( jac::AbstractMatrix, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.value_and_jacobian!(jac, backend, f, x, DI.prepare_jacobian(backend, f, x)) + config = DI.prepare_jacobian(backend, f, x) + return DI.value_and_jacobian!(jac, backend, f, x, config) end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray, extras::JacobianConfig + jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray, config::JacobianConfig ) y = f(x) result = DiffResults.DiffResult(y, jac) - result = jacobian!(result, f, x, extras) + result = jacobian!(result, f, x, config) return DiffResults.value(result), DiffResults.jacobian(result) end function DI.value_and_jacobian( backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.value_and_jacobian(backend, f, x, DI.prepare_jacobian(backend, f, x)) + config = DI.prepare_jacobian(backend, f, x) + return DI.value_and_jacobian(backend, f, x, config) end function DI.value_and_jacobian( - ::AutoForwardDiff, f, x::AbstractArray, extras::JacobianConfig + ::AutoForwardDiff, f, x::AbstractArray, config::JacobianConfig ) - return f(x), jacobian(f, x, extras) + return f(x), jacobian(f, x, config) end function DI.jacobian!( jac::AbstractMatrix, backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing ) - return DI.jacobian!(jac, backend, f, x, DI.prepare_jacobian(backend, f, x)) + config = DI.prepare_jacobian(backend, f, x) + return DI.jacobian!(jac, backend, f, x, config) end function DI.jacobian!( - jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray, extras::JacobianConfig + jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray, config::JacobianConfig ) - jacobian!(jac, f, x, extras) + jacobian!(jac, f, x, config) return jac end function DI.jacobian(backend::AutoForwardDiff, f, x::AbstractArray, extras::Nothing) - return DI.jacobian(backend, f, x, DI.prepare_jacobian(backend, f, x)) + config = DI.prepare_jacobian(backend, f, x) + return DI.jacobian(backend, f, x, config) end -function DI.jacobian(::AutoForwardDiff, f, x::AbstractArray, extras::JacobianConfig) - return jacobian(f, x, extras) +function DI.jacobian(::AutoForwardDiff, f, x::AbstractArray, config::JacobianConfig) + return jacobian(f, x, config) end ## Preparation -function DI.prepare_gradient(::AutoForwardDiff{nothing}, f, x::AbstractArray) - return GradientConfig(f, x, Chunk(x)) +function DI.prepare_gradient(backend::AutoForwardDiff, f, x::AbstractArray) + return GradientConfig(f, x, choose_chunk(backend, x)) end -function DI.prepare_gradient(::AutoForwardDiff{C}, f, x::AbstractArray) where {C} - return GradientConfig(f, x, Chunk{C}()) +function DI.prepare_jacobian(backend::AutoForwardDiff, f, x::AbstractArray) + return JacobianConfig(f, x, choose_chunk(backend, x)) end - -function DI.prepare_jacobian(::AutoForwardDiff{nothing}, f, x::AbstractArray) - return JacobianConfig(f, x, Chunk(x)) -end - -function DI.prepare_jacobian(::AutoForwardDiff{C}, f, x::AbstractArray) where {C} - return JacobianConfig(f, x, Chunk{C}()) -end - -end # module diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl similarity index 89% rename from ext/DifferentiationInterfacePolyesterForwardDiffExt.jl rename to ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index b20f00c8c..b059d58f8 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -11,7 +11,12 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! ## Primitives function DI.value_and_pushforward!( - dy, ::AutoPolyesterForwardDiff{C}, f, x, dx, extras::Nothing=nothing + dy::Union{Number,AbstractArray}, + ::AutoPolyesterForwardDiff{C}, + f, + x, + dx, + extras::Nothing=nothing, ) where {C} return DI.value_and_pushforward!( dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx, extras diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl similarity index 81% rename from ext/DifferentiationInterfaceReverseDiffExt.jl rename to ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index a0ef53440..97c50244d 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -10,9 +10,14 @@ using ReverseDiff: gradient, gradient!, jacobian, jacobian! ## Primitives function DI.value_and_pullback!( - dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:Real} - res = DiffResults.DiffResult(zero(Y), dx) + dx::AbstractArray, + ::AutoReverseDiff, + f, + x::AbstractArray, + dy::Real, + extras::Nothing=nothing, +) + res = DiffResults.DiffResult(zero(dy), dx) res = gradient!(res, f, x) y = DiffResults.value(res) dx .= dy .* DiffResults.gradient(res) @@ -20,8 +25,13 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:AbstractArray,Y<:AbstractArray} + dx::AbstractArray, + ::AutoReverseDiff, + f, + x::AbstractArray, + dy::AbstractArray, + extras::Nothing=nothing, +) res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) res = jacobian!(res, f, x) y = DiffResults.value(res) @@ -31,8 +41,8 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - _dx, backend::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing -) where {X<:Number,Y} + _dx::Number, backend::AutoReverseDiff, f, x::Number, dy, extras::Nothing=nothing +) x_array = [x] dx_array = similar(x_array) y, dx_array = DI.value_and_pullback!(dx_array, backend, f ∘ only, x_array, dy, extras) diff --git a/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl b/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl index 02a267b3f..3149bd8eb 100644 --- a/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl +++ b/ext/DifferentiationInterfaceTestExt/DifferentiationInterfaceTestExt.jl @@ -16,6 +16,7 @@ using Random: AbstractRNG, default_rng, randn! using Test: @test, @testset include("scenarios.jl") -include("test.jl") +include("test_non_mutating.jl") +include("test_mutating.jl") end diff --git a/ext/DifferentiationInterfaceTestExt/scenarios.jl b/ext/DifferentiationInterfaceTestExt/scenarios.jl index 8dd682d7b..f21b66d30 100644 --- a/ext/DifferentiationInterfaceTestExt/scenarios.jl +++ b/ext/DifferentiationInterfaceTestExt/scenarios.jl @@ -1,52 +1,68 @@ - in_type(::Scenario{F,X}) where {F,X} = X out_type(::Scenario{F,X,Y}) where {F,X,Y} = Y +mutating(s::Scenario) = s.mutating + +## Auto fill -function DT.make_scenario(rng::AbstractRNG, f, x) +function DT.Scenario(rng::AbstractRNG, f, x::Union{Number,AbstractArray}) y = f(x) - return make_scenario(rng, f, x, y) + return make_scenario(rng, f, x, y; mutating=false) end -## Auto fill +function DT.Scenario( + rng::AbstractRNG, f!, x::Union{Number,AbstractArray}, s::NTuple{N,<:Integer} +) where {N} + y = randn(eltype(x), s...) + f!(y, x) + return make_scenario(rng, f!, x, y; mutating=true) +end -function DT.make_scenario(rng::AbstractRNG, f, x::Number, y::Number) +function make_scenario(rng::AbstractRNG, f, x::Number, y::Number; mutating) dx = randn(rng, typeof(x)) dy = randn(rng, typeof(y)) der_true = ForwardDiff.derivative(f, x) dx_true = der_true * dy dy_true = der_true * dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true, mutating) end -function DT.make_scenario(rng::AbstractRNG, f, x::Number, y::AbstractArray) +function make_scenario(rng::AbstractRNG, f, x::Number, y::AbstractArray; mutating) dx = randn(rng, typeof(x)) dy = similar(y) randn!(rng, dy) - multider_true = ForwardDiff.derivative(f, x) + if mutating + multider_true = ForwardDiff.derivative(f, y, x) + else + multider_true = ForwardDiff.derivative(f, x) + end dx_true = dot(multider_true, dy) dy_true = multider_true .* dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true, mutating) end -function DT.make_scenario(rng::AbstractRNG, f, x::AbstractArray, y::Number) +function make_scenario(rng::AbstractRNG, f, x::AbstractArray, y::Number; mutating) dx = similar(x) randn!(rng, dx) dy = randn(rng, typeof(y)) grad_true = ForwardDiff.gradient(f, x) dx_true = grad_true .* dy dy_true = dot(grad_true, dx) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true, mutating) end -function DT.make_scenario(rng::AbstractRNG, f, x::AbstractArray, y::AbstractArray) +function make_scenario(rng::AbstractRNG, f, x::AbstractArray, y::AbstractArray; mutating) dx = similar(x) randn!(rng, dx) dy = similar(y) randn!(rng, dy) - jac_true = ForwardDiff.jacobian(f, x) + if mutating + jac_true = ForwardDiff.jacobian(f, y, x) + else + jac_true = ForwardDiff.jacobian(f, x) + end dx_true = reshape(transpose(jac_true) * vec(dy), size(x)) dy_true = reshape(jac_true * vec(dx), size(y)) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true, mutating) end ## Defaults @@ -54,30 +70,87 @@ end f_scalar_scalar(x::Number)::Number = sin(x) f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] + +function f!_scalar_vector(y::AbstractVector, x::Number) + y[1] = sin(x) + y[2] = sin(2x) + return nothing +end + f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) +function f!_scalar_matrix(y::AbstractMatrix, x::Number) + y[1, 1] = sin(x) + y[2, 1] = cos(x) + y[1, 2] = sin(2x) + y[2, 2] = cos(2x) + return nothing +end + f_vector_scalar(x::AbstractVector)::Number = sum(sin, x) f_matrix_scalar(x::AbstractMatrix)::Number = sum(sin, x) f_vector_vector(x::AbstractVector)::AbstractVector = vcat(sin.(x), cos.(x)) + +function f!_vector_vector(y::AbstractVector, x::AbstractVector) + y[1:length(x)] .= sin.(x) + y[(length(x) + 1):(2length(x))] .= cos.(x) + return nothing +end + f_vector_matrix(x::AbstractVector)::AbstractMatrix = hcat(sin.(x), cos.(x)) +function f!_vector_matrix(y::AbstractMatrix, x::AbstractVector) + y[:, 1] .= sin.(x) + y[:, 2] .= cos.(x) + return nothing +end + f_matrix_vector(x::AbstractMatrix)::AbstractVector = vcat(vec(sin.(x)), vec(cos.(x))) + +function f!_matrix_vector(y::AbstractVector, x::AbstractMatrix) + y[1:length(x)] .= sin.(vec(x)) + y[(length(x) + 1):(2length(x))] .= cos.(vec(x)) + return nothing +end + f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix = hcat(vec(sin.(x)), vec(cos.(x))) -function DT.default_scenarios(rng::AbstractRNG) +function f!_matrix_matrix(y::AbstractMatrix, x::AbstractMatrix) + y[:, 1] .= sin.(vec(x)) + y[:, 2] .= cos.(vec(x)) + return nothing +end + +function default_scenarios_non_mutating(rng::AbstractRNG) + scenarios = [ + Scenario(rng, f_scalar_scalar, 1.0), + Scenario(rng, f_scalar_vector, 1.0), + Scenario(rng, f_scalar_matrix, 1.0), + Scenario(rng, f_vector_scalar, [1.0, 2.0]), + Scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), + Scenario(rng, f_vector_vector, [1.0, 2.0]), + Scenario(rng, f_vector_matrix, [1.0, 2.0]), + Scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), + Scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), + ] + return scenarios +end + +function default_scenarios_mutating(rng::AbstractRNG) scenarios = [ - make_scenario(rng, f_scalar_scalar, 1.0), - make_scenario(rng, f_scalar_vector, 1.0), - make_scenario(rng, f_scalar_matrix, 1.0), - make_scenario(rng, f_vector_scalar, [1.0, 2.0]), - make_scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_vector_vector, [1.0, 2.0]), - make_scenario(rng, f_vector_matrix, [1.0, 2.0]), - make_scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), + Scenario(rng, f!_scalar_vector, 1.0, (2,)), + Scenario(rng, f!_scalar_matrix, 1.0, (2, 2)), + Scenario(rng, f!_vector_vector, [1.0, 2.0], (4,)), + Scenario(rng, f!_vector_matrix, [1.0, 2.0], (2, 2)), + Scenario(rng, f!_matrix_vector, [1.0 2.0; 3.0 4.0], (8,)), + Scenario(rng, f!_matrix_matrix, [1.0 2.0; 3.0 4.0], (4, 2)), ] return scenarios end +function DT.default_scenarios(rng::AbstractRNG) + return vcat(default_scenarios_non_mutating(rng), default_scenarios_mutating(rng)) +end + DT.default_scenarios() = default_scenarios(default_rng()) diff --git a/ext/DifferentiationInterfaceTestExt/test_mutating.jl b/ext/DifferentiationInterfaceTestExt/test_mutating.jl new file mode 100644 index 000000000..d9536750c --- /dev/null +++ b/ext/DifferentiationInterfaceTestExt/test_mutating.jl @@ -0,0 +1,217 @@ +function DT.test_pushforward_mutating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + input_type::Type=Any, + output_type::Type=Any, + correctness::Bool=true, + type_stability::Bool=true, +) + scenarios = filter(scenarios) do s + in_type(s) <: input_type && out_type(s) <: output_type && mutating(s) + end + @testset "Pushforward (mutating): $(in_type(scen)) -> $(out_type(scen))" for scen in + scenarios + (; f, x, y, dx, dy_true) = scen + f! = f + extras = prepare_pushforward(ba, f!, x, y) + @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) + y_in = zero(y) + dy_in = zero(dy_true) + y_out, dy_out = value_and_pushforward!( + y_in, dy_in, ba, f!, x, dx, maybe_extras... + ) + + if correctness + @testset "Primal value" begin + @test y_out ≈ y + @testset "Mutation" begin + @test y_in ≈ y + end + end + @testset "Tangent value" begin + @test dy_out ≈ dy_true rtol = 1e-3 + @testset "Mutation" begin + @test dy_in ≈ dy_true rtol = 1e-3 + end + end + end + if type_stability + @testset "Type stability" begin + @test_opt value_and_pushforward!( + y_in, dy_in, ba, f!, x, dx, maybe_extras... + ) + end + end + end + end +end + +function DT.test_pullback_mutating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + input_type::Type=Any, + output_type::Type=Any, + correctness::Bool=true, + type_stability::Bool=true, +) + scenarios = filter(scenarios) do s + in_type(s) <: input_type && out_type(s) <: output_type && mutating(s) + end + @testset "Pullback (mutating): $(in_type(scen)) -> $(out_type(scen))" for scen in + scenarios + (; f, x, y, dy, dx_true) = scen + f! = f + extras = prepare_pullback(ba, f!, x, y) + @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) + y_in = zero(y) + dx_in = zero(dx_true) + y_out, dx_out = value_and_pullback!(y_in, dx_in, ba, f!, x, dy, maybe_extras...) + + if correctness + @testset "Primal value" begin + @test y_out ≈ y + @testset "Mutation" begin + @test y_in ≈ y + end + end + @testset "Cotangent value" begin + @test dx_out ≈ dx_true rtol = 1e-3 + if ismutable(dx_true) + @testset "Mutation" begin + @test dx_in ≈ dx_true rtol = 1e-3 + end + end + end + end + if type_stability + @testset "Type stability" begin + @test_opt value_and_pullback!( + y_in, dx_in, ba, f!, x, dy, maybe_extras... + ) + end + end + end + end +end + +function DT.test_multiderivative_mutating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + input_type::Type=Number, + output_type::Type=AbstractArray, + correctness::Bool=true, + type_stability::Bool=true, +) + scenarios = filter(scenarios) do s + in_type(s) <: typeintersect(input_type, Number) && + out_type(s) <: typeintersect(output_type, AbstractArray) && + mutating(s) + end + @testset "Multiderivative (mutating): $(in_type(scen)) -> $(out_type(scen))" for scen in + scenarios + (; f, x, y, multider_true) = scen + f! = f + extras = prepare_multiderivative(ba, f!, x, y) + @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) + y_in = zero(y) + multider_in = zero(multider_true) + y_out, multider_out = value_and_multiderivative!( + y_in, multider_in, ba, f!, x, maybe_extras... + ) + + if correctness + @testset "Primal value" begin + @test y_out ≈ y + @testset "Mutation" begin + @test y_in ≈ y + end + end + @testset "Multiderivative value" begin + @test multider_out ≈ multider_true rtol = 1e-3 + @testset "Mutation" begin + @test multider_in ≈ multider_true rtol = 1e-3 + end + end + end + if type_stability + @testset "Type stability" begin + @test_opt value_and_multiderivative!( + y_in, multider_in, ba, f!, x, maybe_extras... + ) + end + end + end + end +end + +function DT.test_jacobian_mutating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + input_type::Type=AbstractArray, + output_type::Type=AbstractArray, + correctness::Bool=true, + type_stability::Bool=true, +) + scenarios = filter(scenarios) do s + in_type(s) <: typeintersect(input_type, AbstractArray) && + out_type(s) <: typeintersect(output_type, AbstractArray) && + mutating(s) + end + @testset "Jacobian (mutating): $(in_type(scen)) -> $(out_type(scen))" for scen in + scenarios + (; f, x, y, jac_true) = scen + f! = f + extras = prepare_jacobian(ba, f!, x, y) + @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) + y_in = zero(y) + jac_in = similar(y, length(y), length(x)) + y_out, jac_out = value_and_jacobian!(y_in, jac_in, ba, f!, x, maybe_extras...) + + if correctness + @testset "Primal value" begin + @test y_out ≈ y + @testset "Mutation" begin + @test y_in ≈ y + end + end + @testset "Jacobian value" begin + @test jac_out ≈ jac_true rtol = 1e-3 + @testset "Mutation" begin + @test jac_in ≈ jac_true rtol = 1e-3 + end + end + end + if type_stability + @testset "Type stability" begin + @test_opt value_and_jacobian!(y_in, jac_in, ba, f!, x, maybe_extras...) + end + end + end + end +end + +function DT.test_all_operators_mutating( + ba::AbstractADType, + scenarios::Vector{<:Scenario}; + input_type::Type=Any, + output_type::Type=Any, + correctness::Bool=true, + type_stability::Bool=true, +) + if autodiff_mode(ba) isa ForwardMode + @testset "Pushforward (mutating)" test_pushforward_mutating( + ba, scenarios; input_type, output_type, correctness, type_stability + ) + elseif autodiff_mode(ba) isa ReverseMode + @testset "Pullback (mutating)" test_pullback_mutating( + ba, scenarios; input_type, output_type, correctness, type_stability + ) + end + @testset "Multiderivative (mutating)" test_multiderivative_mutating( + ba, scenarios; input_type, output_type, correctness, type_stability + ) + @testset "Jacobian (mutating)" test_jacobian_mutating( + ba, scenarios; input_type, output_type, correctness, type_stability + ) + return nothing +end diff --git a/ext/DifferentiationInterfaceTestExt/test.jl b/ext/DifferentiationInterfaceTestExt/test_non_mutating.jl similarity index 93% rename from ext/DifferentiationInterfaceTestExt/test.jl rename to ext/DifferentiationInterfaceTestExt/test_non_mutating.jl index 22356e5ec..65f7e1a2f 100644 --- a/ext/DifferentiationInterfaceTestExt/test.jl +++ b/ext/DifferentiationInterfaceTestExt/test_non_mutating.jl @@ -1,4 +1,3 @@ - function DT.test_pushforward( ba::AbstractADType, scenarios::Vector{<:Scenario}; @@ -8,9 +7,9 @@ function DT.test_pushforward( type_stability::Bool=true, ) scenarios = filter(scenarios) do s - in_type(s) <: input_type && out_type(s) <: output_type + in_type(s) <: input_type && out_type(s) <: output_type && !mutating(s) end - @testset "Pushforward $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Pushforward: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios (; f, x, y, dx, dy_true) = scen extras = prepare_pushforward(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) @@ -61,9 +60,9 @@ function DT.test_pullback( type_stability::Bool=true, ) scenarios = filter(scenarios) do s - (in_type(s) <: input_type) && (out_type(s) <: output_type) + in_type(s) <: input_type && out_type(s) <: output_type && !mutating(s) end - @testset "Pullback $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Pullback: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios (; f, x, y, dy, dx_true) = scen extras = prepare_pullback(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) @@ -115,9 +114,11 @@ function DT.test_derivative( ) scenarios = filter(scenarios) do s in_type(s) <: typeintersect(input_type, Number) && - out_type(s) <: typeintersect(output_type, Number) + out_type(s) <: typeintersect(output_type, Number) && + !mutating(s) end - @testset "Derivative $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Derivative: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @assert !scen.mutating (; f, x, y, der_true) = scen extras = prepare_derivative(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) @@ -154,9 +155,10 @@ function DT.test_multiderivative( ) scenarios = filter(scenarios) do s in_type(s) <: typeintersect(input_type, Number) && - out_type(s) <: typeintersect(output_type, AbstractArray) + out_type(s) <: typeintersect(output_type, AbstractArray) && + !mutating(s) end - @testset "Multiderivative $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Multiderivative: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios (; f, x, y, multider_true) = scen extras = prepare_multiderivative(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) @@ -210,9 +212,10 @@ function DT.test_gradient( ) scenarios = filter(scenarios) do s in_type(s) <: typeintersect(input_type, AbstractArray) && - out_type(s) <: typeintersect(output_type, Number) + out_type(s) <: typeintersect(output_type, Number) && + !mutating(s) end - @testset "Gradient $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Gradient: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios (; f, x, y, grad_true) = scen extras = prepare_gradient(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) @@ -262,9 +265,11 @@ function DT.test_jacobian( ) scenarios = filter(scenarios) do s in_type(s) <: typeintersect(input_type, AbstractArray) && - out_type(s) <: typeintersect(output_type, AbstractArray) + out_type(s) <: typeintersect(output_type, AbstractArray) && + !mutating(s) end - @testset "Jacobian $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @testset "Jacobian: $(in_type(scen)) -> $(out_type(scen))" for scen in scenarios + @assert !scen.mutating (; f, x, y, jac_true) = scen extras = prepare_jacobian(ba, f, x) @testset "Extras: $(isempty(maybe_extras))" for maybe_extras in ((), (extras,)) diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl similarity index 91% rename from ext/DifferentiationInterfaceZygoteExt.jl rename to ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index b358a19af..9fe3c8234 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -8,7 +8,9 @@ using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, with ## Primitives -function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy, extras::Nothing=nothing) +function DI.value_and_pullback!( + dx::Union{Number,AbstractArray}, ::AutoZygote, f, x, dy, extras::Nothing=nothing +) y, back = pullback(f, x) new_dx = only(back(dy)) return y, update!(dx, new_dx) diff --git a/src/DifferentiationTest.jl b/src/DifferentiationTest.jl index c90ba7889..41100f5a6 100644 --- a/src/DifferentiationTest.jl +++ b/src/DifferentiationTest.jl @@ -15,7 +15,15 @@ using ADTypes: AbstractADType, AbstractForwardMode, AbstractReverseMode using DifferentiationInterface: ForwardMode, ReverseMode, autodiff_mode, zero! import DifferentiationInterface as DI -@kwdef struct Scenario{F,X,Y,D1,D2,D3,D4} +@kwdef struct Scenario{ + F, + X<:Union{Number,AbstractArray}, + Y<:Union{Number,AbstractArray}, + D1<:Union{Nothing,Number}, + D2<:Union{Nothing,AbstractArray}, + D3<:Union{Nothing,AbstractArray}, + D4<:Union{Nothing,AbstractArray}, +} "function" f::F "argument" @@ -38,49 +46,79 @@ import DifferentiationInterface as DI grad_true::D3 = nothing "Jacobian result" jac_true::D4 = nothing + "mutation" + mutating::Bool = false end -function make_scenario end function default_scenarios end function test_pushforward end +function test_pushforward_mutating end function test_pullback end +function test_pullback_mutating end function test_derivative end function test_multiderivative end +function test_multiderivative_mutating end function test_gradient end function test_jacobian end +function test_jacobian_mutating end function test_all_operators end +function test_all_operators_mutating end struct AutoZeroForward <: AbstractForwardMode end struct AutoZeroReverse <: AbstractReverseMode end -function DI.value_and_pushforward!(dy, ::AutoZeroForward, f, x, dx, extras=nothing) +function DI.value_and_pushforward!( + dy::Union{Number,AbstractArray}, ::AutoZeroForward, f, x, dx, extras=nothing +) return f(x), zero!(dy) end -function DI.value_and_pullback!(dx, ::AutoZeroReverse, f, x, dy, extras=nothing) +function DI.value_and_pullback!( + dx::Union{Number,AbstractArray}, ::AutoZeroReverse, f, x, dy, extras=nothing +) return f(x), zero!(dx) end -export Scenario, make_scenario, default_scenarios +function DI.value_and_pushforward!( + y::AbstractArray, + dy::Union{Number,AbstractArray}, + ::AutoZeroForward, + f!, + x, + dx, + extras=nothing, +) + f!(y, x) + return y, zero!(dy) +end + +function DI.value_and_pullback!( + y::AbstractArray, + dx::Union{Number,AbstractArray}, + ::AutoZeroReverse, + f!, + x, + dy, + extras=nothing, +) + f!(y, x) + return y, zero!(dx) +end + +export Scenario, default_scenarios export test_pushforward, test_pullback +export test_pushforward_mutating, test_pullback_mutating export test_derivative, test_multiderivative, test_gradient, test_jacobian +export test_multiderivative_mutating, test_jacobian_mutating export test_all_operators +export test_all_operators_mutating # see https://docs.julialang.org/en/v1/base/base/#Base.Experimental.register_error_hint function __init__() Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs - if exc.f in [ - make_scenario, - default_scenarios, - test_pushforward, - test_pullback, - test_derivative, - test_multiderivative, - test_gradient, - test_jacobian, - test_all_operators, - ] + f_name = string(exc.f) + if (contains(f_name, "scenario") || contains(f_name, "test_")) print( io, """\n diff --git a/src/derivative.jl b/src/derivative.jl index 8e25d208a..1adb7edfb 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -3,6 +3,8 @@ Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ +function value_and_derivative end + function value_and_derivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) return value_and_pushforward(backend, f, x, one(x), extras) end @@ -16,6 +18,8 @@ end Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ +function derivative end + function derivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) return pushforward(backend, f, x, one(x), extras) end diff --git a/src/gradient.jl b/src/gradient.jl index 856ec3aab..3e66aaf27 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -3,6 +3,8 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ +function value_and_gradient! end + function value_and_gradient!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) @@ -25,6 +27,8 @@ end Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ +function value_and_gradient end + function value_and_gradient( backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) @@ -43,6 +47,8 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ +function gradient! end + function gradient!( grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) @@ -60,6 +66,8 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ +function gradient end + function gradient(backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode) return last(value_and_gradient(backend, f, x, extras)) end diff --git a/src/jacobian.jl b/src/jacobian.jl index f5455c9f6..de9d268d9 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -15,11 +15,14 @@ end """ value_and_jacobian!(jac, backend, f, x, [extras]) -> (y, jac) + value_and_jacobian!(y, jac, backend, f!, x, [extras]) -> (y, jac) Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. $JAC_NOTES """ +function value_and_jacobian! end + function value_and_jacobian!( jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras, ::ForwardMode ) @@ -33,6 +36,24 @@ function value_and_jacobian!( return y, jac end +function value_and_jacobian!( + y::AbstractArray, + jac::AbstractMatrix, + backend::AbstractADType, + f!, + x::AbstractArray, + extras, + ::ForwardMode, +) + check_jac(jac, x, y) + for (k, j) in enumerate(eachindex(IndexCartesian(), x)) + dx_j = basisarray(backend, x, j) + jac_col_j = reshape(view(jac, :, k), size(y)) + value_and_pushforward!(y, jac_col_j, backend, f!, x, dx_j, extras) + end + return y, jac +end + function value_and_jacobian!( jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, extras, ::ReverseMode ) @@ -46,6 +67,24 @@ function value_and_jacobian!( return y, jac end +function value_and_jacobian!( + y::AbstractArray, + jac::AbstractMatrix, + backend::AbstractADType, + f!, + x::AbstractArray, + extras, + ::ReverseMode, +) + check_jac(jac, x, y) + for (k, i) in enumerate(eachindex(IndexCartesian(), y)) + dy_i = basisarray(backend, y, i) + jac_row_i = reshape(view(jac, k, :), size(x)) + value_and_pullback!(y, jac_row_i, backend, f!, x, dy_i, extras) + end + return y, jac +end + """ value_and_jacobian(backend, f, x, [extras]) -> (y, jac) @@ -53,6 +92,8 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ +function value_and_jacobian end + function value_and_jacobian( backend::AbstractADType, f, x::AbstractArray, extras, ::AbstractMode ) @@ -69,6 +110,8 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overw $JAC_NOTES """ +function jacobian! end + function jacobian!( jac::AbstractMatrix, backend::AbstractADType, @@ -87,6 +130,8 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ +function jacobian end + function jacobian(backend::AbstractADType, f, x::AbstractArray, extras, ::AbstractMode) return last(value_and_jacobian(backend, f, x, extras)) end diff --git a/src/mode_trait.jl b/src/mode_trait.jl index 83cd253f6..d43053071 100644 --- a/src/mode_trait.jl +++ b/src/mode_trait.jl @@ -8,7 +8,9 @@ for operator in [ :gradient, :jacobian, ] - @eval function $operator(backend::AbstractADType, f, x, extras=nothing) + @eval function $operator( + backend::AbstractADType, f, x::Union{Number,AbstractArray}, extras=nothing + ) return $operator(backend, f, x, extras, autodiff_mode(backend)) end end @@ -23,7 +25,26 @@ for operator in [ :gradient!, :jacobian!, ] - @eval function $operator(storage, backend::AbstractADType, f, x, extras=nothing) + @eval function $operator( + storage::AbstractArray, + backend::AbstractADType, + f, + x::Union{Number,AbstractArray}, + extras=nothing, + ) return $operator(storage, backend, f, x, extras, autodiff_mode(backend)) end end + +for operator in [:value_and_multiderivative!, :value_and_jacobian!] + @eval function $operator( + y::AbstractArray, + storage::AbstractArray, + backend::AbstractADType, + f!, + x::Union{Number,AbstractArray}, + extras=nothing, + ) + return $operator(y, storage, backend, f!, x, extras, autodiff_mode(backend)) + end +end diff --git a/src/multiderivative.jl b/src/multiderivative.jl index 8e8fcce4b..97becfbd6 100644 --- a/src/multiderivative.jl +++ b/src/multiderivative.jl @@ -1,14 +1,29 @@ """ value_and_multiderivative!(multider, backend, f, x, [extras]) -> (y, multider) + value_and_multiderivative!(y, multider, backend, f!, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ +function value_and_multiderivative! end + function value_and_multiderivative!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) return value_and_pushforward!(multider, backend, f, x, one(x), extras) end +function value_and_multiderivative!( + y::AbstractArray, + multider::AbstractArray, + backend::AbstractADType, + f!, + x::Number, + extras, + ::ForwardMode, +) + return value_and_pushforward!(y, multider, backend, f!, x, one(x), extras) +end + function value_and_multiderivative!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ReverseMode ) @@ -20,11 +35,29 @@ function value_and_multiderivative!( return y, multider end +function value_and_multiderivative!( + y::AbstractArray, + multider::AbstractArray, + backend::AbstractADType, + f!, + x::Number, + extras, + ::ReverseMode, +) + for i in eachindex(IndexCartesian(), multider) + dy_i = basisarray(backend, multider, i) + _, multider[i] = value_and_pullback!(y, multider[i], backend, f!, x, dy_i, extras) + end + return y, multider +end + """ value_and_multiderivative(backend, f, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ +function value_and_multiderivative end + function value_and_multiderivative( backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) @@ -43,6 +76,8 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ +function multiderivative! end + function multiderivative!( multider::AbstractArray, backend::AbstractADType, f, x::Number, extras, ::ForwardMode ) @@ -60,6 +95,8 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ +function multiderivative end + function multiderivative(backend::AbstractADType, f, x::Number, extras, ::ForwardMode) return pushforward(backend, f, x, one(x), extras) end diff --git a/src/prepare.jl b/src/prepare.jl index 399199a19..ed3553703 100644 --- a/src/prepare.jl +++ b/src/prepare.jl @@ -1,41 +1,69 @@ """ prepare_pullback(backend, f, x) -> extras + prepare_pullback(backend, f!, x, y) -> extras Create an `extras` object that can be given to pullback operators. """ -prepare_pullback(backend::AbstractADType, f, x) = nothing +function prepare_pullback end + +prepare_pullback(::AbstractADType, f, x::Union{Number,AbstractArray}) = nothing +function prepare_pullback( + ::AbstractADType, f!, x::Union{Number,AbstractArray}, y::Union{Number,AbstractArray} +) + return nothing +end """ prepare_pushforward(backend, f, x) -> extras + prepare_pushforward(backend, f!, x, y) -> extras Create an `extras` object that can be given to pushforward operators. """ -prepare_pushforward(backend::AbstractADType, f, x) = nothing +function prepare_pushforward end + +prepare_pushforward(::AbstractADType, f, x::Union{Number,AbstractArray}) = nothing +function prepare_pushforward( + ::AbstractADType, f!, x::Union{Number,AbstractArray}, y::Union{Number,AbstractArray} +) + return nothing +end """ prepare_derivative(backend, f, x) -> extras Create an `extras` object that can be given to derivative operators. """ -prepare_derivative(backend::AbstractADType, f, x::Number) = nothing +function prepare_derivative end + +prepare_derivative(::AbstractADType, f, x::Number) = nothing """ prepare_multiderivative(backend, f, x) -> extras + prepare_multiderivative(backend, f!, x, y) -> extras Create an `extras` object that can be given to multiderivative operators. """ -prepare_multiderivative(backend::AbstractADType, f, x::Number) = nothing +function prepare_multiderivative end + +prepare_multiderivative(::AbstractADType, f, x::Number) = nothing +prepare_multiderivative(::AbstractADType, f!, x::Number, y::AbstractArray) = nothing """ prepare_gradient(backend, f, x) -> extras Create an `extras` object that can be given to gradient operators. """ -prepare_gradient(backend::AbstractADType, f, x::AbstractArray) = nothing +function prepare_gradient end + +prepare_gradient(::AbstractADType, f, x::AbstractArray) = nothing """ prepare_jacobian(backend, f, x) -> extras + prepare_jacobian(backend, f!, x, y) -> extras Create an `extras` object that can be given to jacobian operators. """ -prepare_jacobian(backend::AbstractADType, f, x::AbstractArray) = nothing +function prepare_jacobian end + +prepare_jacobian(::AbstractADType, f, x::AbstractArray) = nothing +prepare_jacobian(::AbstractADType, f!, x::AbstractArray, y::AbstractArray) = nothing diff --git a/src/pullback.jl b/src/pullback.jl index 5a816f47c..20ce261a3 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,5 +1,6 @@ """ value_and_pullback!(dx, backend, f, x, dy, [extras]) -> (y, dx) + value_and_pullback!(y, dx, backend, f!, x, dy, [extras]) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. @@ -13,6 +14,8 @@ function value_and_pullback! end Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ +function value_and_pullback end + function value_and_pullback(backend::AbstractADType, f, x, dy, extras=nothing) dx = mysimilar(x) return value_and_pullback!(dx, backend, f, x, dy, extras) @@ -23,6 +26,8 @@ end Compute the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. """ +function pullback! end + function pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) return last(value_and_pullback!(dx, backend, f, x, dy, extras)) end @@ -32,6 +37,8 @@ end Compute the vector-Jacobian product `dx = ∂f(x)' * dy`. """ +function pullback end + function pullback(backend::AbstractADType, f, x, dy, extras=nothing) return last(value_and_pullback(backend, f, x, dy, extras)) end diff --git a/src/pushforward.jl b/src/pushforward.jl index 6e002bf37..5b4767343 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -1,5 +1,6 @@ """ value_and_pushforward!(dy, backend, f, x, dx, [extras]) -> (y, dy) + value_and_pushforward!(y, dy, backend, f!, x, dx, [extras]) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. @@ -20,9 +21,12 @@ end """ pushforward!(dy, backend, f, x, dx, [extras]) -> dy + pushforward!(y, dy, backend, f!, x, dx, [extras]) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. """ +function pushforward! end + function pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) return last(value_and_pushforward!(dy, backend, f, x, dx, extras)) end @@ -32,6 +36,8 @@ end Compute the Jacobian-vector product `dy = ∂f(x) * dx`. """ +function pushforward end + function pushforward(backend::AbstractADType, f, x, dx, extras=nothing) return last(value_and_pushforward(backend, f, x, dx, extras)) end diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl index 50599b2c9..f2b14e5a4 100644 --- a/test/enzyme_forward.jl +++ b/test/enzyme_forward.jl @@ -2,4 +2,8 @@ using ADTypes: AutoEnzyme using Enzyme: Enzyme using DifferentiationInterface.DifferentiationTest -test_all_operators(AutoEnzyme(Val(:forward)), default_scenarios(); type_stability=false); +test_pushforward(AutoEnzyme(Val(:forward)), default_scenarios()); +test_derivative(AutoEnzyme(Val(:forward)), default_scenarios()); +test_multiderivative(AutoEnzyme(Val(:forward)), default_scenarios()); +test_gradient(AutoEnzyme(Val(:forward)), default_scenarios()); +test_jacobian(AutoEnzyme(Val(:forward)), default_scenarios(); type_stability=false); diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl index 0e2a19387..6b9b37eb8 100644 --- a/test/enzyme_reverse.jl +++ b/test/enzyme_reverse.jl @@ -2,4 +2,4 @@ using ADTypes: AutoEnzyme using Enzyme: Enzyme using DifferentiationInterface.DifferentiationTest -test_all_operators(AutoEnzyme(Val(:reverse)), default_scenarios(); type_stability=false); +test_all_operators(AutoEnzyme(Val(:reverse)), default_scenarios(); type_stability=true); diff --git a/test/finitediff.jl b/test/finitediff.jl index 25e6bf8bc..e8c4510f9 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -2,4 +2,8 @@ using ADTypes: AutoFiniteDiff using FiniteDiff: FiniteDiff using DifferentiationInterface.DifferentiationTest -test_all_operators(AutoFiniteDiff(), default_scenarios(); type_stability=false); +test_pushforward(AutoFiniteDiff(), default_scenarios(); type_stability=true); +test_derivative(AutoFiniteDiff(), default_scenarios(); type_stability=true); +test_multiderivative(AutoFiniteDiff(), default_scenarios(); type_stability=true); +test_gradient(AutoFiniteDiff(), default_scenarios(); type_stability=true); +test_jacobian(AutoFiniteDiff(), default_scenarios(); type_stability=false); diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index b458436d9..5dbb9daf1 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -2,8 +2,8 @@ using ADTypes: AutoForwardDiff using ForwardDiff: ForwardDiff using DifferentiationInterface.DifferentiationTest -test_pushforward(AutoForwardDiff(), default_scenarios(); type_stability=true); -test_derivative(AutoForwardDiff(), default_scenarios(); type_stability=true); -test_multiderivative(AutoForwardDiff(), default_scenarios(); type_stability=true); -test_gradient(AutoForwardDiff(), default_scenarios(); type_stability=false); -test_jacobian(AutoForwardDiff(), default_scenarios(); type_stability=false); +test_all_operators(AutoForwardDiff(; chunksize=2), default_scenarios(); type_stability=true); + +test_all_operators_mutating( + AutoForwardDiff(; chunksize=2), default_scenarios(); type_stability=true +); diff --git a/test/runtests.jl b/test/runtests.jl index a4220bff2..f7faa3da7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,18 @@ ## Imports using Aqua: Aqua +using ChainRulesCore: ChainRulesCore using DifferentiationInterface +using Enzyme: Enzyme +using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using JET: JET using JuliaFormatter: JuliaFormatter +using PolyesterForwardDiff: PolyesterForwardDiff +using ReverseDiff: ReverseDiff using Random: Random using Test +using Zygote: Zygote ## Main tests diff --git a/test/zero.jl b/test/zero.jl index 13e54b286..d9659fd6c 100644 --- a/test/zero.jl +++ b/test/zero.jl @@ -5,6 +5,14 @@ test_all_operators( AutoZeroForward(), default_scenarios(); correctness=false, type_stability=true ); +test_all_operators_mutating( + AutoZeroForward(), default_scenarios(); correctness=false, type_stability=true +); + test_all_operators( AutoZeroReverse(), default_scenarios(); correctness=false, type_stability=true ); + +test_all_operators_mutating( + AutoZeroReverse(), default_scenarios(); correctness=false, type_stability=true +);