Skip to content

Commit

Permalink
Update recurrent.jl
Browse files Browse the repository at this point in the history
Minor edits, adding some notes.
  • Loading branch information
mkschleg authored Aug 24, 2023
1 parent 7a467cc commit d070150
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,11 @@ function ChainRulesCore.rrule(
# function is called in subsequent things...
# hobbits = Vector{Tuple}(undef, length(x)) # Unfornately Zygote needs this
# accum_init = ChainRulesCore.rrule_via_ad(config, op, init[1], nothing)
# @show typeof(accum_init)
accum_init = ChainRulesCore.rrule_via_ad(config, op, init, x[1])
hobbits = accumulate(x[begin+1:end]; init=accum_init) do (a, _), b
@show a, b
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
end
# @show typeof(hobbits)

y = first(last(hobbits))
axe = axes(x)
Expand All @@ -114,7 +112,6 @@ function ChainRulesCore.rrule(
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
# @show trio
f_ds, f_da, f_db = accum_init[2](trio[end][2])
dop = sum(first, trio) + f_ds
dx = [[f_db]; map(last, Iterators.reverse(trio))]
Expand All @@ -124,6 +121,7 @@ function ChainRulesCore.rrule(
return y, unfoldl
end

# From Lux.jl
# function ChainRulesCore.rrule(
# config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
# ::typeof(Base.mapfoldl_impl),
Expand Down

0 comments on commit d070150

Please sign in to comment.