@@ -87,42 +87,42 @@ import ReinforcementLearningTrajectories.fetch
87
87
push! (eb, (state = i+ 1 , action = i+ 1 , reward = i, terminal = i == 5 ))
88
88
end
89
89
push! (eb, (state = 7 , action = 7 ))
90
- for (j,i) = enumerate (8 : 11 )
90
+ for (j,i) = enumerate (8 : 12 )
91
91
push! (eb, (state = i, action = i, reward = i- 1 , terminal = false ))
92
92
end
93
93
weights, ns = ReinforcementLearningTrajectories. valid_range (s1, eb)
94
- @test weights == [0 ,1 ,1 ,1 ,1 , 0 ,0 ,1 ,1 ,1 ,0 ]
95
- @test ns == [3 ,3 ,3 , 2 ,1 ,- 1 ,3 ,3 ,2 ,1 ,0 ] # the -1 is due to ep_lengths[6 ] being that of 2nd episode but step_numbers[6] being that of 1st episode
94
+ @test weights == [0 ,1 ,1 ,1 ,0 ,0 ,1 ,1 ,1 , 0 ,0 ]
95
+ @test ns == [3 ,3 ,2 ,1 ,- 1 ,3 ,3 ,3 , 2 ,1 ,0 ] # the -1 is due to ep_lengths[5 ] being that of 2nd episode but step_numbers[6] being that of 1st episode
96
96
inds = [i for i in eachindex (weights) if weights[i] == 1 ]
97
97
batch = sample (s1, eb)
98
98
for key in keys (eb)
99
99
@test haskey (batch, key)
100
100
end
101
101
# state: samples with stacksize
102
102
states = ReinforcementLearningTrajectories. fetch (s1, eb[:state ], Val (:state ), inds, ns[inds])
103
- @test states == [1 2 3 4 7 8 9 ;
104
- 2 3 4 5 8 9 10 ]
103
+ @test states == [1 2 3 6 7 8 ;
104
+ 2 3 4 7 8 9 ]
105
105
@test all (in (eachcol (states)), unique (eachcol (batch[:state ])))
106
106
# next_state: samples with stacksize and nsteps forward
107
107
next_states = ReinforcementLearningTrajectories. fetch (s1, eb[:next_state ], Val (:next_state ), inds, ns[inds])
108
- @test next_states == [4 5 5 5 10 10 10 ;
109
- 5 6 6 6 11 11 11 ]
108
+ @test next_states == [4 4 4 9 10 10 ;
109
+ 5 5 5 10 11 11 ]
110
110
@test all (in (eachcol (next_states)), unique (eachcol (batch[:next_state ])))
111
111
# action: samples normally
112
112
actions = ReinforcementLearningTrajectories. fetch (s1, eb[:action ], Val (:action ), inds, ns[inds])
113
- @test actions == inds
113
+ @test actions == [ 3 , 4 , 5 , 8 , 9 , 10 ]
114
114
@test all (in (actions), unique (batch[:action ]))
115
115
# next_action: is a multiplex trace: should automatically sample nsteps forward
116
116
next_actions = ReinforcementLearningTrajectories. fetch (s1, eb[:next_action ], Val (:next_action ), inds, ns[inds])
117
- @test next_actions == [5 , 6 , 6 , 6 , 11 , 11 , 11 ]
117
+ @test next_actions == [6 , 6 , 6 , 11 , 12 , 12 ]
118
118
@test all (in (next_actions), unique (batch[:next_action ]))
119
119
# reward: discounted sum
120
120
rewards = ReinforcementLearningTrajectories. fetch (s1, eb[:reward ], Val (:reward ), inds, ns[inds])
121
- @test rewards ≈ [2 + 0.99 * 3 + 0.99 ^ 2 * 4 , 3 + 0.99 * 4 + 0.99 ^ 2 * 5 , 4 + 0.99 * 5 , 5 , 8 + 0.99 * 9 + 0.99 ^ 2 * 10 ,9 + 0.99 * 10 , 10 ]
121
+ @test rewards ≈ [2 + 0.99 * 3 + 0.99 ^ 2 * 4 , 3 + 0.99 * 4 , 4 , 7 + 0.99 * 8 + 0.99 ^ 2 * 9 , 8 + 0.99 * 9 + 0.99 ^ 2 * 10 ,9 + 0.99 * 10 ]
122
122
@test all (in (rewards), unique (batch[:reward ]))
123
123
# terminal: nsteps forward
124
124
terminals = ReinforcementLearningTrajectories. fetch (s1, eb[:terminal ], Val (:terminal ), inds, ns[inds])
125
- @test terminals == [0 ,1 , 1 , 1 ,0 ,0 ,0 ]
125
+ @test terminals == [0 ,0 , 0 ,0 ,0 ,0 ]
126
126
127
127
# ## CircularPrioritizedTraces and NStepBatchSampler
128
128
γ = 0.99
0 commit comments