-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay_buffer.py
41 lines (32 loc) · 1.09 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from collections import namedtuple, deque
import random
Transition = namedtuple("Transition", ("state", "action", "reward", "next_state"))
class ExperienceReplayBuffer:
def __init__(self, capacity):
"""
Creates an experience replay buffer with length `capacity`
Implemented using a `deque`
Params:
capacity: int
"""
self.buffer = deque([], maxlen=capacity)
def add(self, *args):
"""
Adds the experience specified by the `args` to the experience buffer
`args` are first stored as a `Transition`
Params:
args: should be state, action, reward, next state
"""
self.buffer.append(Transition(*args))
def sample_experience(self, batch_size):
"""
Sample `batch_size` experiences from buffer. Returns a list of `Transitions`
Params:
batch_size: int
"""
return random.sample(self.buffer, batch_size)
def __len__(self):
"""
More ergonomic way of getting the length of the buffer
"""
return len(self.buffer)