diff --git a/Project.toml b/Project.toml index d6b8020..adcecd9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.6.1" +version = "1.6.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 6757476..8e6d960 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -39,7 +39,7 @@ struct AutoDiffractor <: AbstractADType end mode(::AutoDiffractor) = ForwardOrReverseMode() """ - AutoEnzyme{M,constant_function} + AutoEnzyme{M} Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation. @@ -47,11 +47,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoEnzyme(; mode=nothing, constant_function::Bool=false) - -The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. -For simple functions, `constant_function` should usually be set to `true`, which leads to increased performance. -However, in the case of closures or callable structs which contain differentiated data, `constant_function` should be set to `false` to ensure correctness (more details below). + AutoEnzyme(; mode=nothing) # Fields @@ -59,53 +55,13 @@ However, in the case of closures or callable structs which contain differentiate + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically - -# Notes - -We now give several examples of functions. -For each one, we explain how `constant_function` should be set in order to compute the correct derivative with respect to the input `x`. - -```julia -function f1(x) - return x[1] -end -``` - -The function `f1` is not a closure, it does not contain any data. -Thus `f1` can be differentiated with `AutoEnzyme(constant_function=true)` (although here setting `constant_function=false` would change neither correctness nor performance). - -```julia -parameter = [0.0] -function f2(x) - return parameter[1] + x[1] -end -``` - -The function `f2` is a closure over `parameter`, but `parameter` is never modified based on the input `x`. -Thus, `f2` can be differentiated with `AutoEnzyme(constant_function=true)` (setting `constant_function=false` would not change correctness but would hinder performance). - -```julia -cache = [0.0] -function f3(x) - cache[1] = x[1] - return cache[1] + x[1] -end -``` - -The function `f3` is a closure over `cache`, and `cache` is modified based on the input `x`. -That means `cache` cannot be treated as constant, since derivative values must be propagated through it. -Thus `f3` must be differentiated with `AutoEnzyme(constant_function=false)` (setting `constant_function=true` would make the result incorrect). """ -struct AutoEnzyme{M, constant_function} <: AbstractADType +struct AutoEnzyme{M} <: AbstractADType mode::M end -function AutoEnzyme(mode::M; constant_function::Bool = false) where {M} - return AutoEnzyme{M, constant_function}(mode) -end - -function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M} - return AutoEnzyme{M, constant_function}(mode) +function AutoEnzyme(; mode::M = nothing) where {M} + return AutoEnzyme{M}(mode) end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension diff --git a/test/dense.jl b/test/dense.jl index 739cf59..7554565 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -28,25 +28,25 @@ end @testset "AutoEnzyme" begin ad = AutoEnzyme() @test ad isa AbstractADType - @test ad isa AutoEnzyme{Nothing, false} + @test ad isa AutoEnzyme{Nothing} @test mode(ad) isa ForwardOrReverseMode @test ad.mode === nothing - ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true) + ad = AutoEnzyme(EnzymeCore.Forward) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} @test mode(ad) isa ForwardMode @test ad.mode == EnzymeCore.Forward ad = AutoEnzyme(; mode = EnzymeCore.Forward) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} @test mode(ad) isa ForwardMode @test ad.mode == EnzymeCore.Forward - ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true) + ad = AutoEnzyme(; mode = EnzymeCore.Reverse) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)} @test mode(ad) isa ReverseMode @test ad.mode == EnzymeCore.Reverse end