-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
64 lines (48 loc) · 1.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import time
import signal
import numpy as np
"""Various utility functions."""
def log(data, numit=None):
"""Log output in standard format."""
if numit is None:
lstr = data
else:
dstr = [(k + (' = %.3e' % data[k])) for k in data.keys()]
lstr = '[%06d] ' % numit + ' '.join(dstr)
sys.stdout.write(time.strftime("%Y-%m-%d %H:%M:%S ") + lstr + "\n")
sys.stdout.flush()
def getstop():
"""Returns stop so that stop[0] is True if ctrl+c was hit."""
stop = [False]
_orig = [None]
def handler(_a, _b):
del _a
del _b
stop[0] = True
signal.signal(signal.SIGINT, _orig[0])
_orig[0] = signal.signal(signal.SIGINT, handler)
return stop
def saveopt(fname, opt, iters):
"""Save optimizer state + niters to file."""
weights = opt.get_weights()
npz = {('%d' % i): weights[i] for i in range(len(weights))}
npz['iters'] = np.int64(iters)
np.savez(fname, **npz)
def savemodel(fname, model):
"""Save model weights to file."""
weights = model.get_weights()
npz = {('%d' % i): weights[i] for i in range(len(weights))}
np.savez(fname, **npz)
def loadmodel(fname, model):
"""Restore model weights from file."""
npz = np.load(fname)
weights = [npz['%d' % i] for i in range(len(npz.files))]
model.set_weights(weights)
def loadopt(fname, opt, model):
"""Restore optimizer state from file."""
npz = np.load(fname)
weights = [npz['%d' % i] for i in range(len(npz.files) - 1)]
opt._create_all_weights(model.trainable_variables)
opt.set_weights(weights)
return npz['iters']