-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
328 lines (266 loc) · 10.1 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import sys
sys.path.insert(0,'src')
#@title Run to install MuJoCo and `dm_control`
import distutils.util
import subprocess
# if subprocess.run('nvidia-smi').returncode:
# raise RuntimeError(
# 'Cannot communicate with GPU. '
# 'Make sure you are using a GPU Colab runtime. '
# 'Go to the Runtime menu and select Choose runtime type.')
print('Installing dm_control...')
# !pip install -q dm_control==1.0.8
# # Configure dm_control to use the EGL rendering backend (requires GPU)
# %env MUJOCO_GL=osmesa
# %env PYOPENGL_PLATFORM=
# %env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
print('Checking that the dm_control installation succeeded...')
try:
from dm_control import suite
env = suite.load('cartpole', 'swingup')
pixels = env.physics.render()
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".')
else:
del pixels, suite
# !echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+")
# !rm -r "=1.0.8"
#@title All `dm_control` imports required for this tutorial
# The basic mujoco wrapper.
from dm_control import mujoco
# Access to enums and MuJoCo library functions.
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import mjlib
# PyMJCF
from dm_control import mjcf
# Composer high level imports
from dm_control import composer
from dm_control.composer.observation import observable
from dm_control.composer import variation
# Imports for Composer tutorial example
from dm_control.composer.variation import distributions
from dm_control.composer.variation import noises
from dm_control.locomotion.arenas import floors
# Control Suite
from dm_control import suite
# Run through corridor example
from dm_control.locomotion.walkers import cmu_humanoid
from dm_control.locomotion.arenas import corridors as corridor_arenas
from dm_control.locomotion.tasks import corridors as corridor_tasks
# # Soccer
# from dm_control.locomotion import soccer
# Manipulation
from dm_control import manipulation
#@title Other imports and helper functions
# General
import copy
import os
import itertools
from IPython.display import clear_output
import numpy as np
# Graphics-related
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
import PIL.Image
# Internal loading of video libraries.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam
# from torch.utils.tensorboard import SummaryWriter
# import this first to resolve the issue.
from acme import wrappers
from model import *
from utils import *
# Soft-Actor-Critic Model
from sac import *
from replay_memory import *
import argparse
import datetime
import itertools
import os
import random
import math
import pickle
# try out the wrappers
from acme import wrappers
from dm_control import suite
# # Use svg backend for figure rendering
# %config InlineBackend.figure_format = 'svg'
# Font sizes
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
# Inline video helper function
if os.environ.get('COLAB_NOTEBOOK_TEST', False):
# We skip video generation during tests, as it is quite expensive.
display_video = lambda *args, **kwargs: None
else:
def display_video(frames, framerate=30):
height, width, _ = frames[0].shape
dpi = 70
orig_backend = matplotlib.get_backend()
matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering.
fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
matplotlib.use(orig_backend) # Switch back to the original backend.
ax.set_axis_off()
ax.set_aspect('equal')
ax.set_position([0, 0, 1, 1])
im = ax.imshow(frames[0])
def update(frame):
im.set_data(frame)
return [im]
interval = 1000/framerate
anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
interval=interval, blit=True, repeat=False)
return HTML(anim.to_html5_video())
# Seed numpy's global RNG so that cell outputs are deterministic. We also try to
# use RandomState instances that are local to a single cell wherever possible.
np.random.seed(42)
###### Environment wrappers ####
from dm_env import specs
# environment wrappers
class NormilizeActionSpecWrapper(wrappers.EnvironmentWrapper):
"""Turn each dimension of the actions into the range of [-1, 1]."""
def __init__(self, environment):
super().__init__(environment)
action_spec = environment.action_spec()
self._scale = action_spec.maximum - action_spec.minimum
self._offset = action_spec.minimum
minimum = action_spec.minimum * 0 - 1.
maximum = action_spec.minimum * 0 + 1.
self._action_spec = specs.BoundedArray(
action_spec.shape,
action_spec.dtype,
minimum,
maximum,
name=action_spec.name)
def _from_normal_actions(self, actions):
actions = 0.5 * (actions + 1.0) # a_t is now in the range [0, 1]
# scale range to [minimum, maximum]
return actions * self._scale + self._offset
def step(self, action):
action = self._from_normal_actions(action)
return self._environment.step(action)
def action_spec(self):
return self._action_spec
class MujocoActionNormalizer(wrappers.EnvironmentWrapper):
"""Rescale actions to [-1, 1] range for mujoco physics engine.
For control environments whose actions have bounded range in [-1, 1], this
adaptor rescale actions to the desired range. This allows actor network to
output unscaled actions for better gradient dynamics.
"""
def __init__(self, environment, rescale='clip'):
super().__init__(environment)
self._rescale = rescale
def step(self, action):
"""Rescale actions to [-1, 1] range before stepping wrapped environment."""
if self._rescale == 'tanh':
scaled_actions = tree.map_structure(np.tanh, action)
elif self._rescale == 'clip':
scaled_actions = tree.map_structure(lambda a: np.clip(a, -1., 1.), action)
else:
raise ValueError('Unrecognized scaling option: %s' % self._rescale)
return self._environment.step(scaled_actions)
from IPython.display import display, HTML
#@title Loading and simulating a `suite` task{vertical-output: true}
# Load the environment
# random_state = np.random.RandomState(42)
# env = suite.load('hopper', 'stand', task_kwargs={'random': random_state})
# Simulate episode with random actions
def visualize(duration=10, save=False, name="vids.mp4"):
frames = []
ticks = []
rewards = []
observations = []
spec = env.action_spec()
time_step = env.reset()
while env.physics.data.time < duration:
state = get_flat_obs(time_step)
action = agent.select_action(state)
time_step = env.step(action)
camera0 = env.physics.render(camera_id=0, height=400, width=400)
camera1 = env.physics.render(camera_id=1, height=400, width=400)
frames.append(np.hstack((camera0, camera1)))
rewards.append(time_step.reward)
observations.append(copy.deepcopy(time_step.observation))
ticks.append(env.physics.data.time)
html_video = display_video(frames, framerate=1./env.control_timestep())
# Show video and plot reward and observations
num_sensors = len(time_step.observation)
_, ax = plt.subplots(1 + num_sensors, 1, sharex=True, figsize=(4, 8))
ax[0].plot(ticks, rewards)
ax[0].set_ylabel('reward')
ax[-1].set_xlabel('time')
for i, key in enumerate(time_step.observation):
data = np.asarray([observations[j][key] for j in range(len(observations))])
ax[i+1].plot(ticks, data, label=key)
ax[i+1].set_ylabel(key)
if save:
save_video(frames, video_name=name)
return html_video
# load the environment
env = suite.load(domain_name="cheetah", task_name="run")
# add wrappers onto the environment
env = NormilizeActionSpecWrapper(env)
env = MujocoActionNormalizer(environment=env, rescale='clip')
env = wrappers.SinglePrecisionWrapper(env)
class Args:
env_name = 'whatever'
policy = 'Gaussian'
eval = True
gamma = 0.99
tau = 0.005
lr = 0.0003
alpha = 0.2
automatic_entropy_tuning = True
seed = 42
batch_size = 256
num_steps = 1000000
hidden_size = 256
updates_per_step = 1
start_steps = 10000
target_update_interval = 1
replay_size = 1000000
# use the cuda to speedup
cuda = True
args = Args()
# get the dimensionality of the observation_spec after flattening
flat_obs = tree.flatten(env.observation_spec())
# combine all the shapes
# obs_dim = sum([item.shape[0] for item in flat_obs])
obs_dim = 0
for i in flat_obs:
try:
obs_dim += i.shape[0]
except IndexError:
obs_dim += 1
# setup agent, using Soft-Actor-Critic Model
agent = SAC(obs_dim, env.action_spec(), args)
# load checkpoint - UPLOAD YOUR FILE HERE!
model_path = 'src/models/sac_checkpoint_cheetah_123456_10000'
agent.load_checkpoint(model_path, evaluate=True)
# pull out model
model = agent.policy
# setup hook dict
hook_dict = init_hook_dict(model)
# add hooks
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(recordtodict_hook(name=name, hook_dict=hook_dict))
print(model)
print(f"Successfully Loaded the checkpoint from {model_path}")