-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuffer.py
64 lines (53 loc) · 2.39 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from collections import deque
import random
class BasicBuffer_a:
def __init__(self, size, obs_dim, act_dim):
self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
self.rews_buf = np.zeros([size], dtype=np.float32)
self.done_buf = np.zeros([size], dtype=np.float32)
self.ptr, self.size, self.max_size = 0, 0, size
def push(self, obs, act, rew, next_obs, done):
self.obs1_buf[self.ptr] = obs
self.obs2_buf[self.ptr] = next_obs
self.acts_buf[self.ptr] = act
self.rews_buf[self.ptr] = np.asarray([rew])
self.done_buf[self.ptr] = done
self.ptr = (self.ptr+1) % self.max_size
self.size = min(self.size+1, self.max_size)
def sample(self, batch_size=32):
idxs = np.random.randint(0, self.size, size=batch_size)
temp_dict= dict(s=self.obs1_buf[idxs],
s2=self.obs2_buf[idxs],
a=self.acts_buf[idxs],
r=self.rews_buf[idxs],
d=self.done_buf[idxs])
return (temp_dict['s'],temp_dict['a'],temp_dict['r'].reshape(-1,1),temp_dict['s2'],temp_dict['d'])
class BasicBuffer_b:
def __init__(self, size, obs_dim = None, act_dim = None):
self.max_size = size
self.buffer = deque(maxlen=size)
self.size = 0
def push(self, state, action, reward, next_state, done):
experience = (state, action, np.array([reward]), next_state, done)
self.size = min(self.size+1,self.max_size)
self.buffer.append(experience)
def sample(self, batch_size):
state_batch = []
action_batch = []
reward_batch = []
next_state_batch = []
done_batch = []
batch = random.sample(self.buffer, batch_size)
# np.random.seed(0)
batch = np.random.randint(0, len(self.buffer), size=batch_size)
for experience in batch:
state, action, reward, next_state, done = self.buffer[experience]
state_batch.append(state)
action_batch.append(action)
reward_batch.append(reward)
next_state_batch.append(next_state)
done_batch.append(done)
return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)