From d0701505c19f123428dc82d3472b62e6f82ede2b Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Thu, 24 Aug 2023 10:16:49 -0400 Subject: [PATCH] Update recurrent.jl Minor edits, adding some notes. --- src/layers/recurrent.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index aeb8801f05..fb58d072d3 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -99,13 +99,11 @@ function ChainRulesCore.rrule( # function is called in subsequent things... # hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this # accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing) - # @show typeof(accum_init) accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1]) hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b @show a, b c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) end - # @show typeof(hobbits) y = first(last(hobbits)) axe = axes(x) @@ -114,7 +112,6 @@ function ChainRulesCore.rrule( trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) end - # @show trio f_ds, f_da, f_db = accum_init[2](trio[end][2]) dop = sum(first, trio) + f_ds dx = [[f_db]; map(last, Iterators.reverse(trio))] @@ -124,6 +121,7 @@ function ChainRulesCore.rrule( return y, unfoldl end +# From Lux.jl # function ChainRulesCore.rrule( # config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, # ::typeof(Base.mapfoldl_impl),