Skip to content

Commit 39da305

Browse files
Merge pull request #72 from SciML/gd/constant_f_enzyme
Add constant_function kwarg to AutoEnzyme
2 parents 97d5146 + 091d3b6 commit 39da305

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.5.4"
6+
version = "1.6.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

+42-4
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,63 @@ struct AutoDiffractor <: AbstractADType end
3939
mode(::AutoDiffractor) = ForwardOrReverseMode()
4040

4141
"""
42-
AutoEnzyme{M}
42+
AutoEnzyme{M,constant_function}
4343
4444
Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
4545
4646
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4747
4848
# Constructors
4949
50-
AutoEnzyme(; mode=nothing)
50+
AutoEnzyme(; mode=nothing, constant_function::Bool=false)
51+
52+
The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
53+
For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below).
5154
5255
# Fields
5356
5457
- `mode::M`: can be either
5558
5659
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
5760
+ `nothing` to choose the best mode automatically
61+
62+
# Notes
63+
64+
If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values.
65+
An example of such a function is:
66+
67+
```julia
68+
cache = [0.0]
69+
function f(x)
70+
cache[1] = x[1]^2
71+
cache[1] + x[1]
72+
end
73+
```
74+
75+
In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input.
76+
Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant.
77+
78+
Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`:
79+
80+
```julia
81+
parameter = [0.0]
82+
function f(x)
83+
parameter[1] + x[1]
84+
end
85+
```
86+
87+
In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant.
5888
"""
59-
Base.@kwdef struct AutoEnzyme{M} <: AbstractADType
60-
mode::M = nothing
89+
struct AutoEnzyme{M, constant_function} <: AbstractADType
90+
mode::M
91+
end
92+
93+
function AutoEnzyme(mode::M; constant_function::Bool = false) where {M}
94+
return AutoEnzyme{M, constant_function}(mode)
95+
end
96+
97+
function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M}
98+
return AutoEnzyme{M, constant_function}(mode)
6199
end
62100

63101
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension

test/dense.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,25 @@ end
2828
@testset "AutoEnzyme" begin
2929
ad = AutoEnzyme()
3030
@test ad isa AbstractADType
31-
@test ad isa AutoEnzyme{Nothing}
31+
@test ad isa AutoEnzyme{Nothing, false}
3232
@test mode(ad) isa ForwardOrReverseMode
3333
@test ad.mode === nothing
3434

35+
ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true)
36+
@test ad isa AbstractADType
37+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true}
38+
@test mode(ad) isa ForwardMode
39+
@test ad.mode == EnzymeCore.Forward
40+
3541
ad = AutoEnzyme(; mode = EnzymeCore.Forward)
3642
@test ad isa AbstractADType
37-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
43+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false}
3844
@test mode(ad) isa ForwardMode
3945
@test ad.mode == EnzymeCore.Forward
4046

41-
ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
47+
ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true)
4248
@test ad isa AbstractADType
43-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
49+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true}
4450
@test mode(ad) isa ReverseMode
4551
@test ad.mode == EnzymeCore.Reverse
4652
end

test/misc.jl

-5
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ end
2121
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
2222
end
2323

24-
import ADTypes
25-
26-
struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
27-
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end
28-
2924
for backend in [
3025
# dense
3126
ADTypes.AutoChainRules(; ruleconfig = :rc),

0 commit comments

Comments
 (0)