Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] more contents to tutorial #68

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 152 additions & 3 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,43 @@ end
end
```

Functions do not have return statements, they return input arguments instead.
First let's explain the codes a bit. `@i` macro is the core of NiLang; all reversible
functions in NiLang has this decorator. By adding this macro you're promising that
`r_axpy!` is a reversible function. In other words, the composition of `r_axpy!` and
its inverse `~r_axpy!` becomes an identity map, i.e., `(~r_axpy!)(r_axpy!(args...)...) ≈ args`
for all valid input `args`.

Some functions and variables in NiLang will ends with `!`. This is a Julia convention saying
that `r_axpy!` is an in-place function which modifies the input, and that input `y!` and `out!`
will be modified.

Perhaps surprisingly, you can't write an explicit `return` in `@i`. This is because NiLang's
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn’t this because NiLang returns input arguments automatically?

reversible programming requires a complete "compute–copy–uncompute" paradigm. NiLang is sometimes
smart enough to infer the copy and uncompute stage so you won't need to manually write it. A
complete version of `r_axpy!` is:

```julia
@i function r_axpy!(a::T, x::AbstractVector{T}, y!::AbstractVector{T}) where T
@safe @assert length(x) == length(y!)
# compute
@routine begin
for i=1:length(x)
y![i] += a * x[i]
end
end

# no copy operation here

# uncompute
~@routine

# `@i` forces returning all input variables as outputs, i.e.,
# return a, x, y!
# and you can't override this
end
```

Functions do not have return statements, they return all input arguments instead.
Hence `r_loss` defines a 5 variable to 5 variable bijection.
Let's check the reversibility
```julia
Expand All @@ -63,7 +99,7 @@ julia> out, a, x, y, z = r_loss(out, a, x, y, z)
```

We find the contents in `out` and `y` are changed after calling the loss function.
Then we call the inverse loss function `~r_loss`.
If we call the inverse loss function `~r_loss`, then values are restored:

```julia
julia> out, a, x, y, z = (~r_loss)(out, a, x, y, z)
Expand All @@ -72,13 +108,39 @@ julia> out, a, x, y, z = (~r_loss)(out, a, x, y, z)
[1.1603953198942412, 0.5562855137395296, 1.9650050430758796])
```

Values are restored. Here, instead of assigning variables one by one,
Here, instead of assigning variables one by one,
one can also use the macro `@instr`
```julia
@instr r_loss(out, a, x, y, z)
```
`@instr` macro is for executing a reversible statement.

Let's go back a bit and restate the "compute-copy-uncompute" paradigm. It is sometimes
tedious and cumbersome to write an uncompute manually. NiLang introduces `@routine` macro
to record all operations, and automatically reverse these operations with `~@routine`.
For this reason, each compute stage `@routine` should always have a corresponding
uncompute stage `~@routine`:

!!! tip
You can intuitively take `@routine` as a special 0-argument function that only lives
in the `@i` scope, and `~@routine` is its inverse.

```julia
julia> @i function r_axpy!(a::T, x::AbstractVector{T}, y!::AbstractVector{T}) where T
@safe @assert length(x) == length(y!)
# compute
@routine begin
for i=1:length(x)
y![i] += a * x[i]
end
end
# no corresponding uncompute stage
end
ERROR: LoadError: `@routine` and `~@routine` must appear in pairs, mising `~@routine`!
Stacktrace:
[...]
```

## My first reversible AD program

```julia
Expand Down Expand Up @@ -129,3 +191,90 @@ julia> grad(gz)
0.5418352557049606
0.6004325146280002
```

## Writing irreversible function in a reversible way

Not all functions are reversible: operations that erases the compute history is not reversible
operations. For example, in general:

- `x *= 0` reset `x` to be zero and thus `*` (and `/`) are not a reversible operations;
- shared read/write operation `y += f(y)` clears the old status of `y` and is not reversible

But in practice, you can always rewrite them in a reversible way.

The first trick is to _use extra bits to record the intermediate results_:

```julia
@i function i_wsqeuclidean(out!::T, X::AbstractArray{T}, Y::AbstractArray{T}, W::AbstractArray{T}) where T
@safe @assert size(W) == size(X) == size(Y)
for i = 1:length(X)
# compute stage
@routine begin
@zeros T d d2
# All intermediate results need to be recorded in the compute
# stage so that they can be successfully uncomputed.
d += X[i] - Y[i]
d2 += abs2(d)
end

# copy
out! += d2 * W[i]

# uncompute stage reverses the computation to its initial status,
# it also restores intermediate results `d` and `d2` to zero value
~@routine
end
end
```

Don't worry if you don't know the whole set of irreversible operations, when constructing function
with `@i`, a reversibility check will be called and throw errors when the function body is not
reversible. Hence you're always in a safe status.

```julia
julia> @i function irreversible_f(out!)
out! += abs2(out!)
end
ERROR: LoadError: InvertibilityError("1-th argument and 2-th argument shares the same memory out!, shared read and shared write are not allowed!")
```

Of course, reversibility check takes time, and the overhead and be quite significant in very tight loops,
take our previous `i_wsqeuclidean` as an example, `length(X)` reversible checks are applied here.

```julia
using Benchmark

X, Y, W = rand(5, 5), rand(5, 5), rand(5, 5)
@btime i_wsqeuclidean(0.0, X, Y, W)[1]
# 1.122 μs (2 allocations: 64 bytes)
```

The macro `@invcheckoff` can be used to disable the reversibility check for the entire block.
For example, by adding it to the for-loop block, all checks are disabled.

```julia
@i function i_wsqeuclidean(out!::T, X::AbstractArray{T}, Y::AbstractArray{T}, W::AbstractArray{T}) where T
@safe @assert size(W) == size(X) == size(Y)
@invcheckoff for i = 1:length(X)
@routine begin
@zeros T d d2
# All intermediate results need to be recorded in the compute
# stage so that they can be successfully uncomputed.
d += X[i] - Y[i]
d2 += abs2(d)
end
out! += d2 * W[i]

# uncompute stage reverses the computation to its initial status,
# it also restores intermediate results `d` and `d2` to zero value
~@routine
end
end
```

it's significantly faster now:

```julia
@btime i_wsqeuclidean(0.0, X, Y, W)[1]
# 67.269 ns (2 allocations: 64 bytes)
```