diff --git a/recur_funcs.jl b/recur_funcs.jl new file mode 100644 index 0000000000..453de64c37 --- /dev/null +++ b/recur_funcs.jl @@ -0,0 +1,53 @@ +using Flux + +function run_new_recur() + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + x_block = reshape(reduce(vcat, x), 1, 1, length(x)) + nm_layer = Flux.NewRecur(cell; return_sequence = true) + _out = layer(x_block) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out[1, 1, end]) + end + grads = g[1][:cell] + + @show primal[1] ≈ e + @show ∇Wi ≈ grads[:Wi] + @show ∇Wh ≈ grads[:Wh] + @show ∇b ≈ grads[:b] + @show ∇state0 ≈ grads[:state0] + + return +end + +function run_scan_full() + + x = [[2.0f0], [3.0f0], [4.0f0]] + x_block = reshape(reduce(vcat, x), 1, 1, length(x)) + # nm_layer = Flux.NewRecur(cell; return_sequence = true) + w = zeros(1) + _out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block) + e, g = Flux.withgradient(w) do layer + out = Flux.scan_full((a, b)->(sum(w.*b), sum(w.*b)), 0.0f0, x_block) + sum(out[1, 1, end]) + end + grads = g[1][:cell] + return +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 375ff43d52..aeb8801f05 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -51,6 +51,187 @@ end reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...) + + +# non-stateful recurrence + +""" + scan_full + +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence. +""" +function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) + # Recurrence operation used in the fold. Takes the state of the + # fold and the next input, returns the new state. + function recurrence_op((carry, outputs), input) + carry, out = func(carry, input) + return carry, vcat(outputs, [out]) + end + # Fold left to right. + return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) +end + +function scan_full(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. + xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_full(func, init_carry, xs) +end + +# Chain Rule for Base.mapfoldl_impl +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init, + x::Union{AbstractArray, Tuple}; +) where {G} + # Hobbits has two types afaict, first is for the first component, then the second component. + # This has to do with the entrance I believe (i.e. we don't know what function enters, but we know what + # function is called in subsequent things... + # hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this + # accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing) + # @show typeof(accum_init) + accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1]) + hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b + @show a, b + c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) + end + # @show typeof(hobbits) + + y = first(last(hobbits)) + axe = axes(x) + project = ChainRulesCore.ProjectTo(x) + function unfoldl(dy) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + ds, da, db = back(dc) + end + # @show trio + f_ds, f_da, f_db = accum_init[2](trio[end][2]) + dop = sum(first, trio) + f_ds + dx = [[f_db]; map(last, Iterators.reverse(trio))] + d_init = f_da + return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) + end + return y, unfoldl +end + +# function ChainRulesCore.rrule( +# config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, +# ::typeof(Base.mapfoldl_impl), +# ::typeof(identity), +# op::G, +# init, +# x::Union{AbstractArray, Tuple}; +# ) where {G} +# hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this +# accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b +# c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) +# end +# y = first(last(hobbits)) +# axe = axes(x) +# project = ChainRulesCore.ProjectTo(x) +# function unfoldl(dy) +# trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) +# ds, da, db = back(dc) +# end +# dop = sum(first, trio) +# dx = map(last, Iterators.reverse(trio)) +# @show dx +# d_init = trio[end][2] +# return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) +# end +# return y, unfoldl +# end + + +""" + scan_partial + +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence. +""" +function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) + x_init, x_rest = Iterators.peel(xs) + (carry, y) = func(init_carry, x_init) + for x in x_rest + (carry, y) = func(carry, x) + end + carry, y +end + +function scan_partial(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. + xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_partial(func, init_carry, xs) +end + + +""" + NewRecur +New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns: + +```julia +l = NewRNN(1,2) +xs # Some input array Input x BatchSize x Time +init_carry # the initial carry of the cell. +l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry. +l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE. +``` +""" +struct NewRecur{RET_SEQUENCE, T} + cell::T + # state::S + function NewRecur(cell; return_sequence::Bool=false) + new{return_sequence, typeof(cell)}(cell) + end + function NewRecur{true}(cell) + new{true, typeof(cell)}(cell) + end + function NewRecur{false}(cell) + new{false, typeof(cell)}(cell) + end +end + +Flux.@functor NewRecur +Flux.trainable(a::NewRecur) = (; cell = a.cell) +Base.show(io::IO, m::NewRecur) = print(io, "NewRecur(", m.cell, ")") + +(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur") +(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur") + +function (l::NewRecur)(xs::AbstractArray) + results = l(l.cell.state0, xs) + results[2] # Only return the output here. +end + +function (l::NewRecur{false})(init_carry, xs) + results = scan_partial(l.cell, init_carry, xs) + results[1], results[2] +end + +function (l::NewRecur{true})(init_carry, xs) + results = scan_full(l.cell, init_carry, xs) + results[1], stack(results[2], dims=3) +end + + + # Stateful recurrence """ @@ -187,8 +368,14 @@ function (m::Recur)(x::AbstractArray{T, 3}) where T reshape(reduce(hcat, h), sze[1], sze[2], length(h)) end -# Vanilla RNN +######## +# +# Recurrent Cells +# +######## + +# Vanilla RNN struct RNNCell{F,I,H,V,S} σ::F Wi::I @@ -289,6 +476,8 @@ julia> r(rand(4, 10)) |> size # batch size of 10 RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) Recur(m::RNNCell) = Recur(m, m.state0) +NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) + # LSTM struct LSTMCell{I,H,V,S} @@ -362,6 +551,8 @@ julia> l(rand(Float32, 3, 10)) |> size # batch size of 10 LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) Recur(m::LSTMCell) = Recur(m, m.state0) +NewLSTM(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.LSTMCell(a...; ka...); return_sequence=return_sequence) + # GRU function _gru_output(gxs, ghs, bs) @@ -436,6 +627,8 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) Recur(m::GRUCell) = Recur(m, m.state0) +NewGRU(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUCell(a...; ka...); return_sequence=return_sequence) + # GRU v3 struct GRUv3Cell{I,H,V,HH,S} @@ -505,3 +698,5 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 """ GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) Recur(m::GRUv3Cell) = Recur(m, m.state0) + +NewGRUv3(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.GRUv3Cell(a...; ka...); return_sequence=return_sequence) diff --git a/test/layers/new_recur.jl b/test/layers/new_recur.jl new file mode 100644 index 0000000000..49b3152a65 --- /dev/null +++ b/test/layers/new_recur.jl @@ -0,0 +1,188 @@ +@testset "NewRecur RNN" begin + @testset "Forward Pass" begin + # tanh is needed for forward check to determine ordering of inputs. + cell = Flux.RNNCell(1, 1, tanh) + layer = Flux.NewRecur(cell; return_sequence=true) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + # Lets make sure th output is correct + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1, 2) + @test layer(x)[1, 1, 2] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed. + @test layer(cell.state0, x)[2][1,1,2] ≈ out[1,1] + + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) + end + + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Flux.NewRecur(cell; return_sequence = true) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = nm_layer(x_block) + sum(out[1, 1, 2]) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end + + @testset "gradients-explicit" begin + + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Flux.NewRecur(cell; return_sequence = true) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out[1, 1, 2]) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + end +end + +@testset "New Recur RNN Partial Sequence" begin + @testset "Forward Pass" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.NewRecur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1) + @test layer(x)[1, 1] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 + @test layer(cell.state0, x)[2][1,1] ≈ out[1,1] + + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) + end + + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Flux.NewRecur(cell; return_sequence = false) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = (nm_layer)(x_block) + sum(out) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end + + @testset "gradients-explicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Flux.NewRecur(cell; return_sequence = false) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3378aaa6d5..b3f92b569e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,55 +17,56 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin if get(ENV, "FLUX_TEST_CPU", "true") == "true" - @testset "Utils" begin - include("utils.jl") - end + # @testset "Utils" begin + # include("utils.jl") + # end - @testset "Loading" begin - include("loading.jl") - end + # @testset "Loading" begin + # include("loading.jl") + # end - @testset "Optimise / Train" begin - include("optimise.jl") - include("train.jl") - end + # @testset "Optimise / Train" begin + # include("optimise.jl") + # include("train.jl") + # end - @testset "Data" begin - include("data.jl") - end + # @testset "Data" begin + # include("data.jl") + # end - @testset "Losses" begin - include("losses.jl") - include("ctc.jl") - end + # @testset "Losses" begin + # include("losses.jl") + # include("ctc.jl") + # end @testset "Layers" begin - include("layers/attention.jl") - include("layers/basic.jl") - include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/recurrent.jl") - include("layers/conv.jl") - include("layers/upsample.jl") - include("layers/show.jl") + # include("layers/attention.jl") + # include("layers/basic.jl") + # include("layers/normalisation.jl") + # include("layers/stateless.jl") + # include("layers/recurrent.jl") + # include("layers/conv.jl") + # include("layers/upsample.jl") + # include("layers/show.jl") + include("layers/new_recur.jl") end - @testset "outputsize" begin - using Flux: outputsize - include("outputsize.jl") - end - - @testset "functors" begin - include("functors.jl") - end - - @static if VERSION == v"1.9" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) - end - end + # @testset "outputsize" begin + # using Flux: outputsize + # include("outputsize.jl") + # end + + # @testset "functors" begin + # include("functors.jl") + # end + + # @static if VERSION == v"1.9" + # using Documenter + # @testset "Docs" begin + # DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + # doctest(Flux) + # end + # end else @info "Skipping CPU tests." end