diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index aeb8801f05..fb58d072d3 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -99,13 +99,11 @@ function ChainRulesCore.rrule( # function is called in subsequent things... # hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this # accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing) - # @show typeof(accum_init) accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1]) hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b @show a, b c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) end - # @show typeof(hobbits) y = first(last(hobbits)) axe = axes(x) @@ -114,7 +112,6 @@ function ChainRulesCore.rrule( trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) end - # @show trio f_ds, f_da, f_db = accum_init[2](trio[end][2]) dop = sum(first, trio) + f_ds dx = [[f_db]; map(last, Iterators.reverse(trio))] @@ -124,6 +121,7 @@ function ChainRulesCore.rrule( return y, unfoldl end +# From Lux.jl # function ChainRulesCore.rrule( # config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, # ::typeof(Base.mapfoldl_impl),