-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathMyRemoteVectorEnv.py
136 lines (108 loc) · 5.22 KB
/
MyRemoteVectorEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import Tuple, Callable, Optional
from collections import OrderedDict
import gym
import torch
import numpy as np
import ray
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
from stable_baselines3.common.vec_env.util import dict_to_obs
from MyDummyVecEnv import MyDummyVecEnv
@PublicAPI
class MyRemoteVectorEnv(BaseEnv):
"""Vector env that executes envs in remote workers.
This provides dynamic batching of inference as observations are returned
from the remote simulator actors. Both single and multi-agent child envs
are supported, and envs can be stepped synchronously or async.
You shouldn't need to instantiate this class directly. It's automatically
inserted when you use the `remote_worker_envs` option for Trainers.
"""
@property
def observation_space(self):
return self._observation_space
def __init__(self, make_env: Callable[[int], EnvType], num_workers: int, env_per_worker: int, observation_space: Optional[gym.spaces.Space], device: torch.device):
self.make_local_env = make_env
self.num_workers = num_workers
self.env_per_worker = env_per_worker
self.num_envs = num_workers * env_per_worker
self.poll_timeout = None
self.actors = None # lazy init
self.pending = None # lazy init
self.observation_space = observation_space
self.keys = []
shapes = {}
dtypes = {}
for key, box in observation_space.items():
self.keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
self.device = device
self.buf_obs = OrderedDict(
[(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
for key in self.keys:
self.buf_obs[key][env_idx * self.env_per_worker: (env_idx + 1) * self.env_per_worker] = torch.from_numpy(obs[key]).to(self.device,
non_blocking=True)
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
MultiEnvDict, MultiEnvDict]:
if self.actors is None:
def make_remote_env(i):
return _RemoteSingleAgentEnv.remote(self.make_local_env, i, self.env_per_worker)
self.actors = [make_remote_env(i) for i in range(self.num_workers)]
if self.pending is None:
self.pending = {a.reset.remote(): a for a in self.actors}
# each keyed by env_id in [0, num_remote_envs)
ready = []
# Wait for at least 1 env to be ready here
while not ready:
ready, _ = ray.wait(
list(self.pending),
num_returns=len(self.pending),
timeout=self.poll_timeout)
for obj_ref in ready:
actor = self.pending.pop(obj_ref)
env_id = self.actors.index(actor)
ob, rew, done, info = ray.get(obj_ref)
self._save_obs(env_id, ob)
self.buf_rews[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = rew
self.buf_dones[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = done
self.buf_infos[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = info
return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, self.buf_obs)
@PublicAPI
def send_actions(self, action_list) -> None:
for worker_id in range(self.num_workers):
actions = action_list[worker_id * self.env_per_worker: (worker_id + 1) * self.env_per_worker]
actor = self.actors[worker_id]
obj_ref = actor.step.remote(actions)
self.pending[obj_ref] = actor
@PublicAPI
def try_reset(self,
env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
actor = self.actors[env_id]
obj_ref = actor.reset.remote()
self.pending[obj_ref] = actor
return ASYNC_RESET_RETURN
@PublicAPI
def stop(self) -> None:
if self.actors is not None:
for actor in self.actors:
actor.__ray_terminate__.remote()
@observation_space.setter
def observation_space(self, value):
self._observation_space = value
@ray.remote(num_cpus=1)
class _RemoteSingleAgentEnv:
"""Wrapper class for making a gym env a remote actor."""
def __init__(self, make_env, i, env_per_worker):
self.env = MyDummyVecEnv([lambda: make_env((i * env_per_worker) + k) for k in range(env_per_worker)])
def reset(self):
return self.env.reset(), 0, False, {}
def step(self, actions):
return self.env.step(actions)