Skip to content

Commit 024ac94

Browse files
Merge pull request #73 from SciML/gd/fix_enzyme_doc
Fix AutoEnzyme docstring for constant_function
2 parents 39da305 + 6e3a00e commit 024ac94

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.6.0"
6+
version = "1.6.1"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

+24-14
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
5050
AutoEnzyme(; mode=nothing, constant_function::Bool=false)
5151
5252
The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
53-
For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below).
53+
For simple functions, `constant_function` should usually be set to `true`, which leads to increased performance.
54+
However, in the case of closures or callable structs which contain differentiated data, `constant_function` should be set to `false` to ensure correctness (more details below).
5455
5556
# Fields
5657
@@ -61,30 +62,39 @@ For simple functions, `constant_function` should usually be set to `false`, but
6162
6263
# Notes
6364
64-
If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values.
65-
An example of such a function is:
65+
We now give several examples of functions.
66+
For each one, we explain how `constant_function` should be set in order to compute the correct derivative with respect to the input `x`.
6667
6768
```julia
68-
cache = [0.0]
69-
function f(x)
70-
cache[1] = x[1]^2
71-
cache[1] + x[1]
69+
function f1(x)
70+
return x[1]
7271
end
7372
```
7473
75-
In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input.
76-
Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant.
77-
78-
Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`:
74+
The function `f1` is not a closure, it does not contain any data.
75+
Thus `f1` can be differentiated with `AutoEnzyme(constant_function=true)` (although here setting `constant_function=false` would change neither correctness nor performance).
7976
8077
```julia
8178
parameter = [0.0]
82-
function f(x)
83-
parameter[1] + x[1]
79+
function f2(x)
80+
return parameter[1] + x[1]
81+
end
82+
```
83+
84+
The function `f2` is a closure over `parameter`, but `parameter` is never modified based on the input `x`.
85+
Thus, `f2` can be differentiated with `AutoEnzyme(constant_function=true)` (setting `constant_function=false` would not change correctness but would hinder performance).
86+
87+
```julia
88+
cache = [0.0]
89+
function f3(x)
90+
cache[1] = x[1]
91+
return cache[1] + x[1]
8492
end
8593
```
8694
87-
In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant.
95+
The function `f3` is a closure over `cache`, and `cache` is modified based on the input `x`.
96+
That means `cache` cannot be treated as constant, since derivative values must be propagated through it.
97+
Thus `f3` must be differentiated with `AutoEnzyme(constant_function=false)` (setting `constant_function=true` would make the result incorrect).
8898
"""
8999
struct AutoEnzyme{M, constant_function} <: AbstractADType
90100
mode::M

0 commit comments

Comments
 (0)