Skip to content

Commit bfc1591

Browse files
Remove special treatment of SARTSA traces
Remove methods specifically defined for SARTSA traces in EpisodesBuffer and CircularPrioritizedTraces
1 parent 1560ff5 commit bfc1591

File tree

5 files changed

+147
-147
lines changed

5 files changed

+147
-147
lines changed

src/common/CircularPrioritizedTraces.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,6 @@ function Base.push!(t::CircularPrioritizedTraces, x)
3434
end
3535
end
3636

37-
function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
38-
initial_length = length(t.traces)
39-
push!(t.traces, x)
40-
if length(t.traces) == 1
41-
push!(t.keys, 1)
42-
push!(t.priorities, t.default_priority)
43-
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
44-
# only add a key if the length changes after insertion of the tuple
45-
# or if the trace is already at capacity
46-
push!(t.keys, t.keys[end] + 1)
47-
push!(t.priorities, t.default_priority)
48-
else
49-
# may be partial inserting at the first step, ignore it
50-
end
51-
end
52-
5337
function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
5438
if k === :priority
5539
@assert length(vs) == length(keys)

src/episodes.jl

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,13 @@ end
4343
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
4444
function is_capacity_plus_one(traces::AbstractTraces)
4545
if any(t->t isa MultiplexTraces, traces.traces)
46-
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
47-
return true
48-
elseif traces isa CircularPrioritizedTraces
49-
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
46+
# MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity
5047
return true
5148
else
5249
false
5350
end
5451
end
52+
is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces)
5553

5654
function EpisodesBuffer(traces::AbstractTraces)
5755
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
@@ -70,7 +68,7 @@ function EpisodesBuffer(traces::AbstractTraces)
7068
end
7169

7270
function Base.getindex(es::EpisodesBuffer, idx::Int...)
73-
@boundscheck all(es.sampleable_inds[idx...])
71+
@boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx))
7472
getindex(es.traces, idx...)
7573
end
7674

@@ -79,6 +77,7 @@ function Base.getindex(es::EpisodesBuffer, idx...)
7977
end
8078

8179
Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
80+
capacity(eb::EpisodesBuffer) = capacity(eb.traces)
8281
Base.size(eb::EpisodesBuffer) = size(eb.traces)
8382
Base.length(eb::EpisodesBuffer) = length(eb.traces)
8483
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
@@ -146,7 +145,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
146145
push!(eb.episodes_lengths, 0)
147146
push!(eb.sampleable_inds, 0)
148147
elseif !partial #typical inserting
149-
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
148+
if haskey(eb,:next_action) # if trace has next_action
150149
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
151150
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
152151
end
@@ -174,28 +173,6 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
174173
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
175174
end
176175

177-
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
178-
if max_length(eb) == capacity(eb.traces)
179-
popfirst!(eb)
180-
end
181-
push!(eb.traces, xs.namedtuple)
182-
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
183-
end
184-
185-
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
186-
if max_length(eb) == capacity(eb.traces)
187-
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
188-
xs = merge(xs.namedtuple, addition)
189-
push!(eb.traces, xs)
190-
pop!(eb.traces[:state].trace)
191-
pop!(eb.traces[:reward])
192-
pop!(eb.traces[:terminal])
193-
else
194-
push!(eb.traces, xs.namedtuple)
195-
eb.sampleable_inds[end-1] = 1
196-
end
197-
end
198-
199176
for f in (:pop!, :popfirst!)
200177
@eval function Base.$f(eb::EpisodesBuffer)
201178
$f(eb.episodes_lengths)

test/common.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ end
180180
default_priority=1.0f0
181181
)
182182

183-
eb = EpisodesBuffer(t)
183+
eb = EpisodesBuffer(t)
184184
push!(eb, (state = 1, action = 1))
185185
for i = 1:5
186-
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
186+
push!(eb, (state = i+1, action = i+1, reward = i, terminal = false))
187187
end
188188
push!(eb, (state = 7, action = 7))
189189
for (j,i) = enumerate(8:11)
190-
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
190+
push!(eb, (state = i, action = i, reward = i-1, terminal = false))
191191
end
192192
s = BatchSampler(1000)
193193
b = sample(s, eb)
@@ -222,6 +222,8 @@ end
222222

223223
b = sample(s, t)
224224

225+
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]
226+
225227
t[:priority, [1, 2]] = [0, 0]
226228

227229
# shouldn't be changed since [1,2] are old keys
@@ -240,18 +242,19 @@ end
240242
),
241243
default_priority=1.0f0
242244
)
243-
244-
eb = EpisodesBuffer(t)
245+
246+
eb = EpisodesBuffer(t)
245247
push!(eb, (state = 1,))
246248
for i = 1:5
247-
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
249+
push!(eb, (state = i+1, action = i, reward = i, terminal = false))
248250
end
249251
push!(eb, PartialNamedTuple((action = 6,)))
250252
push!(eb, (state = 7,))
251-
for (j,i) = enumerate(8:11)
252-
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
253+
for i = 8:11
254+
push!(eb, (state = i, action = i-1, reward = i-1, terminal = false))
253255
end
254-
push!(eb, PartialNamedTuple((action=12,)))
256+
push!(eb, PartialNamedTuple((action=11,)))
257+
255258
s = BatchSampler(1000)
256259
b = sample(s, eb)
257260
cm = counter(b[:state])

0 commit comments

Comments
 (0)