Skip to content

Commit

Permalink
Adding new recur and playing w/ gradient.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkschleg committed Aug 16, 2023
1 parent 1348828 commit 7a467cc
Show file tree
Hide file tree
Showing 4 changed files with 479 additions and 42 deletions.
53 changes: 53 additions & 0 deletions recur_funcs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using Flux

function run_new_recur()
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2


x_block = reshape(reduce(vcat, x), 1, 1, length(x))
nm_layer = Flux.NewRecur(cell; return_sequence = true)
_out = layer(x_block)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out[1, 1, end])
end
grads = g[1][:cell]

@show primal[1] e
@show ∇Wi grads[:Wi]
@show ∇Wh grads[:Wh]
@show ∇b grads[:b]
@show ∇state0 grads[:state0]

return
end

function run_scan_full()

x = [[2.0f0], [3.0f0], [4.0f0]]
x_block = reshape(reduce(vcat, x), 1, 1, length(x))
# nm_layer = Flux.NewRecur(cell; return_sequence = true)
w = zeros(1)
_out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block)
e, g = Flux.withgradient(w) do layer
out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block)
sum(out[1, 1, end])
end
grads = g[1][:cell]
return
end
197 changes: 196 additions & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,187 @@ end

reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...)



# non-stateful recurrence

"""
scan_full
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence.
"""
function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray})
# Recurrence operation used in the fold. Takes the state of the
# fold and the next input, returns the new state.
function recurrence_op((carry, outputs), input)
carry, out = func(carry, input)
return carry, vcat(outputs, [out])
end
# Fold left to right.
return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs)
end

function scan_full(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_full(func, init_carry, xs)
end

# Chain Rule for Base.mapfoldl_impl
function ChainRulesCore.rrule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Union{AbstractArray, Tuple};
) where {G}
# Hobbits has two types afaict, first is for the first component, then the second component.
# This has to do with the entrance I believe (i.e. we don't know what function enters, but we know what
# function is called in subsequent things...
# hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this
# accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing)
# @show typeof(accum_init)
accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1])
hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b
@show a, b
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
end
# @show typeof(hobbits)

y = first(last(hobbits))
axe = axes(x)
project = ChainRulesCore.ProjectTo(x)
function unfoldl(dy)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
# @show trio
f_ds, f_da, f_db = accum_init[2](trio[end][2])
dop = sum(first, trio) + f_ds
dx = [[f_db]; map(last, Iterators.reverse(trio))]
d_init = f_da
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end

# function ChainRulesCore.rrule(
# config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
# ::typeof(Base.mapfoldl_impl),
# ::typeof(identity),
# op::G,
# init,
# x::Union{AbstractArray, Tuple};
# ) where {G}
# hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this
# accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b
# c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
# end
# y = first(last(hobbits))
# axe = axes(x)
# project = ChainRulesCore.ProjectTo(x)
# function unfoldl(dy)
# trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
# ds, da, db = back(dc)
# end
# dop = sum(first, trio)
# dx = map(last, Iterators.reverse(trio))
# @show dx
# d_init = trio[end][2]
# return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
# end
# return y, unfoldl
# end


"""
scan_partial
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence.
"""
function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray})
x_init, x_rest = Iterators.peel(xs)
(carry, y) = func(init_carry, x_init)
for x in x_rest
(carry, y) = func(carry, x)
end
carry, y
end

function scan_partial(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_partial(func, init_carry, xs)
end


"""
NewRecur
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns:
```julia
l = NewRNN(1,2)
xs # Some input array Input x BatchSize x Time
init_carry # the initial carry of the cell.
l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry.
l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE.
```
"""
struct NewRecur{RET_SEQUENCE, T}
cell::T
# state::S
function NewRecur(cell; return_sequence::Bool=false)
new{return_sequence, typeof(cell)}(cell)
end
function NewRecur{true}(cell)
new{true, typeof(cell)}(cell)
end
function NewRecur{false}(cell)
new{false, typeof(cell)}(cell)
end
end

Flux.@functor NewRecur
Flux.trainable(a::NewRecur) = (; cell = a.cell)
Base.show(io::IO, m::NewRecur) = print(io, "NewRecur(", m.cell, ")")

(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur")
(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur")

function (l::NewRecur)(xs::AbstractArray)
results = l(l.cell.state0, xs)
results[2] # Only return the output here.
end

function (l::NewRecur{false})(init_carry, xs)
results = scan_partial(l.cell, init_carry, xs)
results[1], results[2]
end

function (l::NewRecur{true})(init_carry, xs)
results = scan_full(l.cell, init_carry, xs)
results[1], stack(results[2], dims=3)
end



# Stateful recurrence

"""
Expand Down Expand Up @@ -187,8 +368,14 @@ function (m::Recur)(x::AbstractArray{T, 3}) where T
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

# Vanilla RNN

########
#
# Recurrent Cells
#
########

# Vanilla RNN
struct RNNCell{F,I,H,V,S}
σ::F
Wi::I
Expand Down Expand Up @@ -289,6 +476,8 @@ julia> r(rand(4, 10)) |> size # batch size of 10
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
Recur(m::RNNCell) = Recur(m, m.state0)

NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)

# LSTM

struct LSTMCell{I,H,V,S}
Expand Down Expand Up @@ -362,6 +551,8 @@ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
Recur(m::LSTMCell) = Recur(m, m.state0)

NewLSTM(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.LSTMCell(a...; ka...); return_sequence=return_sequence)

# GRU

function _gru_output(gxs, ghs, bs)
Expand Down Expand Up @@ -436,6 +627,8 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))
Recur(m::GRUCell) = Recur(m, m.state0)

NewGRU(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUCell(a...; ka...); return_sequence=return_sequence)

# GRU v3

struct GRUv3Cell{I,H,V,HH,S}
Expand Down Expand Up @@ -505,3 +698,5 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)

NewGRUv3(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUv3Cell(a...; ka...); return_sequence=return_sequence)
Loading

0 comments on commit 7a467cc

Please sign in to comment.