-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvector_env.py
32 lines (24 loc) · 873 Bytes
/
vector_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
class DummyVectorEnv:
def __init__(self, envs):
self.envs = envs
self.num = len(envs)
@staticmethod
def _concat_obs(obs_list):
return np.concatenate([np.expand_dims(obs, 0) for obs in obs_list], axis=0)
def reset(self):
return self._concat_obs([env.reset() for env in self.envs])
def step(self, act):
obs_list = []
rew_list = []
done_list = []
for env_id, env in enumerate(self.envs):
obs, rew, done, _ = env.step(act[env_id])
if done:
obs = env.reset()
obs_list.append(obs)
rew_list.append(rew)
done_list.append(done)
return self._concat_obs(obs_list), np.array(rew_list), np.array(done_list), None
def __getattr__(self, name):
return getattr(self.envs[0], name)