-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn_utils.py
932 lines (636 loc) · 35.8 KB
/
nn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
# jax
from functools import partial
from typing import Optional, Sequence
import flax
import ipdb
import jax
import jax.numpy as np
import jax_tqdm
import optax
from equinox import filter_jit
from flax import linen as nn
from misc import *
# other, trivial stuff
# import numpy as onp
# import matplotlib.pyplot as pl
def train_test_split(ys, train_frac=0.9):
assert 0. < train_frac <= 1., '<gordon ramsey voice> this "training fraction" is not even a fraction you donkey'
N_train = int(train_frac * ys['x'].shape[0])
# deterministic -- always last data as test set. dumb idea.
# train_ys = jax.tree_util.tree_map(lambda n: n[:N_train], ys)
# test_ys = jax.tree_util.tree_map(lambda n: n[N_train:], ys)
train_idx = jax.random.choice(jax.random.PRNGKey(0), ys['x'].shape[0], (N_train,), replace=False)
# but this as a boolean mask:
train_mask = np.zeros(ys['x'].shape[0], dtype=bool).at[train_idx].set(True)
test_mask = ~train_mask
train_ys = jax.tree_util.tree_map(lambda n: n[train_mask], ys)
test_ys = jax.tree_util.tree_map(lambda n: n[test_mask], ys)
return train_ys, test_ys
class data_normaliser(object):
def __init__(self, train_ode_states, problem_params, algo_params):
# train_ode_states: dict with entries 'x', 'v', 'vx'.
# ['x'].shape == (N_pts, nx)
# ['v'].shape == (N_pts,)
# ['vx'].shape == (N_pts, nx)
# we scale each x component and v to zero mean and unit variance.
# then some standard differentiation rules tell us how vx and vxx
# must be changed under the linear change of variables.
# could drop problem params again...
x_means = train_ode_states['x'].mean(axis=0)
x_stds = train_ode_states['x'].std(axis=0)
N_pts, nx = train_ode_states['x'].shape
if 'normalise_states' in algo_params:
normalise_mask = algo_params['normalise_states']
assert normalise_mask.shape == (nx,), 'statewise normalisation mask invalid (shape)'
assert normalise_mask.dtype == bool, 'statewise normalisation mask invalid (dtype)'
# we do NOT normalise the states associated with the manifold.
# this keeps everything related to tangent/normal spaces nice and
# independent of the normalisation. I *think* this should work?
# what we do though is squeeze the rest of the state space...
# so maybe it does affect which costate is penalised how much
# in the sobolev loss function. maybe that is a good thing? as long
# as the vx values all stay in reasonable ranges...
dont_normalise = ~normalise_mask
x_means = x_means.at[dont_normalise].set(0)
x_stds = x_stds.at[dont_normalise].set(1)
self.normalise_x = lambda x: (x - x_means) / x_stds
self.unnormalise_x = lambda xn: xn * x_stds + x_means
# maybe min/max would make more sense here? always put it in range [0, 1] or sth?
v_mean = train_ode_states['v'].mean()
v_std = train_ode_states['v'].std()
self.normalise_v = lambda v: (v-v_mean) / v_std
self.unnormalise_v = lambda vn: vn * v_std + v_mean
# now the hard part, vx.
self.normalise_vx = lambda vx: vx * x_stds / v_std
self.unnormalise_vx = lambda vx_n: (vx_n / x_stds) * v_std
# proper multivariate version would be: vxx transformed = A vxx A.T
# where A is the coordinate transformation
# https://math.stackexchange.com/questions/1514680/gradient-and-hessian-for-linear-change-of-coordinates
self.normalise_vxx = lambda vxx: np.diag(x_stds) @ vxx @ np.diag(x_stds).T / v_std
self.unnormalise_vxx = lambda vxx_n: np.diag(1/x_stds) @ (vxx_n * v_std) @ np.diag(1/x_stds)
def normalise_all(self, train_ode_states):
# old format where everything was stacked into an array
print('normaliser: old format is not recommended!')
nn_xs = jax.vmap(self.normalise_x)(train_ode_states['x'])
nn_vs = jax.vmap(self.normalise_v)(train_ode_states['v'])
nn_vxs = jax.vmap(self.normalise_vx)(train_ode_states['vx'])
# to get the data format expected by the nn code...
# maybe change this so that we can use a nicer dict format?
# with entries 'x' 'v' 'vx' and maybe 'vxx'?
nn_ys = np.column_stack([nn_vxs, nn_vs])
return nn_xs, nn_ys
def normalise_all_dict(self, train_ode_states):
oup = {
'x': jax.vmap(self.normalise_x)(train_ode_states['x']),
'v': jax.vmap(self.normalise_v)(train_ode_states['v']),
'vx': jax.vmap(self.normalise_vx)(train_ode_states['vx']),
}
if 'vxx' in train_ode_states:
oup['vxx'] = jax.vmap(self.normalise_vxx)(train_ode_states['vxx'])
return oup
class my_nn_nonsmooth(nn.Module):
'''
maybe this is the way in which we represent the nonsmoothness exactly?
idea: we output like 8 (output_dim) "possible" value functions in a vector
z, and among them we take the lowest, v = min(z). Nonsmoothly!!!
this is the easy part. training will be less trivial. we want a loss
function that makes the NN do these things:
- for each label (v, vx), there should be AT LEAST one i such that
(z_i, grad_x z_i) ≈ (v, vx), whether the label is optimal or not.
- there should be no "spurious" solution z_i that is lower than all
solutions given by data.
- it is bound to happen that two solutions z_i, z_j switch first/second
places for the min(.) WITHOUT there being an actual discontinuity in the
solution (just imagine single integrator on unit circle. if decision
boundary is represented nonsmoothly there needs to be another nonsmooth
point too). we should somehow make sure that in that case, the
transition is somewhat smooth. either by making sure the two z's are
actually quite close in those regions, or by smoothening the transition
somehow. or both.
- make sure that two different local solutions don't randomly switch
places for no reason
can we somehow do it with *minimal* adaptation to the training procedure?
maybe it works if we just have this nn but normal training procedure???
surely not.
'''
# here output_dim is the dimensionality of the penultimate hidden variable.
# z = NN(x)
# v = min(z)
# z.shape == (penultimate_dim,)
features: Sequence[int]
penultimate_dim: Optional[int]
output_dim: Optional[int]
@nn.compact
def __call__(self, x):
for feat in self.features:
x = nn.Dense(features=feat)(x)
x = nn.softplus(x)
if self.output_dim is not None:
assert self.output_dim == 1
x = nn.Dense(features=self.penultimate_dim)(x)
x = np.min(x)
return x.squeeze()
class my_nn_experimental(nn.Module):
# here output_dim is the dimensionality of the penultimate hidden variable.
# z = NN(x)
# v = min(z)
# z.shape == (penultimate_dim,)
features: Sequence[int]
penultimate_dim: Optional[int]
output_dim: Optional[int]
@nn.compact
def __call__(self, x):
for feat in self.features:
x = nn.Dense(features=feat)(x)
x = 0.9 * nn.softplus(x) + 0.1 * x
if self.output_dim is not None:
# last layer relu?
assert self.output_dim == 1
'''
# last layer of half relu and half softplus
x = nn.Dense(features=self.penultimate_dim)(x)
half = self.penultimate_dim // 2
x_relu = nn.relu(x[:half])
x_softplus = nn.softplus(x[half:])
x = np.concatenate([x_relu, x_softplus])
x = nn.Dense(features=1)(x)
'''
'''
# regular "leaky squareplus" last layer
x = nn.Dense(features=self.penultimate_dim)(x)
x = 0.9 * jax.nn.squareplus(x) + 0.1 * x
x = nn.Dense(features=1)(x)
'''
x = nn.Dense(features=self.penultimate_dim)(x)
# "softmin" weights
x_contribs = nn.softmax(-x)
x = np.dot(x, x_contribs)
return x.squeeze()
class my_nn_leaky(nn.Module):
# simple, fully connected NN class.
# for bells & wistles -> nn_wrapper class :)
features: Sequence[int]
output_dim: Optional[int]
@nn.compact
def __call__(self, x):
for feat in self.features:
x = nn.Dense(features=feat)(x)
x = 0.9 * nn.softplus(x) + 0.1 * x
if self.output_dim is not None:
x = nn.Dense(features=self.output_dim)(x)
return x.squeeze()
class my_nn_flax(nn.Module):
# simple, fully connected NN class.
# for bells & wistles -> nn_wrapper class :)
features: Sequence[int]
output_dim: Optional[int]
@nn.compact
def __call__(self, x):
'''
if x.shape == (2,):
# some classic old feature engineering :)
# x = np.concatenate(
# [x, np.array([np.arctan2(x[0], x[1]), np.arctan2(x[1], x[0])]), np.sum(np.square(x))]
# )
# x = np.concatenate(
# [x, np.sum(np.square(x)), x/np.linalg.norm(x)]
# )
# don't even tell it about the original state, preposterous
x = np.concatenate(
[np.sum(np.square(x)), x/np.linalg.norm(x)]
)
'''
for feat in self.features:
x = nn.Dense(features=feat)(x)
x = nn.softplus(x)
if self.output_dim is not None:
x = nn.Dense(features=self.output_dim)(x)
return x.squeeze()
class nn_wrapper():
# all the usual business logic around the NN.
# initialisation, data loading, training, loss plotting
def __init__(self, problem_params, algo_params):
self.input_dim = problem_params['nx']
# self.layer_dims = algo_params['nn_layer_dims']
self.layer_dims = algo_params['nn_n_layers'] * (algo_params['nn_layer_dim'],)
self.output_dim = 1
if algo_params['nn_type'] == 'softplus':
self.nn = my_nn_flax(features=self.layer_dims, output_dim=self.output_dim)
elif algo_params['nn_type'] == 'leaky':
self.nn = my_nn_leaky(features=self.layer_dims, output_dim=self.output_dim)
elif algo_params['nn_type'] == 'minout_softplus':
self.nn = my_nn_nonsmooth(features=self.layer_dims, penultimate_dim=32, output_dim=self.output_dim)
elif algo_params['nn_type'] == 'experimental':
self.nn = my_nn_experimental(features=self.layer_dims, penultimate_dim=32, output_dim=self.output_dim)
else:
raise ValueError(f'NN type {algo_params["nn_type"]} unknown')
print(self.layer_dims)
# so we can use it like a function :)
def __call__(self, params, x):
return self.nn.apply(params, x)
def init_nn_params(self, key):
params = self.nn.init(key, np.zeros((self.input_dim,)))
return params
def init_and_train(self, key, xs, ys, algo_params):
'''
key(s): an array of PRNG keys
other arguments same as train.
returns: params, a dict just like usual, but each entry has an extra leading
dimension arising from the vmap.
'''
raise NotImplementedError('this still uses old data format')
params = self.nn.init(key, np.zeros((self.input_dim,)))
params, outputs = self.train(xs, ys, params, algo_params, key)
return params, outputs
def ensemble_mean_std(self, params, xs):
'''
each node of params should have an extra leading dimension,
as generated by ensemble_init_and_train.
returns two arrays of shape (N_points, 1+nx), where the second index
is 0 for the value output and 1...nx+1 for the costate/value gradient.
'''
raise NotImplementedError('this still uses old data format')
outputs = jax.vmap(self.nn.apply, in_axes=(0, None))(params, xs)
grad_outputs = jax.vmap(self.apply_grad, in_axes=(0, None))(params, xs)
all_outputs = np.concatenate([outputs, grad_outputs], axis=2)
means = all_outputs.mean(axis=0)
stds = all_outputs.std(axis=0)
return means, stds
def sobolev_loss_with_prior(self, key, y, params, v_prior, prior_extent, problem_params, algo_params):
# calculates the usual sobolev loss BUT adds a functional prior loss to
# it. main purpose is to avoid situations where v_nn(x) < 0 or ≈ 0
# outside of the data region, where we kind of know the value function
# must be HIGHER than all available data.
# here we could also slightly regularise ||vx||^2 to make
# it low-ish outside of the data region.
# also for a single data point, vmap outside.
sobolev_key, prior_key, noise_key = jax.random.split(key, 3)
original_loss, loss_terms = self.sobolev_loss(sobolev_key, y, params, problem_params, algo_params)
# evaluate the prior loss at a random point.
# extent = np.array([20, 20, 0., 0., 20, 20, 20]) # TODO put in algo_params too?
# impose *that* prior only at the point where otherwise we would
# get "wrong" values close to 0. This is more of a practical fix
# and less of a bayesian-inspired functional prior type story. but
# if it works who am I to judge (myself...) even outside of the
# manifold!
if algo_params['prior_strength'] > 0:
v_prior = algo_params['v_prior']
# ugly hardcoded things: where do we need the prior?
if problem_params['nx'] == 7 and problem_params['system_name'] == 'flatquad':
# small region around problematic "upside down" state
pushup_x = np.array([0, 0, 0, -1., 0, 0, 0]) + jax.random.normal(prior_key, shape=(problem_params['nx'],)) * 0.1
elif problem_params['nx'] == 2 and problem_params['system_name'] == 'orbits':
# large-ish circle around the equilibrium
pushup_x = jax.random.normal(prior_key, shape=(problem_params['nx'],))
pushup_x = 3 * pushup_x / np.linalg.norm(pushup_x) + problem_params['x_eq']
else:
raise ValueError('invalid configuration - no prior loss known')
v_pred = self.nn.apply(params, pushup_x)
# alternatively: only penalise too small v's, not too high.
# v_prior - v_pred > 0 <=> v_prior > v_pred which is bad.
# conversely if <0 (then the 0 is chosen instead) we overestimate which is good.
pushup_loss = np.maximum(0, v_prior - v_pred)
# smooth version for nicer plots hehehe
# pushup_loss = jax.nn.softplus(v_prior - v_pred)
prior_loss = pushup_loss
total_loss = original_loss + algo_params['prior_strength'] * prior_loss
loss_terms['pushup_prior'] = pushup_loss
else:
# no prior loss. mostly not needed for simpler R^n state spaces.
total_loss = original_loss
return total_loss, loss_terms
def sobolev_loss(self, key, y, params, problem_params, algo_params):
# this is for a *single* datapoint in dict form. vmap later.
# needs a PRNG key for the hvp in random direction. this is
# only a stochastic approximation of the actual sobolev loss.
# (vxx not used anymore and thus key not really needed)
# y is the dict with training data.
# tree_map(lambda z: z.shape, ys) should be: {
# 'x': (nx,), 'v': (1,), 'vx': (nx,), 'vxx': (nx, nx)
# }
v_pred = self.nn.apply(params, y['x'])
# does the same if jacobian is replaced by grad, jacfwd, jacrev \o/
# apparently jacobian = jacrev. grad is also reverse-mode.
# jacfwd is definitely not smart here (n arguments, 1 output)
vx_pred = jax.jacrev(self.nn.apply, argnums=1)(params, y['x'])
return self.sobolev_loss_inner(key, y, v_pred, vx_pred, problem_params, algo_params)
def sobolev_loss_inner(self, key, y, v_pred, vx_pred, problem_params, algo_params):
# adapted to not "stregthen" loss too much for tiny labels
# 1 + x = smoothed max(1, x)
# replace the 1 with the smallest order of magnitude we want to be
# accurate at.
aux_output = dict()
# for the ones we already know we might as well use quadratic loss.
# use_quadratic_loss = y['v'] <= v_k
use_quadratic_loss = False
# asymmetric, smooth huber type loss function.
# penalises overestimation heavily, underestimation less.
d = algo_params['v_loss_d']
v_rel_err = (v_pred - y['v']) / (algo_params['min_important_v'] + y['v'])
rel_err_sq = (v_rel_err)**2
# rel_err_smoothhuber = 2 * (np.sqrt(1 + rel_err_sq) - 1)
rel_err_smoothhuber = d**2 * 2 * (np.sqrt(1 + rel_err_sq/d**2) - 1)
underestimation = v_pred < y['v']
# also output a flag that says whether we are in the linear-ish
# region (-> outlier) or not (not outlier)
smooth_huber_linear = rel_err_sq / d**2 > 1
aux_output['v_loss_linear'] = underestimation & smooth_huber_linear
# v_loss = use_quadratic_loss * rel_err_sq + ~use_quadratic_loss * rel_err_sq
v_loss = jax.lax.select(use_quadratic_loss, rel_err_sq, rel_err_smoothhuber)
# v_loss = rel_err_sq # basic one again.
# nn_sobolev_weights = np.array(algo_params['nn_sobolev_weights'])
nn_sobolev_weights = np.array([algo_params['nn_sobolev_weight_v'], algo_params['nn_sobolev_weight_vx']])
if 'nn_sobolev_weight_vxx' in algo_params and algo_params['nn_sobolev_weight_vxx'] != 0.:
raise NotImplementedError('vxx loss is stale code, do not use')
nx = problem_params['nx']
if problem_params['m'] is not None:
assert nn_sobolev_weights.shape == (2,), 'vxx not implemented with manifold state space'
# in this case the state space is a submanifold of R^n:
# M = {x in R^n: m(x) = 0}.
# we still define the NN for inputs in ambient space R^n. v loss
# stays the same, but vx loss has to be adjusted so we only take
# derivatives in tangent space directions.
# We do this by constructing an orthonormal basis for normal
# space, based on constraint function m.
# this is trivial if the normal space is 1D (scalar constraint fct)
B = jax.jacobian(problem_params['m'])(y['x'])
assert B.shape == (nx,), 'only manifolds of codimension 1 supported rn'
# if codimension > 1, we will have to do one of:
# 1. define m such that B is always an orthonormal basis (and sanity check)
# 2. (ortho?)normalise it here after calculating the jacobian
# 3. use the pseudoinverse in the projection instead of transpose.
# 4. just ignore it, regularise in the directions given by the
# jacobian anyway, it is only a small regularisation after all
# if X is the cartesian product of several independent manifolds of
# co-dimension 1, all vectors \nabla_x m(x) are pairwise
# orthogonal, and we have basically done point 1. above.
# and normalise just for good measure.
B = B / np.linalg.norm(B)
# the main dish.
# orthogonal projection to normal space at current x
P_normal = np.outer(B, B) # B @ B.T also calculates dot product :(
# orthogonal projection to tangent space at current x
P_tangent = np.eye(nx) - P_normal
else:
# Rn has empty normal space (wrt itself)
P_normal = np.zeros((nx, nx))
P_tangent = np.eye(nx)
# factor out this entire calculation too. if cartesian, we naturally have
# P_tangent = I reducing this to the previous calculation, and P_normal=0
# so zero regularisation loss.
vx_err = (vx_pred - y['vx']) @ P_tangent
vx_normaliser = algo_params['min_important_vx'] + np.linalg.norm(y['vx'] @ P_tangent)
vx_label_loss_quadratic = np.sum( (vx_err / vx_normaliser)**2 )
# vx_label_loss_quadratic = np.sum( (vx_err)**2 / (algo_params['min_important_vx'] + np.sum(proj_label**2)) )
vx_reg_loss = np.sum( (vx_pred @ P_normal)**2 )
# factor the 'huberization' out as well. above if/else cases only have
# to calculate vx_label_loss_quadratic.
# this is always done, set like d=10 or 100 to 'disable'.
d = algo_params['vx_loss_d']
vx_label_loss_huber = d**2 * 2 * (np.sqrt(1 + vx_label_loss_quadratic/d**2) - 1)
smooth_huber_linear = vx_label_loss_huber / d**2 > 1
aux_output['vx_loss_linear'] = underestimation & smooth_huber_linear
vx_label_loss = jax.lax.select(use_quadratic_loss, vx_label_loss_quadratic, vx_label_loss_huber)
# now the scaling is here again (only needed once)
scaling = np.clip(np.exp(v_rel_err * algo_params['inv_vx_loss_fadeout']), 0., 1.)
# scaling = jax.lax.select(
# v_rel_err > 0,
# np.exp(-(v_rel_err * algo_params['inv_vx_loss_fadeout'])**2),
# 1.
# )
# cheat autodiff
scaling = jax.lax.stop_gradient(scaling)
vx_label_loss = vx_label_loss * scaling
# reg loss defined in if/else branches
vx_loss = vx_label_loss + algo_params['vx_normal_regularisation'] * vx_reg_loss
assert nn_sobolev_weights.shape == (2,)
normalised_weights = nn_sobolev_weights / np.sum(nn_sobolev_weights)
sobolev_losses = np.array([v_loss, vx_loss])
loss = normalised_weights @ sobolev_losses
lossterms = dict()
lossterms['v'] = v_loss
lossterms['vx_reg'] = vx_reg_loss
lossterms['vx_label'] = vx_label_loss
lossterms['total_loss'] = loss
aux_output['lossterms'] = lossterms
return loss, aux_output
# vmap the loss across a batch and get its mean.
# tuple output so the gradient is only taken of the first argument below (with has_aux=True)
def sobolev_loss_batch_mean(self, k, params, ys, problem_params, algo_params):
# the size of the actual batch, not what algo_params says.
# then we can use the same function e.g. for evaluating loss on test set.
ks = jax.random.split(k, ys['x'].shape[0])
losses, loss_terms = jax.vmap(self.sobolev_loss, in_axes=(0, 0, None, None, None))(ks, ys, params, problem_params, algo_params)
# mean across batch dim.
return np.mean(losses), jtm(lambda n: n.mean(axis=0), loss_terms)
def sobolev_loss_with_prior_batch_mean(self, k, params, ys, v_prior, prior_extent, problem_params, algo_params):
# the size of the actual batch, not what algo_params says.
# then we can use the same function e.g. for evaluating loss on test set.
ks = jax.random.split(k, ys['x'].shape[0])
losses, loss_terms = jax.vmap(self.sobolev_loss_with_prior, in_axes=(0, 0, None, None, None, None, None))(
ks, ys, params, v_prior, prior_extent, problem_params, algo_params
)
# mean across batch dim.
# return np.mean(losses), np.mean(loss_terms, axis=0)
return np.mean(losses, axis=0), jtm(lambda n: n.mean(axis=0), loss_terms)
# @filter_jit
def train_sobolev(self, key, ys, vk, vnext, nn_params, problem_params, algo_params, ys_test=None):
'''
new training method. main changes wrt self.train:
- data format is now this:
ys a dict with keys:
'x': (N_pts, nx) array of state space points just like before.
'v': (N_pts, 1) array of value function evaluations at those x.
'vx': (N_pts, nx) array of value gradient = costate evaluations
optionally:
'vxx': (N_pts, nx, nx) array of value hessians = costate jacobian evaluations.
this should make it easy to train with or without hessians with the same code.
maybe we can also initially train with v and vx and only "fine-tune" with the hessian?
- loss includes (optionally) the hessian error. specifically a stochastic approximation
of it: a hessian-vector product with a randomly chosen direction vector. idea from
Czarnecki et al.: https://arxiv.org/abs/1706.04859
- testset generation not in here. do it using train_test_split in this file. pass
ys_test to evaluate test loss during training (full test dataset every step!)
TODO as of now the test set is just a random subset of all points.
should we instead take a couple entire trajectories as test set? because
if 90% of the points on some trajectory are in the training set it is kind of
not a huge feat to have low loss on the remaining 10%
if OTOH we test with entire trajectories unseen in training the test loss kind of
is more meaningful...
'''
# first the actually meaningful things: set up the prior loss.
# prior value function: just a lot higher than the rest.
# v_prior = algo_params['v_prior_factor'] * np.clip(ys['v'].max(), 1., np.inf)
# extent of the box-shaped prior domain. here we take a minimum of 10,
# otherwise a factor times the data min/max extent. the factor
# determines the loss strength too! we do want the prior to act
# "mostly" in the region where data is not available. thus this factor
# must be > 1. volume ratio ~ factor ** nx!! i think this is good, this
# makes it unlikely that the prior acts in the data region even for
# relatively small factors.
# prior_extent = np.clip(algo_params['prior_extent_factor'] * np.abs(ys['x']).max(axis=0), 1., np.inf)
# update; don't use any of that
v_prior = None
prior_extent = np.array([20, 20, 0., 0., 20, 20, 10])
# make sure it is of correct shape?
testset_exists = ys_test is not None
N_datapts = ys['x'].shape[0]
batchsize = algo_params['nn_batchsize']
N_epochs = algo_params['nn_N_epochs']
# we want: total_iters * batchsize == N_epochs * N_datapts. therefore:
total_iters = (N_epochs * N_datapts) // batchsize
if total_iters < 20000:
total_iters = 20000
# exponential decay. this will go down from lr_init to lr_final over
# the whole training duration.
# if lr_staircase, then instead of smooth decay, we have stepwise decay
# with
N_lr_steps = algo_params['lr_staircase_steps']
total_decay = algo_params['lr_final'] / algo_params['lr_init']
# regardless of whether or not steps are used, the decay rate
# sets the decay *per transition step*.
lr_schedule = optax.exponential_decay(
init_value = algo_params['lr_init'],
transition_steps = total_iters // N_lr_steps,
decay_rate = (total_decay) ** (1/N_lr_steps),
end_value=algo_params['lr_final'],
staircase=algo_params['lr_staircase']
)
lr_schedule_exp = lr_schedule
# if we do a sweep from vk to vnext (and don't circumvent it by setting vk=vnext)
# then we want constant, low learning rate instead.
# this only works when disabling the outer jit's, otherwise jax thinks it can
# treat vk, vnext as traced values but they will hit this python dynamism here.
# if we circumvent the dynamism here, it will instead hit it at
# ./venv/lib/python3.10/site-packages/optax/schedules/_schedule.py:209
if algo_params['nn_value_sweep'] and vnext > vk:
lr_schedule = optax.constant_schedule(algo_params['lr_final'])
if algo_params['weight_decay'] > 0:
# default weight_decay=0.0001
optim = optax.adamw(learning_rate=lr_schedule, weight_decay=algo_params['weight_decay'])
else:
optim = optax.adam(learning_rate=lr_schedule)
# noodling around. amsgrad seems to achieve very low loss in the
# tail of training more easily, about 10x less than adam/adamw. in
# the 1st run this does also translate to better test loss. however,
# there is no weight decay and so we cannot expect warmstarting to work.
# optim = optax.amsgrad(learning_rate=lr_schedule)
opt_state = optim.init(nn_params)
def update_step(key, ys, opt_state, params):
# differentiate the whole thing wrt argument 1 = nn params.
if algo_params['prior_strength'] > 0:
(loss, loss_terms), grad = jax.value_and_grad(self.sobolev_loss_with_prior_batch_mean, argnums=1, has_aux=True)(
key, params, ys, v_prior, prior_extent, problem_params, algo_params
)
else:
(loss, loss_terms), grad = jax.value_and_grad(self.sobolev_loss_batch_mean, argnums=1, has_aux=True)(
key, params, ys, problem_params, algo_params
)
if algo_params['weight_decay'] > 0:
# adamw wants params here too.
updates, opt_state = optim.update(grad, opt_state, params)
else:
updates, opt_state = optim.update(grad, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, loss_terms
def f_scan(carry, input_slice):
# unpack the 'carry' state
nn_params, opt_state, key = carry
batch_key, loss_key, test_key, new_key = jax.random.split(key, 4)
# obtain minibatch
# here we could specify something like p ~ v <= v_sweep...
# would neatly fit in the input_slice which atm is not used...
if algo_params['nn_value_sweep']:
# sweep sublevel set from vk to vnext
# sadly, doing this with linspace on the outside and then having
# v_upper = input_slice
# here breaks jax_tqdm. so instead we pass the step k in here and
# calculate v_upper like this:
# constant sweep
step = input_slice
frac = step / total_iters
# sweep with staircase thing
# with N_steps subintervals, over each subinterval this will
# first increase twice as fast as frac, then stay constant.
# N_steps = 10
# width = 1 / N_steps
# frac = frac + (width/2) - np.abs(width/2 - frac % width)
v_upper = frac * vnext + (1-frac) * vk
do_sample = ys['v'] <= v_upper
# also correspondingly sweep up the lower bound
# if algo_params['thin_data']:
# v_lower = v_upper / algo_params['thin_data_denominator']
# do_sample = do_sample & (ys['v'] >= v_lower)
ps = do_sample / do_sample.sum()
batch_idx = jax.random.choice(batch_key, N_datapts, (batchsize,), p=ps)
else:
batch_idx = jax.random.choice(batch_key, N_datapts, (batchsize,))
ys_batch = jax.tree_util.tree_map(lambda node: node[batch_idx], ys)
# do the thing!!1!1!!1!
opt_state_new, nn_params_new, aux_output = update_step(
loss_key, ys_batch, opt_state, nn_params
)
aux_output['lr'] = lr_schedule(opt_state[0].count)
node_squared_norms = jtm(lambda w: np.sum(w**2), nn_params_new)
tree_squared_norm = jax.tree_util.tree_reduce(operator.add, node_squared_norms)
aux_output['weight_norm'] = np.sqrt(tree_squared_norm)
if algo_params['nn_value_sweep']:
aux_output['v_sweep'] = v_upper
# if given, calculate test loss.
# probably quite expensive to do this every iteration though...
# this if is "compile time"
if ys_test is not None:
# test_key = jax.random.PRNGKey(0) # just one sample. nicer plots :)
if algo_params['prior_strength'] > 0:
test_loss, test_aux_output = self.sobolev_loss_with_prior_batch_mean(
test_key, nn_params_new, ys_test, v_prior, prior_extent, problem_params, algo_params
)
else:
test_loss, test_aux_output = self.sobolev_loss_batch_mean(
test_key, nn_params_new, ys_test, problem_params, algo_params
)
aux_output['test_loss_terms'] = test_aux_output['lossterms']
new_carry = (nn_params_new, opt_state_new, new_key)
return new_carry, aux_output
if algo_params['nn_progressbar']:
# somehow this gives an error from within the library :(
# NOT ANYMORE thanks patrick!!
# https://github.com/mbjd/approximate_optimal_control/issues/1
f_scan = jax_tqdm.scan_tqdm(n=total_iters)(f_scan)
# the training loop!
# currently the input argument is unused -- could also put the PRNG key there.
# or sobolev loss weights if we decide to change them during training...
init_carry = (nn_params, opt_state, key)
final_carry, outputs = jax.lax.scan(f_scan, init_carry, np.arange(total_iters))
nn_params, _, _ = final_carry
return nn_params, outputs
# @filter_jit
def train_sobolev_ensemble(self, key, ys, vk, vnext, problem_params, algo_params, ys_test=None):
# train ensemble by vmapping the whole training procedure with different prng key.
# now the key affects both initialisation and batch selection for each nn.
init_key, train_key = jax.random.split(key)
init_keys = jax.random.split(init_key, algo_params['nn_ensemble_size'])
train_keys = jax.random.split(train_key, algo_params['nn_ensemble_size'])
vmap_params_init = jax.vmap(self.nn.init, in_axes=(0, None))(init_keys, np.zeros(problem_params['nx']))
# to trick around the optional argument. there is probably a neater way...
train_with_key_and_params = lambda k, params_init: self.train_sobolev(k, ys, vk, vnext, params_init, problem_params, algo_params, ys_test=ys_test)
return jax.vmap(train_with_key_and_params, in_axes=(0, 0))(train_keys, vmap_params_init)
# @filter_jit
def train_sobolev_ensemble_warmstarted(self, key, ys, vk, vnext, init_params_vmap, problem_params, algo_params, ys_test=None):
# train ensemble by vmapping the whole training procedure with
# different prng key AND from vmapped params.
# is this implemented in some very wrong way? does this add another
# axis of vmapping? suspiciously slow atm but only when passing test
# data... is that the reason? if the test dataset is much larger than
# the batches we would kind of expect that tbh
# adjust algoparams for warmstart situation. do this in a neater way if it works.
algo_params_warmstart = algo_params.copy()
portion = algo_params['nn_warmstart_fraction'] # repeat the last "portion" of the usual training loop.
algo_params_warmstart['nn_N_epochs'] = int(algo_params['nn_N_epochs'] * portion)
# algo_params_warmstart['lr_init'] = algo_params['lr_final'] * (algo_params['lr_init'] / algo_params['lr_final']) ** portion
# seemed to make it worse :(
# algo_params_warmstart['lr_init'] = algo_params['lr_final'] * (algo_params['lr_init'] / algo_params['lr_final']) ** portion
keys = jax.random.split(key, algo_params['nn_ensemble_size'])
# vmap key and parameters.
train_with_key_and_params = lambda k, params: self.train_sobolev(k, ys, vk, vnext, params, problem_params, algo_params_warmstart, ys_test=ys_test)
return jax.vmap(train_with_key_and_params, in_axes=(0, 0))(keys, init_params_vmap)