-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathutils.py
82 lines (68 loc) · 2.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os.path
import numpy as np
def create_dir_if_not_exists(dir):
if not os.path.exists(dir):
dir += '/1'
os.makedirs(dir)
else:
sub_dirs = next(os.walk(dir))[1]
if len(sub_dirs) > 0:
print(dir)
arr = np.asarray(sub_dirs).astype('int')
sub = str(arr.max() + 1)
print(sub)
dir += '/' + sub
print(dir)
else:
dir += '/1'
os.makedirs(dir)
print('Logging to %s' % dir)
return dir
def handle_args(args):
binary = first = last = xnor = batch_norm = False
log_path = ''
# handle command line args
if args.binary:
print("Using 1-bit weights and activations")
binary = True
# only binarize last layer if received binary flag
if args.last:
last = True
if args.first:
first = True
if first and last:
sub_1 = '/bin_all/'
elif first and not last:
sub_1 = '/bin_first/'
elif last and not first:
sub_1 = '/bin_last/'
else:
sub_1 = '/bin/'
# only use xnor kernel if received binary flag
if args.xnor:
print("Using xnor xnor_gemm kernel")
xnor = True
sub_2 = 'xnor/'
else:
sub_2 = 'matmul/'
else:
sub_1 = '/fp/'
sub_2 = ''
if args.log_dir:
log_path = args.log_dir + sub_1 + sub_2 + \
'hid_' + str(args.n_hidden) + '/'
if args.batch_norm:
print("Using batch normalization")
batch_norm = True
if args.log_dir:
log_path += 'batch_norm/'
if args.log_dir:
log_path += 'bs_' + str(args.batch_size) + '/keep_' + \
str(args.keep_prob) + '/lr_' + str(args.lr)
# reg is a bnn regularization only, while we do dropout for both bin and fp
if binary:
log_path += '/reg_' + str(args.reg)
if args.extra:
log_path += '/' + args.extra
log_path = create_dir_if_not_exists(log_path)
return log_path, binary, first, last, xnor, batch_norm