From 9ec4e7ea72819d605f6ececcb5fad2f0412f8693 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:57:45 +0200 Subject: [PATCH] Contexts for ReverseDiff (#505) --- DifferentiationInterface/Project.toml | 2 +- .../docs/src/explanation/backends.md | 25 ++- .../DifferentiationInterfaceReverseDiffExt.jl | 13 +- .../onearg.jl | 189 ++++++++++++++++-- .../twoarg.jl | 165 +++++++++++---- .../test/Back/ReverseDiff/test.jl | 9 +- 6 files changed, 332 insertions(+), 71 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d65ec40cf..f45db0723 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.0" +version = "0.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index cb0744fed..f1edbaad6 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -62,6 +62,25 @@ In practice, many AD backends have custom implementations for high-level operato | `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ | +Moreover, each context type is supported by a specific subset of backends: + +| | [`Constant`](@ref) | +| -------------------------- | ------------------ | +| `AutoChainRules` | ✅ | +| `AutoDiffractor` | ❌ | +| `AutoEnzyme` (forward) | ✅ | +| `AutoEnzyme` (reverse) | ✅ | +| `AutoFastDifferentiation` | ❌ | +| `AutoFiniteDiff` | ✅ | +| `AutoFiniteDifferences` | ✅ | +| `AutoForwardDiff` | ✅ | +| `AutoMooncake` | ✅ | +| `AutoPolyesterForwardDiff` | ✅ | +| `AutoReverseDiff` | ✅ | +| `AutoSymbolics` | ❌ | +| `AutoTracker` | ✅ | +| `AutoZygote` | ✅ | + ## Second order For second-order operators like [`second_derivative`](@ref), [`hessian`](@ref) and [`hvp`](@ref), there are two main options. @@ -81,9 +100,9 @@ In general, using a forward outer backend over a reverse inner backend will yiel ## Backend switch The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends. -It takes a function `f` and specifies that `f` should be differentiated with the backend of your choice, instead of whatever other backend the code is trying to use. -In other words, when someone tries to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing. -At the moment, `DifferentiateWith` only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). +It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use. +In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself. +At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend. ## Implementations diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index 3646ff45c..bc691fd67 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -1,9 +1,20 @@ module DifferentiationInterfaceReverseDiffExt using ADTypes: AutoReverseDiff +using Base: Fix2 import DifferentiationInterface as DI using DifferentiationInterface: - DerivativePrep, GradientPrep, HessianPrep, JacobianPrep, NoPullbackPrep + Context, + DerivativePrep, + GradientPrep, + HessianPrep, + JacobianPrep, + NoGradientPrep, + NoHessianPrep, + NoJacobianPrep, + NoPullbackPrep, + unwrap, + with_contexts using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult using LinearAlgebra: dot, mul! using ReverseDiff: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 21ddc99f6..2cd21d07c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,48 +1,74 @@ ## Pullback -DI.prepare_pullback(f, ::AutoReverseDiff, x, ty::NTuple) = NoPullbackPrep() +function DI.prepare_pullback( + f, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{Context,C} +) where {C} + return NoPullbackPrep() +end function DI.value_and_pullback( - f, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) - y = f(x) + f, + ::NoPullbackPrep, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) + dotclosure(z, dy) = dot(fc(z), dy) tx = map(ty) do dy if y isa Number - dy .* gradient(f, x) + dy .* gradient(fc, x) elseif y isa AbstractArray - gradient(z -> dot(f(z), dy), x) + gradient(Fix2(dotclosure, dy), x) end end return y, tx end function DI.value_and_pullback!( - f, ::NoPullbackPrep, tx::NTuple, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) - y = f(x) + f, + ::NoPullbackPrep, + tx::NTuple, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) + dotclosure(z, dy) = dot(fc(z), dy) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] if y isa Number - dx = gradient!(dx, f, x) + dx = gradient!(dx, fc, x) dx .*= dy elseif y isa AbstractArray - gradient!(dx, z -> dot(f(z), dy), x) + gradient!(dx, Fix2(dotclosure, dy), x) end end return y, tx end function DI.value_and_pullback( - f, ::NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple -) + f, + ::NoPullbackPrep, + backend::AutoReverseDiff, + x::Number, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} x_array = [x] - f_array = f ∘ only - y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty) + f_array(x_array, args...) = f(only(x_array), args...) + y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty, contexts...) return y, only.(tx_array) end ## Gradient +### Without contexts + struct ReverseDiffGradientPrep{T} <: GradientPrep tape::T end @@ -56,7 +82,7 @@ function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile} end function DI.value_and_gradient!( - f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x + f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x ) y = f(x) # TODO: remove once ReverseDiff#251 is fixed result = MutableDiffResult(y, (grad,)) @@ -71,9 +97,7 @@ function DI.value_and_gradient( return DI.value_and_gradient!(f, grad, prep, backend, x) end -function DI.gradient!( - _f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray -) +function DI.gradient!(_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x) return gradient!(grad, prep.tape, x) end @@ -81,8 +105,47 @@ function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x) return gradient!(prep.tape, x) end +### With contexts + +function DI.prepare_gradient(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} + return NoGradientPrep() +end + +function DI.value_and_gradient!( + f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad,)) + result = gradient!(result, fc, x) + return DiffResults.value(result), DiffResults.derivative(result) +end + +function DI.value_and_gradient( + f, prep::NoGradientPrep, backend::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + grad = similar(x) + return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) +end + +function DI.gradient!( + f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return gradient!(grad, fc, x) +end + +function DI.gradient( + f, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return gradient(fc, x) +end + ## Jacobian +### Without contexts + struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep tape::T end @@ -116,8 +179,47 @@ function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, return jacobian!(prep.tape, x) end +### With contexts + +function DI.prepare_jacobian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} + return NoJacobianPrep() +end + +function DI.value_and_jacobian!( + f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) + result = MutableDiffResult(y, (jac,)) + result = jacobian!(result, fc, x) + return DiffResults.value(result), DiffResults.derivative(result) +end + +function DI.value_and_jacobian( + f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return fc(x), jacobian(fc, x) +end + +function DI.jacobian!( + f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return jacobian!(jac, fc, x) +end + +function DI.jacobian( + f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return jacobian(fc, x) +end + ## Hessian +### Without contexts + struct ReverseDiffHessianPrep{T} <: HessianPrep tape::T end @@ -152,11 +254,54 @@ end function DI.value_gradient_and_hessian( f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x ) - result = MutableDiffResult( - one(eltype(x)), (similar(x), similar(x, length(x), length(x))) - ) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) result = hessian!(result, prep.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) ) end + +### With contexts + +function DI.prepare_hessian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} + return NoHessianPrep() +end + +function DI.hessian!( + f, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return hessian!(hess, fc, x) +end + +function DI.hessian( + f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + return hessian(fc, x) +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad, hess)) + result = hessian!(result, fc, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end + +function DI.value_gradient_and_hessian( + f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc = with_contexts(f, contexts...) + y = fc(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = hessian!(result, fc, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 3dd289038..4b0673be7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,65 +1,99 @@ ## Pullback -DI.prepare_pullback(f!, y, ::AutoReverseDiff, x, ty::NTuple) = NoPullbackPrep() +function DI.prepare_pullback( + f!, y, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{Context,C} +) where {C} + return NoPullbackPrep() +end ### Array in function DI.value_and_pullback( - f!, y, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) + f!, + y, + ::NoPullbackPrep, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc! = with_contexts(f!, contexts...) + function dotclosure(x, dy) + y_copy = similar(y, eltype(x)) + fc!(y_copy, x) + return dot(y_copy, dy) + end tx = map(ty) do dy - function dotproduct_closure(x) - y_copy = similar(y, eltype(x)) - f!(y_copy, x) - return dot(y_copy, dy) - end - gradient(dotproduct_closure, x) + gradient(Fix2(dotclosure, dy), x) end - f!(y, x) + fc!(y, x) return y, tx end function DI.value_and_pullback!( - f!, y, tx::NTuple, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) + f!, + y, + tx::NTuple, + ::NoPullbackPrep, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc! = with_contexts(f!, contexts...) + function dotclosure(x, dy) + y_copy = similar(y, eltype(x)) + fc!(y_copy, x) + return dot(y_copy, dy) + end for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - function dotproduct_closure(x) - y_copy = similar(y, eltype(x)) - f!(y_copy, x) - return dot(y_copy, dy) - end - gradient!(dx, dotproduct_closure, x) + gradient!(dx, Fix2(dotclosure, dy), x) end - f!(y, x) + fc!(y, x) return y, tx end function DI.pullback( - f!, y, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) + f!, + y, + ::NoPullbackPrep, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc! = with_contexts(f!, contexts...) + function dotclosure(x, dy) + y_copy = similar(y, eltype(x)) + fc!(y_copy, x) + return dot(y_copy, dy) + end tx = map(ty) do dy - function dotproduct_closure(x) - y_copy = similar(y, eltype(x)) - f!(y_copy, x) - return dot(y_copy, dy) - end - gradient(dotproduct_closure, x) + gradient(Fix2(dotclosure, dy), x) end return tx end function DI.pullback!( - f!, y, tx::NTuple, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple -) + f!, + y, + tx::NTuple, + ::NoPullbackPrep, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} + fc! = with_contexts(f!, contexts...) + function dotclosure(x, dy) + y_copy = similar(y, eltype(x)) + fc!(y_copy, x) + return dot(y_copy, dy) + end for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - function dotproduct_closure(x) - y_copy = similar(y, eltype(x)) - f!(y_copy, x) - return dot(y_copy, dy) - end - gradient!(dx, dotproduct_closure, x) + gradient!(dx, Fix2(dotclosure, dy), x) end return tx end @@ -67,16 +101,26 @@ end ### Number in, not supported function DI.value_and_pullback( - f!, y, ::NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple{B} -) where {B} + f!, + y, + ::NoPullbackPrep, + backend::AutoReverseDiff, + x::Number, + ty::NTuple, + contexts::Vararg{Context,C}, +) where {C} x_array = [x] - f!_array(_y::AbstractArray, _x_array) = f!(_y, only(_x_array)) - y, tx_array = DI.value_and_pullback(f!_array, y, backend, x_array, ty) + function f!_array(_y::AbstractArray, _x_array, args...) + return f!(_y, only(_x_array), args...) + end + y, tx_array = DI.value_and_pullback(f!_array, y, backend, x_array, ty, contexts...) return y, only.(tx_array) end ## Jacobian +### Without contexts + struct ReverseDiffTwoArgJacobianPrep{T} <: JacobianPrep tape::T end @@ -117,3 +161,46 @@ function DI.jacobian!( jac = jacobian!(jac, prep.tape, x) return jac end + +### With contexts + +function DI.prepare_jacobian( + f!, y, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + return NoJacobianPrep() +end + +function DI.value_and_jacobian( + f!, y, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc! = with_contexts(f!, contexts...) + jac = similar(y, length(y), length(x)) + result = MutableDiffResult(y, (jac,)) + result = jacobian!(result, fc!, y, x) + return DiffResults.value(result), DiffResults.derivative(result) +end + +function DI.value_and_jacobian!( + f!, y, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc! = with_contexts(f!, contexts...) + result = MutableDiffResult(y, (jac,)) + result = jacobian!(result, fc!, y, x) + return DiffResults.value(result), DiffResults.derivative(result) +end + +function DI.jacobian( + f!, y, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc! = with_contexts(f!, contexts...) + jac = jacobian(fc!, y, x) + return jac +end + +function DI.jacobian!( + f!, y, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} +) where {C} + fc! = with_contexts(f!, contexts...) + jac = jacobian!(jac, fc!, y, x) + return jac +end diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index c2144562e..b492d7a28 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -2,7 +2,6 @@ using Pkg Pkg.add("ReverseDiff") using DifferentiationInterface, DifferentiationInterfaceTest -using DifferentiationInterface: AutoReverseFromPrimitive using ReverseDiff: ReverseDiff using StaticArrays: StaticArrays using Test @@ -11,13 +10,13 @@ LOGGING = get(ENV, "CI", "false") == "false" dense_backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] -fromprimitive_backends = [AutoReverseFromPrimitive(AutoReverseDiff())] - -for backend in vcat(dense_backends, fromprimitive_backends) +for backend in dense_backends @test check_available(backend) @test check_inplace(backend) end -test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING); +test_differentiation( + dense_backends, default_scenarios(; include_constantified=true); logging=LOGGING +); test_differentiation(AutoReverseDiff(), static_scenarios(); logging=LOGGING);