Skip to content

Commit 9627bd6

Browse files
authored
Update documentation (#681)
* Update dependencies * Update fenced code block examples * Fix indents * Add headers in docstrings * Add backticks * Fix admonition blocks * Fix more doctests
1 parent a95c181 commit 9627bd6

18 files changed

+631
-502
lines changed

docs/Manifest.toml

+469-382
Large diffs are not rendered by default.

docs/src/FAQ.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ For example, in `access(xs, n) = xs[n]`, the derivative of `access` with respect
6363
When no custom `frule` or `rrule` exists, if you try to call one of those, it will return `nothing` by default.
6464
As a result, you may encounter errors like
6565

66-
```julia
66+
```plain
6767
MethodError: no method matching iterate(::Nothing)
6868
```
6969

docs/src/ad_author/opt_out.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ We provide two ways to know that a rule has been opted out of.
77
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.
88

99
If you are in a position to generate code, in response to values returned by function calls then you can do something like:
10-
```@julia
10+
```julia
1111
res = rrule(f, xs)
1212
if res === nothing
1313
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule

docs/src/design/changing_the_primal.md

+37-37
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ What about using `sincos`?
6363
```@raw html
6464
<details open><summary>Example for `sin`</summary>
6565
```
66-
```julia
66+
```julia-repl
6767
julia> using BenchmarkTools
6868
6969
julia> @btime sin(x) setup=(x=rand());
@@ -76,7 +76,7 @@ julia> 3.838 + 4.795
7676
8.633
7777
```
7878
vs computing both together:
79-
```julia
79+
```julia-repl
8080
julia> @btime sincos(x) setup=(x=rand());
8181
6.028 ns (0 allocations: 0 bytes)
8282
```
@@ -96,7 +96,7 @@ So we can save time, if we can reuse that `exp(x)`.
9696
<details open><summary>Example for the logistic sigmoid</summary>
9797
```
9898
If we have to computing separately:
99-
```julia
99+
```julia-repl
100100
julia> @btime 1/(1+exp(x)) setup=(x=rand());
101101
5.622 ns (0 allocations: 0 bytes)
102102
@@ -108,7 +108,7 @@ julia> 5.622 + 6.036
108108
```
109109

110110
vs reusing `exp(x)`:
111-
```julia
111+
```julia-repl
112112
julia> @btime exp(x) setup=(x=rand());
113113
5.367 ns (0 allocations: 0 bytes)
114114
@@ -148,8 +148,8 @@ x̄ = pullback_at(f, x, y, ȳ, intermediates)
148148
```
149149
```julia
150150
function augmented_primal(::typeof(sin), x)
151-
y, cx = sincos(x)
152-
return y, (; cx=cx) # use a NamedTuple for the intermediates
151+
y, cx = sincos(x)
152+
return y, (; cx=cx) # use a NamedTuple for the intermediates
153153
end
154154

155155
pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
@@ -163,9 +163,9 @@ pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
163163
```
164164
```julia
165165
function augmented_primal(::typeof(σ), x)
166-
ex = exp(x)
167-
y = ex / (1 + ex)
168-
return y, (; ex=ex) # use a NamedTuple for the intermediates
166+
ex = exp(x)
167+
y = ex / (1 + ex)
168+
return y, (; ex=ex) # use a NamedTuple for the intermediates
169169
end
170170

171171
pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex)
@@ -189,8 +189,8 @@ And storing all these things on the tape — inputs, outputs, sensitivities, int
189189
What if we generalized the idea of the `intermediate` named tuple, and had `augmented_primal` return a struct that just held anything we might want put on the tape.
190190
```julia
191191
struct PullbackMemory{P, S}
192-
primal_function::P
193-
state::S
192+
primal_function::P
193+
state::S
194194
end
195195
# convenience constructor:
196196
PullbackMemory(primal_function; state...) = PullbackMemory(primal_function, state)
@@ -211,8 +211,8 @@ which is much cleaner.
211211
```
212212
```julia
213213
function augmented_primal(::typeof(sin), x)
214-
y, cx = sincos(x)
215-
return y, PullbackMemory(sin; cx=cx)
214+
y, cx = sincos(x)
215+
return y, PullbackMemory(sin; cx=cx)
216216
end
217217

218218
pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
@@ -226,9 +226,9 @@ pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
226226
```
227227
```julia
228228
function augmented_primal(::typeof(σ), x)
229-
ex = exp(x)
230-
y = ex / (1 + ex)
231-
return y, PullbackMemory(σ; y=y, ex=ex)
229+
ex = exp(x)
230+
y = ex / (1 + ex)
231+
return y, PullbackMemory(σ; y=y, ex=ex)
232232
end
233233

234234
pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex)
@@ -256,8 +256,8 @@ x̄ = pb(ȳ)
256256
```
257257
```julia
258258
function augmented_primal(::typeof(sin), x)
259-
y, cx = sincos(x)
260-
return y, PullbackMemory(sin; cx=cx)
259+
y, cx = sincos(x)
260+
return y, PullbackMemory(sin; cx=cx)
261261
end
262262
(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx
263263
```
@@ -271,9 +271,9 @@ end
271271
```
272272
```julia
273273
function augmented_primal(::typeof(σ), x)
274-
ex = exp(x)
275-
y = ex / (1 + ex)
276-
return y, PullbackMemory(σ; y=y, ex=ex)
274+
ex = exp(x)
275+
y = ex / (1 + ex)
276+
return y, PullbackMemory(σ; y=y, ex=ex)
277277
end
278278

279279
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex)
@@ -295,16 +295,16 @@ Let's go back and think about the changes we would have make to go from our orig
295295
To rewrite that original formulation in the new pullback form we have:
296296
```julia
297297
function augmented_primal(::typeof(sin), x)
298-
y = sin(x)
299-
return y, PullbackMemory(sin; x=x)
298+
y = sin(x)
299+
return y, PullbackMemory(sin; x=x)
300300
end
301301
(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x)
302302
```
303303
To go from that to:
304304
```julia
305305
function augmented_primal(::typeof(sin), x)
306-
y, cx = sincos(x)
307-
return y, PullbackMemory(sin; cx=cx)
306+
y, cx = sincos(x)
307+
return y, PullbackMemory(sin; cx=cx)
308308
end
309309
(pb::PullbackMemory)(ȳ) = ȳ * pb.cx
310310
```
@@ -317,17 +317,17 @@ end
317317
```
318318
```julia
319319
function augmented_primal(::typeof(σ), x)
320-
y = σ(x)
321-
return y, PullbackMemory(σ; y=y, x=x)
320+
y = σ(x)
321+
return y, PullbackMemory(σ; y=y, x=x)
322322
end
323323
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x)
324324
```
325325
to get to:
326326
```julia
327327
function augmented_primal(::typeof(σ), x)
328-
ex = exp(x)
329-
y = ex/(1 + ex)
330-
return y, PullbackMemory(σ; y=y, ex=ex)
328+
ex = exp(x)
329+
y = ex/(1 + ex)
330+
return y, PullbackMemory(σ; y=y, ex=ex)
331331
end
332332
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex)
333333
```
@@ -356,9 +356,9 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
356356
```
357357
```julia
358358
function augmented_primal(::typeof(sin), x)
359-
y, cx = sincos(x)
360-
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
361-
return y, pb
359+
y, cx = sincos(x)
360+
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
361+
return y, pb
362362
end
363363
```
364364
```@raw html
@@ -370,10 +370,10 @@ end
370370
```
371371
```julia
372372
function augmented_primal(::typeof(σ), x)
373-
ex = exp(x)
374-
y = ex / (1 + ex)
375-
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
376-
return y, pb
373+
ex = exp(x)
374+
y = ex / (1 + ex)
375+
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
376+
return y, pb
377377
end
378378
```
379379
```@raw html

docs/src/design/many_tangents.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Structural tangents are derived from the structure of the input.
4545
Either automatically, as part of the AD, or manually, as part of a custom rule.
4646

4747
Consider the structure of `DateTime`:
48-
```julia
48+
```julia-repl
4949
julia> dump(now())
5050
DateTime
5151
instant: UTInstant{Millisecond}
@@ -83,15 +83,15 @@ Where there is no natural tangent type for the outermost type but there is for s
8383

8484
Consider if we had a representation of a country's GDP as output by some continuous time model like a Gaussian Process, where that representation is as a sequence of `TimeSample`s
8585
structured as follows:
86-
```julia
86+
```julia-repl
8787
julia> struct TimeSample
8888
time::DateTime
8989
value::Float64
9090
end
9191
```
9292

9393
We can look at its structure:
94-
```julia
94+
```julia-repl
9595
julia> dump(TimeSample(now(), 2.6e9))
9696
TimeSample
9797
time: DateTime

docs/src/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ end
191191
# output
192192
193193
```
194+
194195
```jldoctest index
195196
#### Find dfoo/dx via rrules
196197
#### First the forward pass, gathering up the pullbacks

docs/src/rule_author/converting_zygoterules.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Converting ZygoteRules.@adjoint to `rrule`s
1+
# Converting `ZygoteRules.@adjoint` to `rrule`s
22

33
[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRulesCore but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only.
44

docs/src/rule_author/example.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ end
3838
```
3939

4040
We can check this rule against a finite-differences approach using [`ChainRulesTestUtils`](https://github.com/JuliaDiff/ChainRulesTestUtils.jl):
41-
```julia
41+
```julia-repl
4242
julia> using ChainRulesTestUtils
43+
4344
julia> test_rrule(foo_mul, Foo(rand(3, 3), 3.0), rand(3, 3))
4445
Test Summary: | Pass Total
4546
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10

docs/src/rule_author/which_functions_need_rules.md

+13-10
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ function addone(a::AbstractArray)
3434
end
3535
```
3636
complains that
37-
```julia
37+
```julia-repl
3838
julia> using Zygote
39+
3940
julia> gradient(addone, a)
4041
ERROR: Mutating arrays is not supported
4142
```
@@ -50,7 +51,7 @@ function ChainRules.rrule(::typeof(addone), a)
5051
end
5152
```
5253
the gradient can be evaluated:
53-
```julia
54+
```julia-repl
5455
julia> gradient(addone, a)
5556
([1.0, 1.0, 1.0],)
5657
```
@@ -86,7 +87,7 @@ function exception(x)
8687
end
8788
```
8889
does not work
89-
```julia
90+
```julia-repl
9091
julia> gradient(exception, 3.0)
9192
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
9293
```
@@ -101,7 +102,7 @@ function ChainRulesCore.rrule(::typeof(exception), x)
101102
end
102103
```
103104

104-
```julia
105+
```julia-repl
105106
julia> gradient(exception, 3.0)
106107
(6.0,)
107108
```
@@ -123,9 +124,11 @@ function mse(y, ŷ)
123124
end
124125
```
125126
takes a lot longer to AD through
126-
```julia
127-
julia> y = rand(30)
128-
julia> ŷ = rand(30)
127+
```julia-repl
128+
julia> y = rand(30);
129+
130+
julia> ŷ = rand(30);
131+
129132
julia> @btime gradient(mse, $y, $ŷ)
130133
38.180 μs (993 allocations: 65.00 KiB)
131134
```
@@ -142,7 +145,7 @@ function ChainRules.rrule(::typeof(mse), x, x̂)
142145
end
143146
```
144147
which is much faster
145-
```julia
148+
```julia-repl
146149
julia> @btime gradient(mse, $y, $ŷ)
147150
143.697 ns (2 allocations: 672 bytes)
148151
```
@@ -159,7 +162,7 @@ function sum3(array)
159162
return x+y+z
160163
end
161164
```
162-
```julia
165+
```julia-repl
163166
julia> @btime gradient(sum3, rand(30))
164167
424.510 ns (9 allocations: 2.06 KiB)
165168
```
@@ -176,7 +179,7 @@ function ChainRulesCore.rrule(::typeof(sum3), a)
176179
end
177180
```
178181
turns out to be significantly faster
179-
```julia
182+
```julia-repl
180183
julia> @btime gradient(sum3, rand(30))
181184
192.818 ns (3 allocations: 784 bytes)
182185
```

docs/src/rule_author/writing_good_rules.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` wil
110110

111111
You can check which to use with `Core.Typeof`:
112112

113-
```julia
113+
```julia-repl
114114
julia> function foo end
115115
foo (generic function with 0 methods)
116116
@@ -254,7 +254,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
254254
end
255255
```
256256
Ends up infering a return type of `Any`
257-
```julia
257+
```julia-repl
258258
julia> _, pullback = rrule(double_it, [2.0, 3.0])
259259
([4.0, 6.0], var"#double_it_pullback#8"(Core.Box(var"#double_it_pullback#8"(#= circular reference @-2 =#))))
260260
@@ -289,7 +289,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
289289
end
290290
```
291291
This infers just fine:
292-
```julia
292+
```julia-repl
293293
julia> _, pullback = rrule(double_it, [2.0, 3.0])
294294
([4.0, 6.0], _double_it_pullback)
295295

0 commit comments

Comments
 (0)