Skip to content

Commit

Permalink
Improve collector (#125)
Browse files Browse the repository at this point in the history
* remove multibuf

* reward_metric

* make fileds with empty Batch rather than None after reset

* many fixes and refactor
Co-authored-by: Trinkle23897 <[email protected]>
  • Loading branch information
youkaichao authored and Trinkle23897 committed Jul 13, 2020
1 parent 5599a6d commit 26fb874
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 177 deletions.
38 changes: 22 additions & 16 deletions test/base/env.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
import time
import gym
import time
from gym.spaces.discrete import Discrete


class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0, dict_state=False):
"""This is a "going right" task. The task is to go right ``size`` steps.
"""

def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
self.size = size
self.sleep = sleep
self.dict_state = dict_state
self.ma_rew = ma_rew
self.action_space = Discrete(2)
self.reset()

def reset(self, state=0):
self.done = False
self.index = state
return self._get_dict_state()

def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
x = int(self.done)
if self.ma_rew > 0:
return [x] * self.ma_rew
return x

def _get_dict_state(self):
"""Generate a dict_state if dict_state is True."""
return {'index': self.index} if self.dict_state else self.index

def step(self, action):
Expand All @@ -23,22 +38,13 @@ def step(self, action):
time.sleep(self.sleep)
if self.index == self.size:
self.done = True
if self.dict_state:
return {'index': self.index}, 0, True, {}
else:
return self.index, 0, True, {}
return self._get_dict_state(), self._get_reward(), self.done, {}
if action == 0:
self.index = max(self.index - 1, 0)
if self.dict_state:
return {'index': self.index}, 0, False, {'key': 1, 'env': self}
else:
return self.index, 0, False, {}
return self._get_dict_state(), self._get_reward(), self.done, \
{'key': 1, 'env': self} if self.dict_state else {}
elif action == 1:
self.index += 1
self.done = self.index == self.size
if self.dict_state:
return {'index': self.index}, int(self.done), self.done, \
{'key': 1, 'env': self}
else:
return self.index, int(self.done), self.done, \
{'key': 1, 'env': self}
return self._get_dict_state(), self._get_reward(), \
self.done, {'key': 1, 'env': self}
50 changes: 46 additions & 4 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def learn(self):

def preprocess_fn(**kwargs):
# modify info before adding into the buffer
if kwargs.get('info', None) is not None:
# if info is not provided from env, it will be a ``Batch()``.
if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs'])
info = kwargs['info']
for i in range(n):
info[i].update(rew=kwargs['rew'][i])
return {'info': info}
# or
# return Batch(info=info)
# or: return Batch(info=info)
else:
return {}
return Batch()


class Logger(object):
Expand Down Expand Up @@ -119,6 +119,48 @@ def test_collector_with_dict_state():
print(batch['obs_next']['index'])


def test_collector_with_ma():
def reward_metric(x):
return x.sum()
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
r = c0.collect(n_step=3)['rew']
assert np.asanyarray(r).size == 1 and r == 0.
r = c0.collect(n_episode=3)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
r = c1.collect(n_step=10)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
obs = [
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
[[x] * 4 for x in rew])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10)
print(batch['obs_next'])


if __name__ == '__main__':
test_collector()
test_collector_with_dict_state()
test_collector_with_ma()
Loading

0 comments on commit 26fb874

Please sign in to comment.