Skip to content

Commit 26c24f0

Browse files
committed
Update test for NStepBatchSampler
1 parent bfc7610 commit 26c24f0

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

test/samplers.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,42 +87,42 @@ import ReinforcementLearningTrajectories.fetch
8787
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
8888
end
8989
push!(eb, (state = 7, action = 7))
90-
for (j,i) = enumerate(8:11)
90+
for (j,i) = enumerate(8:12)
9191
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
9292
end
9393
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
94-
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
95-
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
94+
@test weights == [0,1,1,1,0,0,1,1,1,0,0]
95+
@test ns == [3,3,2,1,-1,3,3,3,2,1,0] #the -1 is due to ep_lengths[5] being that of 2nd episode but step_numbers[6] being that of 1st episode
9696
inds = [i for i in eachindex(weights) if weights[i] == 1]
9797
batch = sample(s1, eb)
9898
for key in keys(eb)
9999
@test haskey(batch, key)
100100
end
101101
#state: samples with stacksize
102102
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
103-
@test states == [1 2 3 4 7 8 9;
104-
2 3 4 5 8 9 10]
103+
@test states == [1 2 3 6 7 8;
104+
2 3 4 7 8 9]
105105
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
106106
#next_state: samples with stacksize and nsteps forward
107107
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
108-
@test next_states == [4 5 5 5 10 10 10;
109-
5 6 6 6 11 11 11]
108+
@test next_states == [4 4 4 9 10 10;
109+
5 5 5 10 11 11]
110110
@test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state])))
111111
#action: samples normally
112112
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
113-
@test actions == inds
113+
@test actions == [3, 4, 5, 8, 9, 10]
114114
@test all(in(actions), unique(batch[:action]))
115115
#next_action: is a multiplex trace: should automatically sample nsteps forward
116116
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
117-
@test next_actions == [5, 6, 6, 6, 11, 11, 11]
117+
@test next_actions == [6, 6, 6, 11, 12, 12]
118118
@test all(in(next_actions), unique(batch[:next_action]))
119119
#reward: discounted sum
120120
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
121-
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10]
121+
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4, 4, 7+0.99*8+0.99^2*9, 8+0.99*9+0.99^2*10,9+0.99*10]
122122
@test all(in(rewards), unique(batch[:reward]))
123123
#terminal: nsteps forward
124124
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
125-
@test terminals == [0,1,1,1,0,0,0]
125+
@test terminals == [0,0,0,0,0,0]
126126

127127
### CircularPrioritizedTraces and NStepBatchSampler
128128
γ = 0.99

0 commit comments

Comments
 (0)