From a165d080b02a30688096f5bd677819e085f57e00 Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Fri, 4 Jun 2021 21:27:34 +0800 Subject: [PATCH] more contents to tutorial --- docs/src/tutorial.md | 155 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 152 insertions(+), 3 deletions(-) diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index 3968cc3..551d42d 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -47,7 +47,43 @@ end end ``` -Functions do not have return statements, they return input arguments instead. +First let's explain the codes a bit. `@i` macro is the core of NiLang; all reversible +functions in NiLang has this decorator. By adding this macro you're promising that +`r_axpy!` is a reversible function. In other words, the composition of `r_axpy!` and +its inverse `~r_axpy!` becomes an identity map, i.e., `(~r_axpy!)(r_axpy!(args...)...) ≈ args` +for all valid input `args`. + +Some functions and variables in NiLang will ends with `!`. This is a Julia convention saying +that `r_axpy!` is an in-place function which modifies the input, and that input `y!` and `out!` +will be modified. + +Perhaps surprisingly, you can't write an explicit `return` in `@i`. This is because NiLang's +reversible programming requires a complete "compute–copy–uncompute" paradigm. NiLang is sometimes +smart enough to infer the copy and uncompute stage so you won't need to manually write it. A +complete version of `r_axpy!` is: + +```julia +@i function r_axpy!(a::T, x::AbstractVector{T}, y!::AbstractVector{T}) where T + @safe @assert length(x) == length(y!) + # compute + @routine begin + for i=1:length(x) + y![i] += a * x[i] + end + end + + # no copy operation here + + # uncompute + ~@routine + + # `@i` forces returning all input variables as outputs, i.e., + # return a, x, y! + # and you can't override this +end +``` + +Functions do not have return statements, they return all input arguments instead. Hence `r_loss` defines a 5 variable to 5 variable bijection. Let's check the reversibility ```julia @@ -63,7 +99,7 @@ julia> out, a, x, y, z = r_loss(out, a, x, y, z) ``` We find the contents in `out` and `y` are changed after calling the loss function. -Then we call the inverse loss function `~r_loss`. +If we call the inverse loss function `~r_loss`, then values are restored: ```julia julia> out, a, x, y, z = (~r_loss)(out, a, x, y, z) @@ -72,13 +108,39 @@ julia> out, a, x, y, z = (~r_loss)(out, a, x, y, z) [1.1603953198942412, 0.5562855137395296, 1.9650050430758796]) ``` -Values are restored. Here, instead of assigning variables one by one, +Here, instead of assigning variables one by one, one can also use the macro `@instr` ```julia @instr r_loss(out, a, x, y, z) ``` `@instr` macro is for executing a reversible statement. +Let's go back a bit and restate the "compute-copy-uncompute" paradigm. It is sometimes +tedious and cumbersome to write an uncompute manually. NiLang introduces `@routine` macro +to record all operations, and automatically reverse these operations with `~@routine`. +For this reason, each compute stage `@routine` should always have a corresponding +uncompute stage `~@routine`: + +!!! tip + You can intuitively take `@routine` as a special 0-argument function that only lives + in the `@i` scope, and `~@routine` is its inverse. + +```julia +julia> @i function r_axpy!(a::T, x::AbstractVector{T}, y!::AbstractVector{T}) where T + @safe @assert length(x) == length(y!) + # compute + @routine begin + for i=1:length(x) + y![i] += a * x[i] + end + end + # no corresponding uncompute stage +end +ERROR: LoadError: `@routine` and `~@routine` must appear in pairs, mising `~@routine`! +Stacktrace: +[...] +``` + ## My first reversible AD program ```julia @@ -129,3 +191,90 @@ julia> grad(gz) 0.5418352557049606 0.6004325146280002 ``` + +## Writing irreversible function in a reversible way + +Not all functions are reversible: operations that erases the compute history is not reversible +operations. For example, in general: + +- `x *= 0` reset `x` to be zero and thus `*` (and `/`) are not a reversible operations; +- shared read/write operation `y += f(y)` clears the old status of `y` and is not reversible + +But in practice, you can always rewrite them in a reversible way. + +The first trick is to _use extra bits to record the intermediate results_: + +```julia +@i function i_wsqeuclidean(out!::T, X::AbstractArray{T}, Y::AbstractArray{T}, W::AbstractArray{T}) where T + @safe @assert size(W) == size(X) == size(Y) + for i = 1:length(X) + # compute stage + @routine begin + @zeros T d d2 + # All intermediate results need to be recorded in the compute + # stage so that they can be successfully uncomputed. + d += X[i] - Y[i] + d2 += abs2(d) + end + + # copy + out! += d2 * W[i] + + # uncompute stage reverses the computation to its initial status, + # it also restores intermediate results `d` and `d2` to zero value + ~@routine + end +end +``` + +Don't worry if you don't know the whole set of irreversible operations, when constructing function +with `@i`, a reversibility check will be called and throw errors when the function body is not +reversible. Hence you're always in a safe status. + +```julia +julia> @i function irreversible_f(out!) + out! += abs2(out!) + end +ERROR: LoadError: InvertibilityError("1-th argument and 2-th argument shares the same memory out!, shared read and shared write are not allowed!") +``` + +Of course, reversibility check takes time, and the overhead and be quite significant in very tight loops, +take our previous `i_wsqeuclidean` as an example, `length(X)` reversible checks are applied here. + +```julia +using Benchmark + +X, Y, W = rand(5, 5), rand(5, 5), rand(5, 5) +@btime i_wsqeuclidean(0.0, X, Y, W)[1] +# 1.122 μs (2 allocations: 64 bytes) +``` + +The macro `@invcheckoff` can be used to disable the reversibility check for the entire block. +For example, by adding it to the for-loop block, all checks are disabled. + +```julia +@i function i_wsqeuclidean(out!::T, X::AbstractArray{T}, Y::AbstractArray{T}, W::AbstractArray{T}) where T + @safe @assert size(W) == size(X) == size(Y) + @invcheckoff for i = 1:length(X) + @routine begin + @zeros T d d2 + # All intermediate results need to be recorded in the compute + # stage so that they can be successfully uncomputed. + d += X[i] - Y[i] + d2 += abs2(d) + end + out! += d2 * W[i] + + # uncompute stage reverses the computation to its initial status, + # it also restores intermediate results `d` and `d2` to zero value + ~@routine + end +end +``` + +it's significantly faster now: + +```julia +@btime i_wsqeuclidean(0.0, X, Y, W)[1] +# 67.269 ns (2 allocations: 64 bytes) +```