-
Notifications
You must be signed in to change notification settings - Fork 1
/
Ещё из блокнота stable baselines 3
418 lines (418 loc) · 15.2 KB
/
Ещё из блокнота stable baselines 3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "5.custom_gym_env.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/Alian3785/disciples-ai/blob/master/%D0%95%D1%89%D1%91%20%D0%B8%D0%B7%20%D0%B1%D0%BB%D0%BE%D0%BA%D0%BD%D0%BE%D1%82%D0%B0%20stable%20baselines%203\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AoxOjIlOImwx"
},
"source": [
"# Stable Baselines Tutorial - Creating a custom Gym environment\n",
"\n",
"Github repo: https://github.com/araffin/rl-tutorial-jnrr19\n",
"\n",
"Stable-Baselines: https://github.com/hill-a/stable-baselines\n",
"\n",
"Documentation: https://stable-baselines.readthedocs.io/en/master/\n",
"\n",
"RL Baselines zoo: https://github.com/araffin/rl-baselines-zoo\n",
"\n",
"\n",
"## Introduction\n",
"\n",
"In this notebook, you will learn how to use your own environment following the OpenAI Gym interface.\n",
"Once it is done, you can easily use any compatible (depending on the action space) RL algorithm from Stable Baselines on that environment.\n",
"\n",
"## Install Dependencies and Stable Baselines Using Pip\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Sp8rSS4DIhEV"
},
"source": [
"# Stable Baselines only supports tensorflow 1.x for now\n",
"\n",
"!pip install stable-baselines3[extra]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rzevZcgmJmhi"
},
"source": [
"## First steps with the gym interface\n",
"\n",
"As you have noticed in the previous notebooks, an environment that follows the gym interface is quite simple to use.\n",
"It provides to this user mainly three methods:\n",
"- `reset()` called at the beginning of an episode, it returns an observation\n",
"- `step(action)` called to take an action with the environment, it returns the next observation, the immediate reward, whether the episode is over and additional information\n",
"- (Optional) `render(method='human')` which allow to visualize the agent in action. Note that graphical interface does not work on google colab, so we cannot use it directly (we have to rely on `method='rbg_array'` to retrieve an image of the scene\n",
"\n",
"Under the hood, it also contains two useful properties:\n",
"- `observation_space` which one of the gym spaces (`Discrete`, `Box`, ...) and describe the type and shape of the observation\n",
"- `action_space` which is also a gym space object that describes the action space, so the type of action that can be taken\n",
"\n",
"The best way to learn about gym spaces is to look at the [source code](https://github.com/openai/gym/tree/master/gym/spaces), but you need to know at least the main ones:\n",
"- `gym.spaces.Box`: A (possibly unbounded) box in $R^n$. Specifically, a Box represents the Cartesian product of n closed intervals. Each interval has the form of one of [a, b], (-oo, b], [a, oo), or (-oo, oo). Example: A 1D-Vector or an image observation can be described with the Box space.\n",
"```python\n",
"# Example for using image as input:\n",
"observation_space = spaces.Box(low=0, high=255, shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)\n",
"``` \n",
"\n",
"- `gym.spaces.Discrete`: A discrete space in $\\{ 0, 1, \\dots, n-1 \\}$\n",
" Example: if you have two actions (\"left\" and \"right\") you can represent your action space using `Discrete(2)`, the first action will be 0 and the second 1.\n",
"\n",
"\n",
"\n",
"[Documentation on custom env](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "I98IKKyNJl6K"
},
"source": [
"import gym\n",
"\n",
"env = gym.make(\"CartPole-v1\")\n",
"\n",
"# Box(4,) means that it is a Vector with 4 components\n",
"print(\"Observation space:\", env.observation_space)\n",
"print(\"Shape:\", env.observation_space.shape)\n",
"# Discrete(2) means that there is two discrete actions\n",
"print(\"Action space:\", env.action_space)\n",
"\n",
"# The reset method is called at the beginning of an episode\n",
"obs = env.reset()\n",
"# Sample a random action\n",
"action = env.action_space.sample()\n",
"print(\"Sampled action:\", action)\n",
"obs, reward, done, info = env.step(action)\n",
"# Note the obs is a numpy array\n",
"# info is an empty dict for now but can contain any debugging info\n",
"# reward is a scalar\n",
"print(obs.shape, reward, done, info)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "RqxatIwPOXe_"
},
"source": [
"## Gym env skeleton\n",
"\n",
"In practice this is how a gym environment looks like.\n",
"Here, we have implemented a simple grid world were the agent must learn to go always left."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rYzDXA9vJfz1"
},
"source": [
"import numpy as np\n",
"import gym\n",
"from gym import spaces\n",
"\n",
"\n",
"class GoLeftEnv(gym.Env):\n",
" \"\"\"\n",
" Custom Environment that follows gym interface.\n",
" This is a simple env where the agent must learn to go always left. \n",
" \"\"\"\n",
" # Because of google colab, we cannot implement the GUI ('human' render mode)\n",
" metadata = {'render.modes': ['console']}\n",
" # Define constants for clearer code\n",
" LEFT = 0\n",
" RIGHT = 1\n",
" HIGH = 2\n",
" LOW = 3\n",
" ATTACK = 4\n",
" def __init__(self, grid_size=10):\n",
" super(GoLeftEnv, self).__init__()\n",
"\n",
" # Size of the 1D-grid\n",
" self.grid_size = grid_size\n",
" # Initialize the agent at the right of the grid\n",
" self.agent_pos = grid_size - 1\n",
"\n",
" # Define action and observation space\n",
" # They must be gym.spaces objects\n",
" # Example when using discrete actions, we have two: left and right\n",
" n_actions = 5\n",
" self.action_space = spaces.Discrete(n_actions)\n",
" # The observation will be the coordinate of the agent\n",
" # this can be described both by Discrete and Box space\n",
" self.observation_space = spaces.Box(low=0, high=100,\n",
" shape=(3,), dtype=np.float32)\n",
"\n",
" def reset(self):\n",
" \"\"\"\n",
" Important: the observation must be a numpy array\n",
" :return: (np.array) \n",
" \"\"\"\n",
" # Initialize the agent at the right of the grid\n",
" self.agent_posX = 1\n",
" self.agent_posY = 1\n",
" self.enemyhealth = 10\n",
" # here we convert to float32 to make it more general (in case we want to use continuous actions)\n",
" return np.array([self.agent_posX, self.agent_posY, self.enemyhealth]).astype(np.float32)\n",
"\n",
" def step(self, action):\n",
" if action == self.LEFT:\n",
" if (self.agent_posX == 1):\n",
" self.agent_posX = 1\n",
" else: \n",
" self.agent_posX -= 1\n",
" elif action == self.RIGHT:\n",
" if (self.agent_posX == 10):\n",
" self.agent_posX = 10\n",
" else: \n",
" self.agent_posX += 1\n",
" elif action == self.LOW:\n",
" if (self.agent_posY == 1):\n",
" self.agent_posY = 1\n",
" else: \n",
" self.agent_posY -= 1\n",
" elif action == self.HIGH:\n",
" if (self.agent_posY == 10):\n",
" self.agent_posY = 10\n",
" else: \n",
" self.agent_posY += 1\n",
" elif action == self.ATTACK:\n",
" if (self.agent_posX == 7 and self.agent_posY == 6):\n",
" self.enemyhealth = 0 \n",
" else:\n",
" raise ValueError(\"Received invalid action={} which is not part of the action space\".format(action))\n",
"\n",
" # Are we at the left of the grid?\n",
" done = bool(self.enemyhealth == 0)\n",
"\n",
" # Null reward everywhere except when reaching the goal (left of the grid)\n",
" reward = 100 if self.enemyhealth == 0 else -1\n",
"\n",
" # Optionally we can pass additional info, we are not using that for now\n",
" info = {}\n",
"\n",
" return np.array([self.agent_posX, self.agent_posY, self.enemyhealth]).astype(np.float32), reward, done, info\n",
"\n",
" def render(self, mode='console'):\n",
" if mode != 'console':\n",
" raise NotImplementedError()\n",
" # agent is represented as a cross, rest as a dot\n",
" print(\"Координата X\")\n",
" print(self.agent_posX)\n",
" print(\"Координата Y\")\n",
" print(self.agent_posY)\n",
"\n",
" def close(self):\n",
" pass\n",
" "
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zy5mlho1-Ine"
},
"source": [
"### Validate the environment\n",
"\n",
"Stable Baselines provides a [helper](https://stable-baselines.readthedocs.io/en/master/common/env_checker.html) to check that your environment follows the Gym interface. It also optionally checks that the environment is compatible with Stable-Baselines (and emits warning if necessary)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "9DOpP_B0-LXm"
},
"source": [
"from stable_baselines.common.env_checker import check_env"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1CcUVatq-P0l"
},
"source": [
"env = GoLeftEnv()\n",
"# If the environment don't follow the interface, an error will be thrown\n",
"check_env(env, warn=True)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eJ3khFtkSE0g"
},
"source": [
"### Testing the environment"
]
},
{
"cell_type": "code",
"metadata": {
"id": "i62yf2LvSAYY"
},
"source": [
"env = GoLeftEnv(grid_size=10)\n",
"\n",
"obs = env.reset()\n",
"env.render()\n",
"\n",
"print(env.observation_space)\n",
"print(env.action_space)\n",
"print(env.action_space.sample())\n",
"\n",
"GO_LEFT = 0\n",
"# Hardcoded best agent: always go left!\n",
"n_steps = 20\n",
"for step in range(n_steps):\n",
" print(\"Step {}\".format(step + 1))\n",
" obs, reward, done, info = env.step(GO_LEFT)\n",
" print('obs=', obs, 'reward=', reward, 'done=', done)\n",
" env.render()\n",
" if done:\n",
" print(\"Goal reached!\", \"reward=\", reward)\n",
" break"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pv1e1qJETfHU"
},
"source": [
"### Try it with Stable-Baselines\n",
"\n",
"Once your environment follow the gym interface, it is quite easy to plug in any algorithm from stable-baselines"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PQfLBE28SNDr"
},
"source": [
"from stable_baselines3 import DQN, A2C, HER, TD3, PPO\n",
"from stable_baselines3.common.env_util import make_vec_env\n",
"\n",
"# Instantiate the env\n",
"env = GoLeftEnv(grid_size=10)\n",
"# wrap it\n",
"env = make_vec_env(lambda: env, n_envs=1)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zRV4Q7FVUKB6"
},
"source": [
"# Train the agent\n",
"model = DQN('MlpPolicy', env, verbose=1).learn(400000)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BJbeiF0RUN-p"
},
"source": [
"# Test the trained agent\n",
"obs = env.reset()\n",
"n_steps = 100\n",
"for step in range(n_steps):\n",
" action, _ = model.predict(obs, deterministic=True)\n",
" print(\"Step {}\".format(step + 1))\n",
" print(\"Action: \", action)\n",
" obs, reward, done, info = env.step(action)\n",
" print('obs=', obs, 'reward=', reward, 'done=', done)\n",
" env.render(mode='console')\n",
" if done:\n",
" # Note that the VecEnv resets automatically\n",
" # when a done signal is encountered\n",
" print(\"Goal reached!\", \"reward=\", reward)\n",
" break"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jOggIa9sU--b"
},
"source": [
"## It is your turn now, be creative!\n",
"\n",
"As an exercise, that's now your turn to build a custom gym environment.\n",
"There is no constrain about what to do, be creative! (but not too creative, there is not enough time for that)\n",
"\n",
"If you don't have any idea, here is is a list of the environment you can implement:\n",
"- Transform the discrete grid world to a continuous one, you will need to change a bit the logic and the action space\n",
"- Create a 2D grid world and add walls\n",
"- Create a tic-tac-toe game\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lBDp4Pm-Uh4D"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}