Skip to content

Commit

Permalink
Replace Vector stack with Stack stack (#74)
Browse files Browse the repository at this point in the history
* Replace vector with stack

* Update unit test
  • Loading branch information
willtebbutt authored Feb 7, 2024
1 parent b6a1ac0 commit 1656be8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,11 @@ end
function build_coinsts(x::PiNode, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool)
val = _get_slot(x.val, _rrule!!)
ret = _rrule!!.slots[n]
old_vals = Vector{eltype(ret)}(undef, 0)
sizehint!(old_vals, 10)
old_vals = Stack{eltype(ret)}()
return build_coinsts(PiNode, val, ret, old_vals, _standard_next_block(is_blk_end, b))
end
function build_coinsts(
::Type{PiNode}, val::CoDualSlot{V}, ret::CoDualSlot{R}, old_vals::Vector, next_blk::Int,
::Type{PiNode}, val::CoDualSlot{V}, ret::CoDualSlot{R}, old_vals::Stack, next_blk::Int,
) where {V, R}
make_fwds(v) = R(primal(v), tangent(v))
fwds_inst = @opaque function (p::Int)
Expand Down
2 changes: 1 addition & 1 deletion test/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
@testset "PiNode" begin
val = SlotRef{CoDual{Any, Any}}(CoDual{Any, Any}(5.0, 0.0))
ret = SlotRef{CoDual{Float64, Float64}}(CoDual(-1.0, -1.0))
old_vals = Vector{CoDual{Float64, Float64}}(undef, 0)
old_vals = Stack{CoDual{Float64, Float64}}()
next_blk = 5
fwds_inst, bwds_inst = build_coinsts(PiNode, val, ret, old_vals, next_blk)

Expand Down

0 comments on commit 1656be8

Please sign in to comment.