-
Notifications
You must be signed in to change notification settings - Fork 0
/
flatquad_sweep_fig.py
executable file
·290 lines (212 loc) · 9.72 KB
/
flatquad_sweep_fig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#!/usr/bin/env python
import diffrax
import flax
import jax
import jax.numpy as np
from jax import config
config.update("jax_enable_x64", True)
import gzip
import os
import pickle
import warnings
import datetime
from functools import partial
import ipdb
import matplotlib
import matplotlib.pyplot as pl
import numpy as onp
import scipy
import subprocess
import tqdm
import wandb
import sys
import pontryagin_utils
from fig_config import *
from misc import *
from flatquad_landing_experiment import base_algo_params, define_problem_params
# line fig plots, but for an entire sweep.
def pull_runs(sysname, sweep_name, T=np.inf):
# sysname: the system name used in problem_params
# sweep_name: the sweep name as given by algo_params['sweep_name']
# T: max age of data we include, in hours.
# runs = whatever is returned by wandb api :)
# 1. use wandb api to get all runs matching that sweep name
print(f'fetching runs for sweep {sweep_name} from wandb...')
api = wandb.Api()
runs = api.runs(path=f'mbjd-projects/levelsets_{sysname}', filters={'config.sweep_name': sweep_name})
if False:
# remove too old runs. either I am too dumb or wandb documentation is
# too crappy for me to find out how to do this with filters above.
print(f'got {len(runs)} runs')
cutoff_datetime = datetime.datetime.now() - datetime.timedelta(hours=T)
def is_recent(r):
run_datetime = datetime.datetime.strptime(r.createdAt, '%Y-%m-%dT%H:%M:%S')
return run_datetime > cutoff_datetime
runs = [r for r in runs if is_recent(r)]
print(f'...{len(runs)} of which are recent enough')
# 2. get the runs from euler if not present already.
print('pulling output data from euler (only current sweep)...')
run_data_cmd = ['rsync', '--dry-run']
run_data_cmd = ['rsync']
for r in runs:
run_data_cmd.append(f'--include={r.id}')
run_data_cmd.append("--include='*.msgpack.gz'")
run_data_cmd.append("--exclude='*'")
run_data_cmd.append('-av')
run_data_cmd.append('--progress')
run_data_cmd.append('[email protected]:/cluster/scratch/dbalduin/flatquad_runs/')
run_data_cmd.append('./euler_runs/')
# oup = subprocess.run(run_data_cmd)
# no clue why this works but not the other one
oup = subprocess.run(' '.join(run_data_cmd), shell=True)
print('pulling eval/plot data from euler (all runs)...')
# here we just get everything, much less data
plot_data_cmd = [
'rsync',
'-av',
'--progress',
'[email protected]:/cluster/scratch/dbalduin/plot_data/',
'plot_data/'
]
oup = subprocess.run(' '.join(plot_data_cmd), shell=True)
return runs
def plot_sweep(sysname, sweep_name, sweep_config):
# easiest to just pass both sweep name (value of the dummy algoparam arg,
# used for getting the correct runs) AND the sweep config (which config
# variable is modified in the sweep), then we can name them idependently.
# 2. pull the data from euler
runs = pull_runs(sysname, sweep_name, T=72)
# can we infer this from the runs object?
# sweep_config = 'active_learning_batchsize'
# nice name for plotting
nice_sweep_config = {
'active_learning_batchsize': 'Active learning batch size $N_\\text{batch}$',
'weight_decay': 'Weight Decay',
'nn_layer_dim': 'NN Layer size',
'pontryagin_solver_rtol': 'ODE Solver rtol',
'vx_loss_d': '$\lambda$ Huber width $\delta$',
'dtmax': 'ODE solver $\Delta t_\\text{max}$',
'inv_vx_loss_fadeout': 'Loss fadeout $\mu$',
'lr_final': '$\\text{lr}_\\text{final}$',
}[sweep_config]
batchsizes = [r.config[sweep_config] for r in runs]
batchsizes_unique = sorted([j for j in set(batchsizes)])
labels=[
'$\gamma_1(s) = [-10 + 20 s, 0, 0, 1, 0, 0, 0]$',
'$\gamma_2(s) = [-10 + 20 s, 0, 0, -1, 0, 5, 0]$',
'$\gamma_3(s) = [-10 + 20 s, 0, 0, -1, 0, 10, 0]$',
'$\gamma_4(s) = [-5, 5 s, 0, -1, 5 s, 5, 0]$',
'$\gamma_5(s) = [0, 0, \sin(2 \pi s), \cos(2 \pi s), 0, 5, 0]$',
'$\gamma_6(s) = [-5, 0, \sin(2 \pi s), \cos(2 \pi s), 5, 5, 0]$',
]
# put the closed loop / learned percentiles in a dict with key
# being the swept config.
fracs = dict()
fracs_TO = dict()
# read refsol costs only once.
fpath = os.path.join(data_dir, f'{sys_name}_refsol_costs.msgpack.gz')
with gzip.open(fpath, 'rb') as f:
bs = f.read()
refsol_outputs = flax.serialization.msgpack_restore(bs)
refsol_outputs = jtm(np.array, refsol_outputs) # np array -> jax array
for r in runs:
try:
last = r.history(pandas=False)[-1]
except IndexError:
print('run has empty history')
continue
relevant_config = r.config[sweep_config]
if relevant_config not in fracs:
fracs[relevant_config] = []
try:
fracs[relevant_config].append((
last['frac_ratio_005'],
last['frac_ratio_050'],
last['frac_ratio_500'],
))
except KeyError:
# data not present in wandb :(((( can we get it otherwise?
print('key not found.')
# also plot a bit of the other stats?
run_id = r.id
try:
fpath = os.path.join(data_dir, f'{sys_name}_{run_id}_controlcosts_lines.msgpack.gz')
with gzip.open(fpath, 'rb') as f:
bs = f.read()
eval_outputs_lines = flax.serialization.msgpack_restore(bs)
eval_outputs_lines = jtm(np.array, eval_outputs_lines) # np array -> jax array
fpath = os.path.join(data_dir, f'{sys_name}_{run_id}_controlcosts_common.msgpack.gz')
with gzip.open(fpath, 'rb') as f:
bs = f.read()
eval_outputs_common = flax.serialization.msgpack_restore(bs)
eval_outputs_common = jtm(np.array, eval_outputs_common) # np array -> jax array
# calculate control cost wrt TO refsol.
all_TO_refsols = np.concatenate([np.minimum(n['left'], n['right']) for n in refsol_outputs])
all_closedloop_costs = np.concatenate([oup['costs'] for oup in eval_outputs_lines])
suboptimalities = all_closedloop_costs / all_TO_refsols - 1
if relevant_config not in fracs_TO:
fracs_TO[relevant_config] = []
fracs_TO[relevant_config].append((
(suboptimalities < 0.05).mean().item(),
(suboptimalities < 0.50).mean().item(),
(suboptimalities < 5.00).mean().item(),
))
except Exception as e:
print(f'error in run {run_id}: {e}')
# can we pull this out into a function to do it for other metrics as
# well???
def dict_to_arrays(data_dict):
# converts a dict {x: [v0, v1, ...]} (with x representing a
# particular value of the swept variable) to arrays:
# xs = [x0, x1, ...]
# vmins = [v0min, v1min, ...]
# vmeans = [v0mean, v1mean, ...]
# vmaxs = [v0max, v1max, ...]
arraydict = {k: np.array(v) for k, v in data_dict.items() if len(v) > 0}
min_array = np.array([np.min(v, axis=0) for k, v in sorted(arraydict.items())])
mean_array = np.array([np.mean(v, axis=0) for k, v in sorted(arraydict.items())])
max_array = np.array([np.max(v, axis=0) for k, v in sorted(arraydict.items())])
xs = np.array(sorted(arraydict.keys()))
return xs, min_array, mean_array, max_array
fig = pl.figure('sweepfig', figsize=(pagewidth, .6*pagewidth))
def plot_fracs(fracs):
xs, ys_min, ys_mean, ys_max = dict_to_arrays(fracs)
pl.semilogx(xs, ys_mean, label=('p = 0.05', 'p = 0.5', 'p = 5'))
pl.gca().set_prop_cycle(None)
for j in range(3):
pl.fill_between(xs, ys_min[:, j], ys_max[:, j], alpha=confidence_band_alpha)
pl.legend()
pl.xlabel(nice_sweep_config)
pl.ylabel('P(relative suboptimality $\leq$ p)')
ylims = pl.ylim()
pl.ylim([ylims[0], 1])
pl.grid('on')
pl.subplot(121)
plot_fracs(fracs)
pl.gca().set_title('CDF Evaluations of relative suboptimality\nwrt. learned value: $\\frac{V^\\text{cl}_\Theta(x)}{\mu^\Theta(x)} - 1$')
# same but with boldsymbol theta.
pl.subplot(122)
plot_fracs(fracs_TO)
pl.gca().set_title('CDF Evaluations of relative suboptimality\nwrt. reference value: $\\frac{V^\\text{cl}_\Theta(x)}{V_\\text{ref}(x)} - 1$')
# second subplot with other stats like N iterations etc?
# most sensibly we would probably plot the following:
# - control cost on the six sweeps (compared to TO cost and/or by # itself)
# (as cdf of the suboptimality ratio?)
# - mean control cost on eval_xs (which are not the same for each run,
# not even from the same same distribution bc level sets differ a bit)
# - stats like number of steps, runtime, floating point ops even?
fig.tight_layout()
pl.savefig(f'./{fig_dir}/{sys_name}_sweep_{sweep_name}.{fig_format}', bbox_inches='tight', dpi=dpi)
if show:
pl.show()
if __name__ == '__main__':
sys_name = 'flatquad'
# plot_sweep(sys_name, 'vx_fadeout', 'inv_vx_loss_fadeout')
# plot_sweep(sys_name, 'lr_final', 'lr_final')
# plot_sweep(sys_name, 'dtmax', 'dtmax')
# plot_sweep(sys_name, 'vxd', 'vx_loss_d')
# plot_sweep(sys_name, 'batchsize', 'active_learning_batchsize')
# plot_sweep(sys_name, 'weight_decay', 'weight_decay')
# completely uninteresting sadly
plot_sweep(sys_name, 'rtol', 'pontryagin_solver_rtol')
plot_sweep(sys_name, 'layerdim', 'nn_layer_dim')