Skip to content

Commit

Permalink
New stack design (#59)
Browse files Browse the repository at this point in the history
* sve

* new stack

* fit new compiler

* update version and examples
  • Loading branch information
GiggleLiu authored May 11, 2021
1 parent aba8b80 commit 19b96b2
Show file tree
Hide file tree
Showing 26 changed files with 222 additions and 258 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NiLang"
uuid = "ab4ef3a6-0b42-11ea-31f6-e34652774712"
authors = ["JinGuo Liu", "thautwarm"]
version = "0.8.5"
version = "0.9.0"

[deps]
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
Expand All @@ -10,14 +10,15 @@ LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
MatchCore = "5dd3f0b1-72a9-48ad-ae6e-79f673da005f"
NiLangCore = "575d3204-02a4-11ea-3f62-238caa8bf11e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[compat]
FixedPointNumbers = "0.6, 0.7, 0.8"
LogarithmicNumbers = "0.4"
MatchCore = "0.1"
NiLangCore = "0.9.1"
NiLangCore = "0.10.1"
Reexport = "0.2, 1.0"
TupleTools = "1.2"
julia = "1.3"
Expand Down
Empty file added benchmark/stack.jl
Empty file.
7 changes: 5 additions & 2 deletions examples/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

using NiLang
@i function i_fft!(x::AbstractVector{T}) where T
@invcheckoff N length(x)
@routine @invcheckoff N length(x)
@safe @assert N%2 == 0
@invcheckoff @inbounds if N <= 1
elseif N == 2
Expand All @@ -57,12 +57,15 @@ using NiLang
end
# combine
for i=1:N÷2
θ -2*π*(i-1)/N
@routine θ -2*π*(i-1)/N
ROT(x[i+N÷2].re, x[i+N÷2].im, θ)
HADAMARD(x[i].re, x[i+N÷2].re)
HADAMARD(x[i].im, x[i+N÷2].im)
~@routine
end
x2 zeros(T, N)
end
~@routine
end

using Test, FFTW
Expand Down
5 changes: 3 additions & 2 deletions examples/fib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
using NiLang

@i function rfib(out!, n::T) where T
n1 zero(T)
n2 zero(T)
@routine begin
n1 zero(T)
n2 zero(T)
n1 += n - 1
n2 += n - 2
end
Expand All @@ -31,6 +31,7 @@ end
rfib(out, n!)
end
~rfib(out, n!)
out 0
end

# In this example, the postcondition `n!=0` in the `while` statement is false before entering the loop, and it becomes true in later iterations. In the reverse program, the `while` statement stops at `n==0`.
Expand Down
1 change: 1 addition & 0 deletions examples/lax_wendroff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ using NiLang
cache[nx+1,j] += q[2]
SWAP(q[nx], cache[nx+1,j])
end
nx length(q)
end
nt = 2000
i_lax_wendroff!(nt, 1.0, q_init, zero(q_init), zeros(length(q_init)+1,nt))
1 change: 1 addition & 0 deletions examples/nice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const NiceNetwork{T} = Vector{NiceLayer{T}}
else
@inbounds nice_layer!(x! |> subarray(1:np÷2), network[i], x! |> subarray(np÷2+1:np))
end
np length(x!)
end
end

Expand Down
13 changes: 8 additions & 5 deletions examples/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

# ## Functions used in this example

using NiLang, NiLang.AD
using NiLang, NiLang.AD, Test

# ## The QR decomposition
# Let us consider a naive implementation of QR decomposition from scratch.
# This implementation is just a proof of principle which does not consider reorthogonalization and other practical issues.

@i function qr(Q, R, A::Matrix{T}) where T
anc_norm zero(T)
anc_dot zeros(T, size(A,2))
ri zeros(T, size(A,1))
@routine begin
anc_norm zero(T)
anc_dot zeros(T, size(A,2))
ri zeros(T, size(A,1))
end
for col = 1:size(A, 1)
ri .+= A[:,col]
for precol = 1:col-1
Expand Down Expand Up @@ -41,6 +43,7 @@ using NiLang, NiLang.AD
i_norm2(anc_norm, ri)
end
end
~@routine
end

# Here, in order to avoid frequent uncomputing, we allocate ancillas `ri` and `anc_dot` as vectors.
Expand All @@ -55,7 +58,7 @@ q, r = zero(A), zero(A)
i_sum(out, q)
end

check_grad(test1, (0.0, q, r, A); iloss=1)
@test check_grad(test1, (0.0, q, r, A); iloss=1)

# Here, the loss function `test1` is defined as the sum of the output unitary matrix `q`.
# The `check_grad` function is a gradient checker function defined in module `NiLang.AD`.
2 changes: 1 addition & 1 deletion src/NiLang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Reexport
import NiLangCore: invtype

using FixedPointNumbers: Q20f43, Fixed
import NiLangCore: empty_global_stacks!, loaddata
export Fixed43
const Fixed43 = Q20f43

Expand All @@ -13,7 +14,6 @@ include("wrappers.jl")
include("vars.jl")
include("instructs.jl")
include("ulog.jl")
include("stack.jl")
include("complex.jl")
include("autobcast.jl")
include("macros.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ using NiLangCore
using MatchCore, TupleTools

import ..NiLang: ROT, IROT, SWAP,
chfield, value, NoGrad, loaddata, INC, DEC, HADAMARD,
chfield, value, NoGrad, INC, DEC, HADAMARD,
AddConst, SubConst, NEG, INV
using NiLangCore: default_constructor

export GVar, grad, Loss, NoGrad, @nograd

include("vars.jl")
include("stack.jl")
include("gradfunc.jl")
include("checks.jl")

Expand Down
15 changes: 15 additions & 0 deletions src/autodiff/stack.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This is a patch for loading a data to GVar correctly.
import NiLangCore

NiLangCore.loaddata(::Type{GT}, x::T) where {T, GT<:GVar{T}} = convert(GT, x)
function NiLangCore.loaddata(t::Type{VT}, x::AbstractVector) where {T, VT<:AbstractVector{T}}
convert.(T, x)
end

function NiLangCore.loaddata(t::VT, x::AbstractVector) where {T, VT<:AbstractVector{T}}
convert(VT, NiLangCore.loaddata.(t, x))
end

function NiLangCore.loaddata(::Type{T}, x::XT) where {N, T<:Tuple{N}, XT<:Tuple{N}}
ntuple(i=>NiLangCore.loaddata.(T.parameters[i], [i]), N)
end
19 changes: 1 addition & 18 deletions src/autodiff/vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ chfield(x::GVar, ::typeof(grad), g::GVar) = GVar(x.x, g)
chfield(x::Complex{<:GVar}, ::typeof(grad), g::Complex) = Complex(GVar(value(x.re), g.re), GVar(value(x.im), g.im))

# NOTE: superwarning: check value only to make ancilla gradient descardable.
NiLangCore.deanc(x::GVar, val::GVar) = NiLangCore.deanc(value(x), value(val))
NiLangCore.deanc(x::GVar{T}, val::GVar{T}) where T = NiLangCore.deanc(value(x), value(val))
function deanc(x::T, val::T) where {T<:AbstractArray}
x === val || deanc.(x, val)
end
Expand Down Expand Up @@ -186,19 +186,6 @@ macro nograd(ex)
end
end

# load data from stack
function loaddata(::Type{TG}, x::T) where {T,TG<:GVar{T}}
TG(x)
end

function loaddata(::Type{T}, x::T) where T <: GVar
x
end

function loaddata(::Type{AGT}, x::AT) where {T, GT, AT<:AbstractArray{T}, AGT<:AbstractArray{GVar{T,T}}}
map(x->GVar(x, zero(x)), x)
end

# ULogarithmic
_content(x::ULogarithmic) = x.log
NiLang.AD.GVar(x::ULogarithmic) = exp(ULogarithmic, GVar(_content(x), zero(_content(x))))
Expand All @@ -209,10 +196,6 @@ Base.one(::Type{ULogarithmic{GVar{T,GT}}}) where {T,GT} = exp(ULogarithmic, GVar
Base.zero(x::ULogarithmic{GVar{T,GT}}) where {T,GT} =zero(ULogarithmic{GVar{T,GT}})
Base.zero(::Type{ULogarithmic{GVar{T,T}}}) where T = exp(ULogarithmic, GVar(zero(T), zero(T)))

function NiLang.loaddata(::Type{Array{<:ULogarithmic{GVar{T,T}}}}, data::Array{<:ULogarithmic{T}}) where {T}
GVar.(data)
end

# the patch for dicts
function GVar(d::Dict)
Dict([(k=>GVar(v)) for (k, v) in d])
Expand Down
37 changes: 37 additions & 0 deletions src/instructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export SWAP, FLIP
export ROT, IROT
export INC, DEC, NEG, INV, AddConst, SubConst
export HADAMARD
export PUSH!, POP!, COPYPOP!, COPYPUSH!

"""
NoGrad{T} <: IWrapper{T}
Expand Down Expand Up @@ -180,3 +181,39 @@ Base.:~(ac::SubConst) = AddConst(ac.x)
for F in [:INV, :NEG, :FLIP, :INC, :DEC]
@eval NiLangCore.chfield(x::T, ::typeof($F), xval::T) where T<:Real = (~$F)(xval)
end

#### The following functions are not safe!
@i @inline function PUSH!(x::T) where T
PUSH!((@skip! GLOBAL_STACK), x)
end

@i @inline function POP!(x::T) where T
POP!((@skip! GLOBAL_STACK), x)
end

@i @inline function COPYPUSH!(x)
COPYPUSH!((@skip! GLOBAL_STACK), x)
end

@i @inline function COPYPOP!(x)
COPYPOP!((@skip! GLOBAL_STACK), x)
end

# reversibility turned off, in principle, we can not deallocate `GVar{T}` to `T`
@i @inline function PUSH!(st, x::T) where T
@invcheckoff st[end+1] x
@invcheckoff x _zero(T)
end

@i @inline function POP!(st, x::T) where T
@invcheckoff x _zero(T)
@invcheckoff st[end] (x::T)::∅
end

@i @inline function COPYPUSH!(st, x)
@invcheckoff st[end+1] x
end

@i @inline function COPYPOP!(st, x)
@invcheckoff st[end] x
end
78 changes: 0 additions & 78 deletions src/stack.jl

This file was deleted.

15 changes: 10 additions & 5 deletions src/stdlib/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ end
@i function i_mul!(out!::AbstractVector{T}, x::AbstractMatrix, y::AbstractVector) where T
@safe size(x, 2) == size(y, 1) || throw(DimensionMismatch())
@invcheckoff @inbounds for j=1:size(x,2)
yj zero(T)
yj += y[j]
@routine begin
yj zero(T)
yj += y[j]
end
for i=1:size(x,1)
out![i] += x[i,j] * yj
end
yj -= y[j]
~@routine
end
end

Expand Down Expand Up @@ -81,8 +83,10 @@ end
Compute unitary matrix multiplication on `x`, where the unitary matrix is parameterized by (N+1)*N/2 `θ`s.
"""
@i function i_umm!(x!::AbstractArray, θ)
M size(x!, 1)
N size(x!, 2)
@routine begin
M size(x!, 1)
N size(x!, 2)
end
k 0
@safe @assert length(θ) == M*(M-1)/2
for l = 1:N
Expand All @@ -95,4 +99,5 @@ Compute unitary matrix multiplication on `x`, where the unitary matrix is parame
end

k length(θ)
~@routine
end
1 change: 1 addition & 0 deletions src/stdlib/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Reversible `mapfoldl` function, `map` can be irreversible, but `fold` should be
fold(out!, anc)
anc -= map(iter[i])
end
anc zero(T)
end

"""
Expand Down
Loading

0 comments on commit 19b96b2

Please sign in to comment.