From e65358d08fa0253f7e81b09a74572ab3cfcf7d59 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 22 Jul 2024 19:18:06 +0200 Subject: [PATCH 1/9] Handle `AutoEnzyme(constant_function=true/false)` --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceEnzymeExt.jl | 12 +- .../forward_onearg.jl | 54 ++++++--- .../forward_twoarg.jl | 12 +- .../reverse_onearg.jl | 105 ++++++------------ .../reverse_twoarg.jl | 14 ++- .../test/Back/Enzyme/test.jl | 14 ++- 7 files changed, 114 insertions(+), 99 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 9d5d80f5e..646141687 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.5.0" +ADTypes = "1.6.1" ChainRulesCore = "1.23.0" Compat = "3,4" Diffractor = "=0.2.6" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 5095a36b4..82cfb6389 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -38,15 +38,21 @@ using Enzyme: make_zero, make_zero! -struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType +const CONSTANT_FUNCTION_ERROR = """`AutoEnzyme(constant_function=false)` is not yet supported by DifferentiationInterface. For the time being, please use `AutoEnzyme(constant_function=true)` and avoid closures containing differentiable data.""" + +struct AutoDeferredEnzyme{M,constant_function} <: ADTypes.AbstractADType mode::M end ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode)) -DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode) +function DI.nested(backend::AutoEnzyme{M,constant_function}) where {M} + return AutoDeferredEnzyme{M,constant_function}(backend.mode) +end -const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}} +const AnyAutoEnzyme{M,constant_function} = Union{ + AutoEnzyme{M,constant_function},AutoDeferredEnzyme{M,constant_function} +} # forward mode if possible forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index e5d86f9a8..fa1aba60a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,11 +1,23 @@ ## Pushforward -function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) +function DI.prepare_pushforward( + f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx +) return NoPushforwardExtras() end +function DI.prepare_pushforward( + f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},false}, x, dx +) + throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) +end + function DI.value_and_pushforward( - f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras + f, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + x, + dx, + ::NoPushforwardExtras, ) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) @@ -18,7 +30,11 @@ function DI.value_and_pushforward( end function DI.pushforward( - f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras + f, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + x, + dx, + ::NoPushforwardExtras, ) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) @@ -33,7 +49,7 @@ end function DI.value_and_pushforward!( f, dy, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx, extras::NoPushforwardExtras, @@ -46,7 +62,7 @@ end function DI.pushforward!( f, dy, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx, extras::NoPushforwardExtras, @@ -61,34 +77,42 @@ struct EnzymeForwardGradientExtras{B,O} <: GradientExtras shadow::O end -function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode}, x) +function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode,true}, x) B = pick_batchsize(backend, length(x)) shadow = chunkedonehot(x, Val(B)) return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow) end function DI.gradient( - f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, backend::AutoEnzyme{<:ForwardMode,true}, x, extras::EnzymeForwardGradientExtras{B} ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return reshape(collect(grad_tup), size(x)) end function DI.value_and_gradient( - f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras + f, backend::AutoEnzyme{<:ForwardMode,true}, x, extras::EnzymeForwardGradientExtras ) return f(x), DI.gradient(f, backend, x, extras) end function DI.gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, + grad, + backend::AutoEnzyme{<:ForwardMode,true}, + x, + extras::EnzymeForwardGradientExtras{B}, ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return copyto!(grad, grad_tup) end function DI.value_and_gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} + f, + grad, + backend::AutoEnzyme{<:ForwardMode,true}, + x, + extras::EnzymeForwardGradientExtras{B}, ) where {B} grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return f(x), copyto!(grad, grad_tup) @@ -100,7 +124,7 @@ struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras shadow::O end -function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x) +function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x) B = pick_batchsize(backend, length(x)) shadow = chunkedonehot(x, Val(B)) return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow) @@ -108,7 +132,7 @@ end function DI.jacobian( f, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras{B}, ) where {B} @@ -120,7 +144,7 @@ end function DI.value_and_jacobian( f, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) @@ -130,7 +154,7 @@ end function DI.jacobian!( f, jac, - backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) @@ -140,7 +164,7 @@ end function DI.value_and_jacobian!( f, jac, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, extras::EnzymeForwardOneArgJacobianExtras, ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index ed242c50b..f224fa8f8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,13 +1,21 @@ ## Pushforward -function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) +function DI.prepare_pushforward( + f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx +) return NoPushforwardExtras() end +function DI.prepare_pushforward( + f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},false}, x, dx +) + throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) +end + function DI.value_and_pushforward( f!, y, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx, ::NoPushforwardExtras, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 071846ae7..91020ac6e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -1,14 +1,18 @@ ## Pullback -function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) +function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy) return NoPullbackExtras() end +function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},false}, x, dy) + throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) +end + ### Out-of-place function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x::Number, dy::Number, ::NoPullbackExtras, @@ -24,7 +28,7 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x::Number, dy, ::NoPullbackExtras, @@ -39,7 +43,7 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy::Number, ::NoPullbackExtras, @@ -59,14 +63,22 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras + f, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + x, + dy, + extras::NoPullbackExtras, ) dx = make_zero(x) return DI.value_and_pullback!(f, dx, backend, x, dy, extras) end function DI.pullback( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras + f, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + x, + dy, + extras::NoPullbackExtras, ) return DI.value_and_pullback(f, backend, x, dy, extras)[2] end @@ -76,7 +88,7 @@ end function DI.value_and_pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy::Number, ::NoPullbackExtras, @@ -97,7 +109,12 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - f, dx, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, ::NoPullbackExtras + f, + dx, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + x, + dy, + ::NoPullbackExtras, ) tf, tx = typeof(f), typeof(x) forw, rev = autodiff_thunk( @@ -114,7 +131,7 @@ end function DI.pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy, extras::NoPullbackExtras, @@ -124,12 +141,12 @@ end ## Gradient -function DI.prepare_gradient(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x) +function DI.prepare_gradient(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x) return NoGradientExtras() end function DI.gradient( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, ::NoGradientExtras ) if backend isa AutoDeferredEnzyme grad = make_zero(x) @@ -143,7 +160,7 @@ end function DI.gradient!( f, grad, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, extras::NoGradientExtras, ) @@ -158,13 +175,17 @@ function DI.gradient!( end function DI.value_and_gradient( - f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, ::NoGradientExtras ) return DI.value_and_pullback(f, backend, x, true, NoPullbackExtras()) end function DI.value_and_gradient!( - f, grad, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ::NoGradientExtras + f, + grad, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + x, + ::NoGradientExtras, ) return DI.value_and_pullback!(f, grad, backend, x, true, NoPullbackExtras()) end @@ -172,59 +193,3 @@ end ## Jacobian # see https://github.com/EnzymeAD/Enzyme.jl/issues/1391 - -#= - -struct EnzymeReverseOneArgJacobianExtras{B,N} end - -function DI.prepare_jacobian(f, backend::AutoReverseEnzyme, x) - B = pick_batchsize(backend, length(x)) - y = f(x) - N = length(y) - return EnzymeReverseOneArgJacobianExtras{B,N}() -end - -function DI.jacobian( - f, - backend::AutoReverseEnzyme, - x::AbstractArray, - ::EnzymeReverseOneArgJacobianExtras{C,N}, -) where {B,N} - jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(N), Val(B)) - nx = length(x) - ny = length(jac_wrongshape) รท length(x) - jac_rightshape = reshape(jac_wrongshape, ny, nx) - return jac_rightshape -end - -function DI.value_and_jacobian( - f, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - return f(x), DI.jacobian(f, backend, x, extras) -end - -function DI.jacobian!( - f, - jac, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - return copyto!(jac, DI.jacobian(f, backend, x, extras)) -end - -function DI.value_and_jacobian!( - f, - jac, - backend::AutoReverseEnzyme, - x::AbstractArray, - extras::EnzymeReverseOneArgJacobianExtras, -) - y, new_jac = DI.value_and_jacobian(f, backend, x, extras) - return y, copyto!(jac, new_jac) -end - -=# diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 4648e7b6b..e2bb9ba10 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,13 +1,21 @@ ## Pullback -function DI.prepare_pullback(f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) +function DI.prepare_pullback( + f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy +) return NoPullbackExtras() end +function DI.prepare_pullback( + f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},false}, x, dy +) + throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) +end + function DI.value_and_pullback( f!, y, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x::Number, dy, ::NoPullbackExtras, @@ -25,7 +33,7 @@ end function DI.value_and_pullback( f!, y, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x::AbstractArray, dy, ::NoPullbackExtras, diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 9a860a603..2c79b9c74 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -6,14 +6,18 @@ using StableRNGs using Test dense_backends = [ - AutoEnzyme(; mode=nothing), - AutoEnzyme(; mode=Enzyme.Forward), - AutoEnzyme(; mode=Enzyme.Reverse), + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), ] nested_dense_backends = [ - DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Forward)), - DifferentiationInterface.nested(AutoEnzyme(; mode=Enzyme.Reverse)), + DifferentiationInterface.nested( + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true) + ), + DifferentiationInterface.nested( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true) + ), ] sparse_backends = From d589db4beeb2687eb66db812f2a2e9150afe8438 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 22 Jul 2024 20:14:02 +0200 Subject: [PATCH 2/9] Add make_closure --- .../DifferentiationInterfaceEnzymeExt.jl | 2 +- .../test/Back/Enzyme/test.jl | 8 +++++++ .../src/DifferentiationInterfaceTest.jl | 1 + .../src/scenarios/modify.jl | 22 +++++++++++++++++++ 4 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 DifferentiationInterfaceTest/src/scenarios/modify.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 82cfb6389..113ac05ff 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -46,7 +46,7 @@ end ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode)) -function DI.nested(backend::AutoEnzyme{M,constant_function}) where {M} +function DI.nested(backend::AutoEnzyme{M,constant_function}) where {M,constant_function} return AutoDeferredEnzyme{M,constant_function}(backend.mode) end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 2c79b9c74..866ca74e2 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -1,5 +1,6 @@ using ADTypes: ADTypes using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT using Enzyme: Enzyme using SparseConnectivityTracer, SparseMatrixColorings using StableRNGs @@ -44,6 +45,13 @@ test_differentiation( logging=LOGGING, ); +test_differentiation( + AutoEnzyme(; constant_function=true), + DIT.make_closure.(default_scenarios()); + second_order=false, + logging=LOGGING, +); # all of these should fail? + test_differentiation( [ AutoEnzyme(; mode=nothing), diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 41eeb320d..4c7ef0fc2 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -76,6 +76,7 @@ include("scenarios/default.jl") include("scenarios/sparse.jl") include("scenarios/allocfree.jl") include("scenarios/extensions.jl") +include("scenarios/modify.jl") include("utils/zero_backends.jl") include("utils/misc.jl") diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl new file mode 100644 index 000000000..df2f299ac --- /dev/null +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -0,0 +1,22 @@ +""" + make_closure(scen::Scenario) + +Return a new [`Scenario`](@ref) with a modified function `f` or `f!` that closes over differentiable data. +""" +function make_closure(scen::Scenario) + closed_data = Ref(zero(scen.y)) + if nb_args(scen) == 1 + function closure_f(x) + closed_data[] = scen.f(x) + return copy(closed_data[]) + end + return change_function(scen, closure_f) + elseif nb_args(scen) == 2 + function closure_f!(y, x) + scen.f(closed_data[], x) + copyto!(y, closed_data[]) + return nothing + end + return change_function(scen, closure_f!) + end +end From 5b4c5b1672e80c828518675676263b495a76d424 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:59:45 +0200 Subject: [PATCH 3/9] Handle forward mod --- .../forward_onearg.jl | 42 ++++++++++--------- .../forward_twoarg.jl | 24 +++++------ .../test/Back/Enzyme/test.jl | 4 +- .../src/scenarios/modify.jl | 41 +++++++++++------- 4 files changed, 62 insertions(+), 49 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index fa1aba60a..8c0d34a74 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,47 +1,51 @@ ## Pushforward -function DI.prepare_pushforward( - f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx -) +function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) return NoPushforwardExtras() end -function DI.prepare_pushforward( - f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},false}, x, dx -) - throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) -end - function DI.value_and_pushforward( f, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, x, dx, ::NoPushforwardExtras, -) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) y, new_dy = if backend isa AutoDeferredEnzyme - autodiff_deferred(forward_mode(backend), f, Duplicated, x_and_dx) + autodiff_deferred(forward_mode(backend), f_and_df, Duplicated, x_and_dx) else - autodiff(forward_mode(backend), Const(f), Duplicated, x_and_dx) + autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx) end return y, new_dy end function DI.pushforward( f, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, x, dx, ::NoPushforwardExtras, -) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) new_dy = if backend isa AutoDeferredEnzyme - only(autodiff_deferred(forward_mode(backend), f, DuplicatedNoNeed, x_and_dx)) + only(autodiff_deferred(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx)) else - only(autodiff(forward_mode(backend), Const(f), DuplicatedNoNeed, x_and_dx)) + only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx)) end return new_dy end @@ -49,7 +53,7 @@ end function DI.value_and_pushforward!( f, dy, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, extras::NoPushforwardExtras, @@ -62,7 +66,7 @@ end function DI.pushforward!( f, dy, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, extras::NoPushforwardExtras, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index f224fa8f8..591eb1a32 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,33 +1,31 @@ ## Pushforward -function DI.prepare_pushforward( - f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, x, dx -) +function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx) return NoPushforwardExtras() end -function DI.prepare_pushforward( - f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},false}, x, dx -) - throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) -end - function DI.value_and_pushforward( f!, y, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, x, dx, ::NoPushforwardExtras, -) +) where {constant_function} + f!_and_df! = if constant_function + Const(f!) + else + df! = make_zero(f!) + Duplicated(f!, df!) + end dx_sametype = convert(typeof(x), dx) dy_sametype = make_zero(y) y_and_dy = Duplicated(y, dy_sametype) x_and_dx = Duplicated(x, dx_sametype) if backend isa AutoDeferredEnzyme - autodiff_deferred(forward_mode(backend), f!, Const, y_and_dy, x_and_dx) + autodiff_deferred(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) else - autodiff(forward_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) end return y, dy_sametype end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 866ca74e2..31d4ee560 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -46,11 +46,11 @@ test_differentiation( ); test_differentiation( - AutoEnzyme(; constant_function=true), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=false), DIT.make_closure.(default_scenarios()); second_order=false, logging=LOGGING, -); # all of these should fail? +); test_differentiation( [ diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index df2f299ac..e75518390 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -1,22 +1,33 @@ +struct MyClosure{args,F,X,Y} + f::F + x_buffer::Vector{X} + y_buffer::Vector{Y} +end + +function (mc::MyClosure{1})(x) + mc.x_buffer[1] = x + mc.y_buffer[1] = mc.f(x) + return copy(mc.y_buffer[1]) +end + +function (mc::MyClosure{2})(y, x) + mc.x_buffer[1] = x + mc.f(mc.y_buffer[1], mc.x_buffer[1]) + copyto!(y, mc.y_buffer[1]) + return nothing +end + """ make_closure(scen::Scenario) Return a new [`Scenario`](@ref) with a modified function `f` or `f!` that closes over differentiable data. """ function make_closure(scen::Scenario) - closed_data = Ref(zero(scen.y)) - if nb_args(scen) == 1 - function closure_f(x) - closed_data[] = scen.f(x) - return copy(closed_data[]) - end - return change_function(scen, closure_f) - elseif nb_args(scen) == 2 - function closure_f!(y, x) - scen.f(closed_data[], x) - copyto!(y, closed_data[]) - return nothing - end - return change_function(scen, closure_f!) - end + (; f, x, y) = scen + x_buffer = [zero(x)] + y_buffer = [zero(y)] + closure_f = MyClosure{nb_args(scen),typeof(f),typeof(x),typeof(y)}( + f, x_buffer, y_buffer + ) + return change_function(scen, closure_f) end From b436d77c932024e16d49dee8e505b0651ff68ca3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 23 Jul 2024 11:01:47 +0200 Subject: [PATCH 4/9] Compat --- DifferentiationInterfaceTest/src/scenarios/modify.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index e75518390..db939ccfb 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -23,7 +23,7 @@ end Return a new [`Scenario`](@ref) with a modified function `f` or `f!` that closes over differentiable data. """ function make_closure(scen::Scenario) - (; f, x, y) = scen + @compat (; f, x, y) = scen x_buffer = [zero(x)] y_buffer = [zero(y)] closure_f = MyClosure{nb_args(scen),typeof(f),typeof(x),typeof(y)}( From 2701b00c67c494e575eede3bc5b62ffa58dfcd2e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 23 Jul 2024 11:12:13 +0200 Subject: [PATCH 5/9] Add reverse mod --- .../reverse_onearg.jl | 103 +++++++++++------- .../reverse_twoarg.jl | 34 +++--- .../test/Back/Enzyme/test.jl | 5 +- 3 files changed, 84 insertions(+), 58 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 91020ac6e..017b697bd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -1,26 +1,28 @@ ## Pullback -function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy) +function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) return NoPullbackExtras() end -function DI.prepare_pullback(f, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},false}, x, dy) - throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) -end - ### Out-of-place function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x::Number, dy::Number, ::NoPullbackExtras, -) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end der, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, f, Active, Active(x)) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, Active(x)) else - autodiff(ReverseWithPrimal, Const(f), Active, Active(x)) + autodiff(ReverseWithPrimal, f_and_df, Active, Active(x)) end new_dx = dy * only(der) return y, new_dx @@ -28,32 +30,45 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x::Number, dy, ::NoPullbackExtras, -) - tf, tx = typeof(f), typeof(x) - forw, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{tf}, Duplicated, Active{tx}) - tape, y, new_dy = forw(Const(f), Active(x)) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end + forw, rev = autodiff_thunk( + ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(Active(x)) + ) + tape, y, new_dy = forw(f_and_df, Active(x)) copyto!(new_dy, dy) - new_dx = only(only(rev(Const(f), Active(x), tape))) + new_dx = only(only(rev(f_and_df, Active(x), tape))) return y, new_dx end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x, dy::Number, ::NoPullbackExtras, -) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end dx_sametype = make_zero(x) x_and_dx = Duplicated(x, dx_sametype) _, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, x_and_dx) else - autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff(ReverseWithPrimal, f_and_df, Active, x_and_dx) end if !isone(dy) # TODO: generalize beyond Arrays? @@ -63,22 +78,14 @@ function DI.value_and_pullback( end function DI.value_and_pullback( - f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, - x, - dy, - extras::NoPullbackExtras, + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras ) dx = make_zero(x) return DI.value_and_pullback!(f, dx, backend, x, dy, extras) end function DI.pullback( - f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, - x, - dy, - extras::NoPullbackExtras, + f, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras ) return DI.value_and_pullback(f, backend, x, dy, extras)[2] end @@ -88,18 +95,24 @@ end function DI.value_and_pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x, dy::Number, ::NoPullbackExtras, -) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) x_and_dx = Duplicated(x, dx_sametype) _, y = if backend isa AutoDeferredEnzyme - autodiff_deferred(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff_deferred(ReverseWithPrimal, f_and_df, Active, x_and_dx) else - autodiff(ReverseWithPrimal, Const(f), Active, x_and_dx) + autodiff(ReverseWithPrimal, f_and_df, Active, x_and_dx) end if !isone(dy) # TODO: generalize beyond Arrays? @@ -111,27 +124,33 @@ end function DI.value_and_pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x, dy, ::NoPullbackExtras, -) - tf, tx = typeof(f), typeof(x) - forw, rev = autodiff_thunk( - ReverseSplitWithPrimal, Const{tf}, Duplicated, Duplicated{tx} - ) +) where {constant_function} + f_and_df = if constant_function + Const(f) + else + df = make_zero(f) + Duplicated(f, df) + end dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) - tape, y, new_dy = forw(Const(f), Duplicated(x, dx_sametype)) + x_and_dx = Duplicated(x, dx_sametype) + forw, rev = autodiff_thunk( + ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(x_and_dx) + ) + tape, y, new_dy = forw(f_and_df, x_and_dx) copyto!(new_dy, dy) - rev(Const(f), Duplicated(x, dx_sametype), tape) + rev(f_and_df, x_and_dx, tape) return y, copyto!(dx, dx_sametype) end function DI.pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, extras::NoPullbackExtras, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index e2bb9ba10..70273474c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,31 +1,29 @@ ## Pullback -function DI.prepare_pullback( - f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, x, dy -) +function DI.prepare_pullback(f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy) return NoPullbackExtras() end -function DI.prepare_pullback( - f!, y, ::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},false}, x, dy -) - throw(ArgumentError(CONSTANT_FUNCTION_ERROR)) -end - function DI.value_and_pullback( f!, y, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, x::Number, dy, ::NoPullbackExtras, -) +) where {constant_function} + f!_and_df! = if constant_function + Const(f!) + else + df! = make_zero(f!) + Duplicated(f!, df!) + end dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) _, new_dx = if backend isa AutoDeferredEnzyme - only(autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, Active(x))) + only(autodiff_deferred(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x))) else - only(autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, Active(x))) + only(autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x))) end return y, new_dx end @@ -38,14 +36,20 @@ function DI.value_and_pullback( dy, ::NoPullbackExtras, ) + f!_and_df! = if constant_function + Const(f!) + else + df! = make_zero(f!) + Duplicated(f!, df!) + end dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) x_and_dx = Duplicated(x, dx_sametype) if backend isa AutoDeferredEnzyme - autodiff_deferred(reverse_mode(backend), f!, Const, y_and_dy, x_and_dx) + autodiff_deferred(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) else - autodiff(reverse_mode(backend), Const(f!), Const, y_and_dy, x_and_dx) + autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx) end return y, dx_sametype end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 31d4ee560..3a1acbbdd 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -46,7 +46,10 @@ test_differentiation( ); test_differentiation( - AutoEnzyme(; mode=Enzyme.Forward, constant_function=false), + [ + AutoEnzyme(; mode=Enzyme.Forward, constant_function=false), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=false), + ], DIT.make_closure.(default_scenarios()); second_order=false, logging=LOGGING, From 09b1f24daf3b9b7c40b9e03f0d98d1312317def3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:00:09 +0200 Subject: [PATCH 6/9] More stuff --- DifferentiationInterface/docs/src/backends.md | 2 +- .../docs/src/tutorial1.md | 2 +- .../DifferentiationInterfaceEnzymeExt.jl | 9 +++ .../forward_onearg.jl | 30 ++-------- .../forward_twoarg.jl | 11 +--- .../reverse_onearg.jl | 60 +++++-------------- .../reverse_twoarg.jl | 18 ++---- .../test/Back/Enzyme/test.jl | 21 +++++-- .../test/Back/SecondOrder/test.jl | 8 ++- .../test/Down/Detector/detector.jl | 2 +- .../docs/src/tutorial.md | 2 +- 11 files changed, 62 insertions(+), 103 deletions(-) diff --git a/DifferentiationInterface/docs/src/backends.md b/DifferentiationInterface/docs/src/backends.md index 408b90700..9820085b4 100644 --- a/DifferentiationInterface/docs/src/backends.md +++ b/DifferentiationInterface/docs/src/backends.md @@ -57,7 +57,7 @@ import Zygote backend_examples = [ AutoDiffractor(), - AutoEnzyme(), + AutoEnzyme(; constant_function=true), AutoFastDifferentiation(), AutoFiniteDiff(), AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), diff --git a/DifferentiationInterface/docs/src/tutorial1.md b/DifferentiationInterface/docs/src/tutorial1.md index 0e3cd1205..d3c26c52c 100644 --- a/DifferentiationInterface/docs/src/tutorial1.md +++ b/DifferentiationInterface/docs/src/tutorial1.md @@ -116,7 +116,7 @@ Typically, for gradients, reverse mode AD might be a better fit, so let's try th ```@example tuto1 import Enzyme -backend2 = AutoEnzyme() +backend2 = AutoEnzyme(constant_function=true) ``` Once the backend is created, things run smoothly with exactly the same syntax as before: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 113ac05ff..abaf68a5d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -74,6 +74,15 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T return b end +function get_f_and_df(f, ::AnyAutoEnzyme{M,true}) where {M} + return Const(f) +end + +function get_f_and_df(f, ::AnyAutoEnzyme{M,false}) where {M} + df = make_zero(f) + return Duplicated(f, df) +end + include("forward_onearg.jl") include("forward_twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 8c0d34a74..395a1f0a0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -5,18 +5,9 @@ function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} end function DI.value_and_pushforward( - f, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, - x, - dx, - ::NoPushforwardExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end + f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras +) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) y, new_dy = if backend isa AutoDeferredEnzyme @@ -28,18 +19,9 @@ function DI.value_and_pushforward( end function DI.pushforward( - f, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, - x, - dx, - ::NoPushforwardExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end + f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras +) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) x_and_dx = Duplicated(x, dx_sametype) new_dy = if backend isa AutoDeferredEnzyme diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 591eb1a32..e05cb273f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -7,17 +7,12 @@ end function DI.value_and_pushforward( f!, y, - backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras, -) where {constant_function} - f!_and_df! = if constant_function - Const(f!) - else - df! = make_zero(f!) - Duplicated(f!, df!) - end +) + f!_and_df! = get_f_and_df(f!, backend) dx_sametype = convert(typeof(x), dx) dy_sametype = make_zero(y) y_and_dy = Duplicated(y, dy_sametype) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 017b697bd..35ed56113 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -8,17 +8,12 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x::Number, dy::Number, ::NoPullbackExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end +) + f_and_df = get_f_and_df(f, backend) der, y = if backend isa AutoDeferredEnzyme autodiff_deferred(ReverseWithPrimal, f_and_df, Active, Active(x)) else @@ -30,17 +25,12 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x::Number, dy, ::NoPullbackExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end +) + f_and_df = get_f_and_df(f, backend) forw, rev = autodiff_thunk( ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(Active(x)) ) @@ -52,17 +42,12 @@ end function DI.value_and_pullback( f, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy::Number, ::NoPullbackExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end +) + f_and_df = get_f_and_df(f, backend) dx_sametype = make_zero(x) x_and_dx = Duplicated(x, dx_sametype) _, y = if backend isa AutoDeferredEnzyme @@ -95,17 +80,12 @@ end function DI.value_and_pullback!( f, dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy::Number, ::NoPullbackExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end +) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) x_and_dx = Duplicated(x, dx_sametype) @@ -122,19 +102,9 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - f, - dx, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, - x, - dy, - ::NoPullbackExtras, -) where {constant_function} - f_and_df = if constant_function - Const(f) - else - df = make_zero(f) - Duplicated(f, df) - end + f, dx, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, ::NoPullbackExtras +) + f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), dx) make_zero!(dx_sametype) x_and_dx = Duplicated(x, dx_sametype) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 70273474c..2b5c5582d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -7,17 +7,12 @@ end function DI.value_and_pullback( f!, y, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},constant_function}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x::Number, dy, ::NoPullbackExtras, -) where {constant_function} - f!_and_df! = if constant_function - Const(f!) - else - df! = make_zero(f!) - Duplicated(f!, df!) - end +) + f!_and_df! = get_f_and_df(f!, backend) dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) _, new_dx = if backend isa AutoDeferredEnzyme @@ -36,12 +31,7 @@ function DI.value_and_pullback( dy, ::NoPullbackExtras, ) - f!_and_df! = if constant_function - Const(f!) - else - df! = make_zero(f!) - Duplicated(f!, df!) - end + f!_and_df! = get_f_and_df(f!, backend) dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(dy)) y_and_dy = Duplicated(y, dy_sametype) diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 3a1acbbdd..878cffec3 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -57,10 +57,16 @@ test_differentiation( test_differentiation( [ - AutoEnzyme(; mode=nothing), - AutoEnzyme(; mode=Enzyme.Reverse), - SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Reverse)), - SecondOrder(AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)), + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + ), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), + ), ]; first_order=false, excluded=[:second_derivative], @@ -68,14 +74,17 @@ test_differentiation( ); test_differentiation( - [AutoEnzyme(; mode=nothing), AutoEnzyme(; mode=Enzyme.Forward)]; + [ + AutoEnzyme(; mode=nothing, constant_function=true), + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true), + ]; first_order=false, excluded=[:hessian, :hvp], logging=LOGGING, ); test_differentiation( - AutoEnzyme(; mode=Enzyme.Forward); # TODO: add more + AutoEnzyme(; mode=Enzyme.Forward, constant_function=true); # TODO: add more correctness=false, type_stability=true, second_order=false, diff --git a/DifferentiationInterface/test/Back/SecondOrder/test.jl b/DifferentiationInterface/test/Back/SecondOrder/test.jl index 413d13c21..2ccf530c1 100644 --- a/DifferentiationInterface/test/Back/SecondOrder/test.jl +++ b/DifferentiationInterface/test/Back/SecondOrder/test.jl @@ -16,8 +16,12 @@ onearg_backends = [ ] twoarg_backends = [ - SecondOrder(AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward)), - SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff()), + SecondOrder( + AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Forward, constant_function=true) + ), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), AutoForwardDiff() + ), ] for backend in vcat(onearg_backends, twoarg_backends) diff --git a/DifferentiationInterface/test/Down/Detector/detector.jl b/DifferentiationInterface/test/Down/Detector/detector.jl index 2125637bc..892c906b8 100644 --- a/DifferentiationInterface/test/Down/Detector/detector.jl +++ b/DifferentiationInterface/test/Down/Detector/detector.jl @@ -24,7 +24,7 @@ g(x::AbstractVector) = dot(x, Hc, x) g(x::AbstractMatrix) = g(vec(x)) @testset verbose = true "$(typeof(backend))" for backend in [ - AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff() + AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true), AutoForwardDiff() ] @test_throws ArgumentError DenseSparsityDetector(backend; atol=1e-5, method=:random) @testset "$method" for method in (:iterative, :direct) diff --git a/DifferentiationInterfaceTest/docs/src/tutorial.md b/DifferentiationInterfaceTest/docs/src/tutorial.md index 6ab5f086d..8c88ebbdc 100644 --- a/DifferentiationInterfaceTest/docs/src/tutorial.md +++ b/DifferentiationInterfaceTest/docs/src/tutorial.md @@ -12,7 +12,7 @@ import ForwardDiff, Enzyme The AD backends we want to compare are [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). ```@example tuto -backends = [AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Reverse)] +backends = [AutoForwardDiff(), AutoEnzyme(; mode=Enzyme.Reverse, constant_function=true)] ``` To do that, we are going to take gradients of a simple function: From ea09bbaa079da251aba1d78c8f4cc3d50e59ae15 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 24 Jul 2024 08:55:03 +0200 Subject: [PATCH 7/9] Fix --- .../ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 2b5c5582d..c6c93651e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -26,7 +26,7 @@ end function DI.value_and_pullback( f!, y, - backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},true}, + backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x::AbstractArray, dy, ::NoPullbackExtras, From aa329cb50f0e2695e4470ae3634ef73b1ad303ce Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:34:25 +0200 Subject: [PATCH 8/9] Add test --- .../DifferentiationInterfaceEnzymeExt.jl | 2 -- DifferentiationInterfaceTest/test/weird.jl | 7 +++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index abaf68a5d..3775a8ae6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -38,8 +38,6 @@ using Enzyme: make_zero, make_zero! -const CONSTANT_FUNCTION_ERROR = """`AutoEnzyme(constant_function=false)` is not yet supported by DifferentiationInterface. For the time being, please use `AutoEnzyme(constant_function=true)` and avoid closures containing differentiable data.""" - struct AutoDeferredEnzyme{M,constant_function} <: ADTypes.AbstractADType mode::M end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 6171e3a11..e067bc55d 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -20,6 +20,13 @@ test_differentiation( logging=LOGGING, ) +test_differentiation( + AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), + DIT.make_closure.(default_scenarios()); + second_order=false, + logging=LOGGING, +); + test_differentiation( AutoZygote(), gpu_scenarios(); correctness=true, second_order=false, logging=LOGGING ) From 9641c7c7e3d68fd947c2f5b02e1532fd90aabd77 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:25:34 +0200 Subject: [PATCH 9/9] Better test in DIT --- DifferentiationInterfaceTest/Project.toml | 3 ++- DifferentiationInterfaceTest/test/weird.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index c12be09e5..9571ffe49 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -61,6 +61,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -77,4 +78,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index e067bc55d..ba86d363b 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -3,6 +3,7 @@ using ComponentArrays: ComponentArrays using DifferentiationInterface using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT +using FiniteDiff: FiniteDiff using FiniteDifferences: FiniteDifferences using Flux: Flux using ForwardDiff: ForwardDiff @@ -21,7 +22,7 @@ test_differentiation( ) test_differentiation( - AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), + AutoFiniteDiff(), DIT.make_closure.(default_scenarios()); second_order=false, logging=LOGGING,