forked from mjzyle/DeepShift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.py
507 lines (435 loc) · 22.4 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import optim
from torchvision import datasets, transforms
import csv
import distutils
import os
from contextlib import redirect_stdout
import time
import torchsummary
import mnist
import copy
import deepshift
import unoptimized
from deepshift.convert import convert_to_shift, round_shift_weights, count_layer_type
from unoptimized.convert import convert_to_unoptimized
## MATT ADDITIONS ############################################################################
import psutil
import threading
import pandas as pd
import datetime as dt
from distutils import util
is_training = False
is_testing = False
training_perf = pd.DataFrame()
testing_perf = pd.DataFrame()
pid = os.getpid()
proc = psutil.Process(pid=pid)
proc.cpu_affinity([0]) # Limit number of CPUs used for processing
model_name = ''
loc_performance_profile_training = r"/home/alex/DeepShift/pytorch/performance_profiles_training.xlsx"
loc_performance_accuracy_testing = r"/home/alex/DeepShift/pytorch/performance_v_accuracy_testing.xlsx"
def t_report_usage_training(name):
wait = 0.1
while True:
global training_perf, proc
training_perf = training_perf.append({
'Time' : time.ctime(time.time()),
'CPU%' : proc.cpu_percent()/len(proc.cpu_affinity()),
'RAM%' : proc.memory_percent(),
'NumCPUs' : len(proc.cpu_affinity())
}, ignore_index=True)
time.sleep(wait)
global is_training
if not is_training:
break
def report_usage_training():
t = threading.Thread(target=t_report_usage_training, args=(1,))
t.start()
return t
def organize_performance_profile_training():
global training_perf, loc_performance_profile_training, model_name
training_perf_full = pd.read_excel(loc_performance_profile_training)
iterations = list(set(training_perf_full['IterationID']))
iterations.sort()
# Add an appropriate iteration ID to the full data
if len(iterations) == 0:
iterID = 0
else:
iterID = len(iterations)
# Add iteration-specific identifiers
training_perf['IterationID'] = iterID
training_perf['Dataset'] = 'MNIST'
training_perf['Model'] = model_name
# Normalize time data
training_perf['Time'] = pd.to_datetime(training_perf['Time'])
time_start = training_perf.loc[0, 'Time']
training_perf['TimeAdjusted'] = training_perf['Time'].apply(lambda x: (x - time_start) + dt.datetime(1900, 1, 1))
# Merge iteration data with full data
training_perf_full = pd.concat([training_perf_full, training_perf])
training_perf_full.to_excel(loc_performance_profile_training, index=False)
def organize_performance_accuracy_testing():
global testing_perf, loc_performance_accuracy_testing, model_name
testing_perf_full = pd.read_excel(loc_performance_accuracy_testing)
iterations = list(set(testing_perf_full['IterationID']))
iterations.sort()
# Add an appropriate iteration ID to the full data
if len(iterations) == 0:
iterID = 0
else:
iterID = len(iterations)
# Add iteration-specific identifiers
testing_perf['IterationID'] = iterID
testing_perf['Dataset'] = 'MNIST'
testing_perf['Model'] = model_name
# Merge iteration data with full data
testing_perf_full = pd.concat([testing_perf_full, testing_perf])
testing_perf_full.to_excel(loc_performance_accuracy_testing, index=False)
#############################################################################################
class LinearMNIST(nn.Module):
def __init__(self):
super(LinearMNIST, self).__init__()
self.fc1 = nn.Linear(1*28*28, 512)
self.dropout1 = nn.Dropout(0.2)
self.fc2 = nn.Linear(512, 512)
self.dropout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1 * 28 * 28)
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return F.log_softmax(x, dim=1)
class ConvMNIST(nn.Module):
def __init__(self):
super(ConvMNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, loss_fn, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
return loss.item()
def test(args, model, device, test_loader, loss_fn):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += loss_fn(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--type', default='linear',
choices=['linear', 'conv'],
help='model architecture type: ' +
' | '.join(['linear', 'conv']) +
' (default: linear)')
parser.add_argument('--model', default='', type=str, metavar='MODEL_PATH',
help='path to model file to load both its architecture and weights (default: none)')
parser.add_argument('--weights', default='', type=str, metavar='WEIGHTS_PATH',
help='path to file to load its weights (default: none)')
parser.add_argument('--shift-depth', type=int, default=0,
help='how many layers to convert to shift')
parser.add_argument('-st', '--shift-type', default='PS', choices=['Q', 'PS'],
help='type of DeepShift method for training and representing weights (default: PS)')
parser.add_argument('-r', '--rounding', default='deterministic', choices=['deterministic', 'stochastic'],
help='type of rounding (default: deterministic)')
parser.add_argument('-wb', '--weight-bits', type=int, default=5,
help='number of bits to represent the shift weights')
parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('-opt', '--optimizer', metavar='OPT', default="SGD",
help='optimizer algorithm')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
help='SGD momentum (default: 0.0)')
parser.add_argument('--resume', default='', type=str, metavar='CHECKPOINT_PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='only evaluate model on validation set')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--pretrained', dest='pretrained', default=False, type=lambda x:bool(distutils.util.strtobool(x)),
help='use pre-trained model of full conv or fc model')
parser.add_argument('--save-model', default=True, type=lambda x:bool(distutils.util.strtobool(x)),
help='For Saving the current Model (default: True)')
parser.add_argument('--print-weights', default=True, type=lambda x:bool(distutils.util.strtobool(x)),
help='For printing the weights of Model (default: True)')
parser.add_argument('--desc', type=str, default=None,
help='description to append to model directory name')
parser.add_argument('--use-kernel', type=lambda x:bool(distutils.util.strtobool(x)), default=False,
help='whether using custom shift kernel')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
if(args.evaluate is False and args.use_kernel is True):
raise ValueError('Our custom kernel currently supports inference only, not training.')
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}
# Load training MNIST data
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # transforms.Normalize((0,), (255,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
# Load testing MNIST data
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # transforms.Normalize((0,), (255,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
# Use an existing model (directory path provided)
if args.model:
if args.type or args.pretrained:
print("WARNING: Ignoring arguments \"type\" and \"pretrained\" when creating model...")
model = None
saved_checkpoint = torch.load(args.model)
if isinstance(saved_checkpoint, nn.Module):
model = saved_checkpoint
elif "model" in saved_checkpoint:
model = saved_checkpoint["model"]
else:
raise Exception("Unable to load model from " + args.model)
# Generate new model (linear or convolution)
else:
if args.type == 'linear':
model = LinearMNIST().to(device)
elif args.type == 'conv':
model = ConvMNIST().to(device)
if args.pretrained:
model.load_state_dict(torch.load("./models/mnist/simple_" + args.type + "/shift_0/weights.pth"))
model = model.to(device)
model_rounded = None
if args.weights:
saved_weights = torch.load(args.weights)
if isinstance(saved_weights, nn.Module):
state_dict = saved_weights.state_dict()
elif "state_dict" in saved_weights:
state_dict = saved_weights["state_dict"]
else:
state_dict = saved_weights
model.load_state_dict(state_dict)
if args.shift_depth > 0:
model, _ = convert_to_shift(model, args.shift_depth, args.shift_type, convert_all_linear=(args.type != 'linear'), convert_weights=True, use_kernel = args.use_kernel, use_cuda = use_cuda, rounding = args.rounding, weight_bits = args.weight_bits)
model = model.to(device)
elif args.use_kernel and args.shift_depth == 0:
model = convert_to_unoptimized(model)
model = model.to(device)
elif args.use_kernel and args.shift_depth == 0:
model = convert_to_unoptimized(model)
model = model.to(device)
loss_fn = F.cross_entropy # F.nll_loss
# define optimizer
optimizer = None
if(args.optimizer.lower() == "sgd"):
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum)
elif(args.optimizer.lower() == "adadelta"):
optimizer = torch.optim.Adadelta(model.parameters(), args.lr)
elif(args.optimizer.lower() == "adagrad"):
optimizer = torch.optim.Adagrad(model.parameters(), args.lr)
elif(args.optimizer.lower() == "adam"):
optimizer = torch.optim.Adam(model.parameters(), args.lr)
elif(args.optimizer.lower() == "rmsprop"):
optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
elif(args.optimizer.lower() == "radam"):
optimizer = optim.RAdam(model.parameters(), args.lr)
elif(args.optimizer.lower() == "ranger"):
optimizer = optim.Ranger(model.parameters(), args.lr)
else:
raise ValueError("Optimizer type: ", args.optimizer, " is not supported or known")
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'"
.format(args.resume))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# name model sub-directory "shift_all" if all layers are converted to shift layers
conv2d_layers_count = count_layer_type(model, nn.Conv2d) + count_layer_type(model, unoptimized.UnoptimizedConv2d)
linear_layers_count = count_layer_type(model, nn.Linear) + count_layer_type(model, unoptimized.UnoptimizedLinear)
if (args.shift_depth > 0):
if (args.shift_type == 'Q'):
shift_label = "shift_q"
else:
shift_label = "shift_ps"
else:
shift_label = "shift"
# name model sub-directory "shift_all" if all layers are converted to shift layers
conv2d_layers_count = count_layer_type(model, nn.Conv2d)
linear_layers_count = count_layer_type(model, nn.Linear)
if (conv2d_layers_count==0 and linear_layers_count==0):
shift_label += "_all"
else:
shift_label += "_%s" % (args.shift_depth)
if (args.shift_depth > 0):
shift_label += "_wb_%s" % (args.weight_bits)
if (args.desc is not None and len(args.desc) > 0):
desc_label = "_%s" % (args.desc)
else:
desc_label = ""
global model_name # MZ addition
model_name = 'simple_%s/%s%s' % (args.type, shift_label, desc_label)
# if evaluating round weights to ensure that the results are due to powers of 2 weights
if (args.evaluate):
model = round_shift_weights(model)
model_summary = None
try:
model_summary, model_params_info = torchsummary.summary_string(model, input_size=(1,28,28))
print(model_summary)
print("WARNING: The summary function reports duplicate parameters for multi-GPU case")
except:
print("WARNING: Unable to obtain summary of model")
model_dir = os.path.join(os.path.join(os.path.join(os.getcwd(), "models"), "mnist"), model_name)
if not os.path.isdir(model_dir):
os.makedirs(model_dir, exist_ok=True)
if (args.save_model):
with open(os.path.join(model_dir, 'command_args.txt'), 'w') as command_args_file:
for arg, value in sorted(vars(args).items()):
command_args_file.write(arg + ": " + str(value) + "\n")
with open(os.path.join(model_dir, 'model_summary.txt'), 'w') as summary_file:
with redirect_stdout(summary_file):
if (model_summary is not None):
print(model_summary)
print("WARNING: The summary function reports duplicate parameters for multi-GPU case")
else:
print("WARNING: Unable to obtain summary of model")
# del model_tmp_copy
start_time = time.time()
if args.evaluate:
test_loss, correct = test(args, model, device, test_loader, loss_fn)
test_log = [(test_loss, correct/1e4)]
with open(os.path.join(model_dir, "test_log.csv"), "w") as test_log_file:
test_log_csv = csv.writer(test_log_file)
test_log_csv.writerow(['test_loss', 'correct'])
test_log_csv.writerows(test_log)
else:
###################################################################################################################################
# Start recording training usage metrics
global is_training
is_training = True
t = report_usage_training()
###################################################################################################################################
train_log = []
###################################################################################################################################
test_start = None
test_end = None
###################################################################################################################################
for epoch in range(1, args.epochs + 1):
train_loss = train(args, model, device, train_loader, loss_fn, optimizer, epoch)
test_start = time.time()
test_loss, correct = test(args, model, device, test_loader, loss_fn)
test_end = time.time()
###################################################################################################################################
# Record test-specific metrics
test_set_size = len(test_loader.dataset)
eval_time = test_end - test_start
global testing_perf
testing_perf = testing_perf.append({
'TestSetSize' : test_set_size,
'EvaluationTime' : eval_time,
'Loss' : test_loss,
'Correct%' : correct/test_set_size,
'Epoch' : epoch
}, ignore_index=True)
###################################################################################################################################
if (args.print_weights):
with open(os.path.join(model_dir, 'weights_log_' + str(epoch) + '.txt'), 'w') as weights_log_file:
with redirect_stdout(weights_log_file):
# Log model's state_dict
print("Model's state_dict:")
# TODO: Use checkpoint above
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print(model.state_dict()[param_tensor])
print("")
train_log.append((epoch, train_loss, test_loss, correct/1e4))
###################################################################################################################################
# Stop recording training usage metrics
is_training = False
t.join()
global training_perf
training_perf.to_excel(model_dir + '\\train_performance.xlsx', index=False)
organize_performance_profile_training()
organize_performance_accuracy_testing()
###################################################################################################################################
with open(os.path.join(model_dir, "train_log.csv"), "w") as train_log_file:
train_log_csv = csv.writer(train_log_file)
train_log_csv.writerow(['epoch', 'train_loss', 'test_loss', 'test_accuracy'])
train_log_csv.writerows(train_log)
if (args.save_model):
model_rounded = round_shift_weights(model, clone=True)
torch.save(model_rounded, os.path.join(model_dir, "model.pth"))
torch.save(model_rounded.state_dict(), os.path.join(model_dir, "weights.pth"))
end_time = time.time()
print("Total Time:", end_time - start_time )
if (args.print_weights):
if(model_rounded is None):
model_rounded = round_shift_weights(model, clone=True)
with open(os.path.join(model_dir, 'weights_log.txt'), 'w') as weights_log_file:
with redirect_stdout(weights_log_file):
# Log model's state_dict
print("Model's state_dict:")
# TODO: Use checkpoint above
for param_tensor in model_rounded.state_dict():
print(param_tensor, "\t", model_rounded.state_dict()[param_tensor].size())
print(model_rounded.state_dict()[param_tensor])
print("")
if __name__ == '__main__':
main()
torch.cuda.empty_cache()