12
12
import torch .backends .cudnn as cudnn
13
13
import torch .distributed as dist
14
14
import torch .optim
15
+ from torch .optim .lr_scheduler import StepLR
15
16
import torch .multiprocessing as mp
16
17
import torch .utils .data
17
18
import torch .utils .data .distributed
24
25
and callable (models .__dict__ [name ]))
25
26
26
27
parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Training' )
27
- parser .add_argument ('data' , metavar = 'DIR' ,
28
- help = 'path to dataset' )
28
+ parser .add_argument ('data' , metavar = 'DIR' , default = 'imagenet' ,
29
+ help = 'path to dataset (default: imagenet) ' )
29
30
parser .add_argument ('-a' , '--arch' , metavar = 'ARCH' , default = 'resnet18' ,
30
31
choices = model_names ,
31
32
help = 'model architecture: ' +
@@ -148,7 +149,7 @@ def main_worker(gpu, ngpus_per_node, args):
148
149
model .cuda (args .gpu )
149
150
# When using a single GPU per process and per
150
151
# DistributedDataParallel, we need to divide the batch size
151
- # ourselves based on the total number of GPUs we have
152
+ # ourselves based on the total number of GPUs of the current node.
152
153
args .batch_size = int (args .batch_size / ngpus_per_node )
153
154
args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
154
155
model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
@@ -168,13 +169,16 @@ def main_worker(gpu, ngpus_per_node, args):
168
169
else :
169
170
model = torch .nn .DataParallel (model ).cuda ()
170
171
171
- # define loss function (criterion) and optimizer
172
+ # define loss function (criterion), optimizer, and learning rate scheduler
172
173
criterion = nn .CrossEntropyLoss ().cuda (args .gpu )
173
174
174
175
optimizer = torch .optim .SGD (model .parameters (), args .lr ,
175
176
momentum = args .momentum ,
176
177
weight_decay = args .weight_decay )
177
-
178
+
179
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
180
+ scheduler = StepLR (optimizer , step_size = 30 , gamma = 0.1 )
181
+
178
182
# optionally resume from a checkpoint
179
183
if args .resume :
180
184
if os .path .isfile (args .resume ):
@@ -192,6 +196,7 @@ def main_worker(gpu, ngpus_per_node, args):
192
196
best_acc1 = best_acc1 .to (args .gpu )
193
197
model .load_state_dict (checkpoint ['state_dict' ])
194
198
optimizer .load_state_dict (checkpoint ['optimizer' ])
199
+ scheduler .load_state_dict (checkpoint ['scheduler' ])
195
200
print ("=> loaded checkpoint '{}' (epoch {})"
196
201
.format (args .resume , checkpoint ['epoch' ]))
197
202
else :
@@ -240,14 +245,16 @@ def main_worker(gpu, ngpus_per_node, args):
240
245
for epoch in range (args .start_epoch , args .epochs ):
241
246
if args .distributed :
242
247
train_sampler .set_epoch (epoch )
243
- adjust_learning_rate (optimizer , epoch , args )
244
248
245
249
# train for one epoch
246
250
train (train_loader , model , criterion , optimizer , epoch , args )
247
251
248
252
# evaluate on validation set
249
253
acc1 = validate (val_loader , model , criterion , args )
254
+
255
+ scheduler .step ()
250
256
257
+
251
258
# remember best acc@1 and save checkpoint
252
259
is_best = acc1 > best_acc1
253
260
best_acc1 = max (acc1 , best_acc1 )
@@ -260,6 +267,7 @@ def main_worker(gpu, ngpus_per_node, args):
260
267
'state_dict' : model .state_dict (),
261
268
'best_acc1' : best_acc1 ,
262
269
'optimizer' : optimizer .state_dict (),
270
+ 'scheduler' : scheduler .state_dict ()
263
271
}, is_best )
264
272
265
273
@@ -425,14 +433,6 @@ def _get_batch_fmtstr(self, num_batches):
425
433
fmt = '{:' + str (num_digits ) + 'd}'
426
434
return '[' + fmt + '/' + fmt .format (num_batches ) + ']'
427
435
428
-
429
- def adjust_learning_rate (optimizer , epoch , args ):
430
- """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
431
- lr = args .lr * (0.1 ** (epoch // 30 ))
432
- for param_group in optimizer .param_groups :
433
- param_group ['lr' ] = lr
434
-
435
-
436
436
def accuracy (output , target , topk = (1 ,)):
437
437
"""Computes the accuracy over the k top predictions for the specified values of k"""
438
438
with torch .no_grad ():
0 commit comments