forked from microsoft/oac-explore
-
Notifications
You must be signed in to change notification settings - Fork 1
/
replay_buffer.py
143 lines (115 loc) · 4.05 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from collections import OrderedDict
from utils.env_utils import get_dim
from gym.spaces import Discrete
import numpy as np
class ReplayBuffer(object):
def __init__(
self,
max_replay_buffer_size,
ob_space,
action_space,
):
"""
The class state which should not mutate
"""
self._ob_space = ob_space
self._action_space = action_space
ob_dim = get_dim(self._ob_space)
ac_dim = get_dim(self._action_space)
self._max_replay_buffer_size = max_replay_buffer_size
"""
The class mutable state
"""
self._observations = np.zeros((max_replay_buffer_size, ob_dim))
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, ob_dim))
self._actions = np.zeros((max_replay_buffer_size, ac_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
self._top = 0
self._size = 0
def add_path(self, path):
"""
Add a path to the replay buffer.
This default implementation naively goes through every step, but you
may want to optimize this.
"""
for i, (
obs,
action,
reward,
next_obs,
terminal,
agent_info,
env_info
) in enumerate(zip(
path["observations"],
path["actions"],
path["rewards"],
path["next_observations"],
path["terminals"],
path["agent_infos"],
path["env_infos"],
)):
self.add_sample(
observation=obs,
action=action,
reward=reward,
next_observation=next_obs,
terminal=terminal,
agent_info=agent_info,
env_info=env_info,
)
def add_paths(self, paths):
for path in paths:
self.add_path(path)
def add_sample(self, observation, action, reward, next_observation,
terminal, env_info, **kwargs):
assert not isinstance(self._action_space, Discrete)
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation
self._advance()
def _advance(self):
self._top = (self._top + 1) % self._max_replay_buffer_size
if self._size < self._max_replay_buffer_size:
self._size += 1
def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
batch = dict(
observations=self._observations[indices],
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=self._next_obs[indices],
)
return batch
def num_steps_can_sample(self):
return self._size
def get_diagnostics(self):
return OrderedDict([
('size', self._size)
])
def end_epoch(self, epoch):
return
def get_snapshot(self):
return dict(
_observations=self._observations,
_next_obs=self._next_obs,
_actions=self._actions,
_rewards=self._rewards,
_terminals=self._terminals,
_top=self._top,
_size=self._size,
)
def restore_from_snapshot(self, ss):
for key in ss.keys():
assert hasattr(self, key)
setattr(self, key, ss[key])