@@ -48,6 +48,8 @@ def parse_args():
48
48
parser .add_argument ('--val-epochs' , nargs = '*' , type = int ,
49
49
default = [40 , 50 , 55 , 60 , 65 , 70 , 75 , 80 ],
50
50
help = 'epochs at which to evaluate in addition to --val-interval' )
51
+ parser .add_argument ('--batch-splits' , type = int , default = 1 ,
52
+ help = 'Split batch to N steps (gradient accumulation)' )
51
53
parser .add_argument ('--lr-decay-schedule' , nargs = '*' , type = int ,
52
54
default = [40 , 50 ],
53
55
help = 'epochs at which to decay the learning rate' )
@@ -282,6 +284,11 @@ def train300_mlperf_coco(args):
282
284
mllogger .event (key = mllog_const .MODEL_BN_SPAN , value = args .batch_size )
283
285
current_lr = args .lr * (global_batch_size / 32 )
284
286
287
+ assert args .batch_size % args .batch_splits == 0 , "--batch-size must be divisible by --batch-splits"
288
+ fragment_size = args .batch_size // args .batch_splits
289
+ if args .batch_splits != 1 :
290
+ print ("using gradient accumulation with fragments of size {}" .format (fragment_size ))
291
+
285
292
current_momentum = 0.9
286
293
optim = torch .optim .SGD (ssd300 .parameters (), lr = current_lr ,
287
294
momentum = current_momentum ,
@@ -311,6 +318,8 @@ def train300_mlperf_coco(args):
311
318
key = mllog_const .BLOCK_START ,
312
319
metadata = {mllog_const .FIRST_EPOCH_NUM : 1 ,
313
320
mllog_const .EPOCH_COUNT : args .epochs })
321
+
322
+ optim .zero_grad ()
314
323
for epoch in range (args .epochs ):
315
324
mllogger .start (
316
325
key = mllog_const .EPOCH_START ,
@@ -327,27 +336,34 @@ def train300_mlperf_coco(args):
327
336
param_group ['lr' ] = current_lr
328
337
329
338
for nbatch , (img , img_id , img_size , bbox , label ) in enumerate (train_dataloader ):
330
- if use_cuda :
331
- img = img .cuda ()
332
- img = Variable (img , requires_grad = True )
333
- ploc , plabel = ssd300 (img )
334
- trans_bbox = bbox .transpose (1 ,2 ).contiguous ()
335
- if use_cuda :
336
- trans_bbox = trans_bbox .cuda ()
337
- label = label .cuda ()
338
- gloc , glabel = Variable (trans_bbox , requires_grad = False ), \
339
- Variable (label , requires_grad = False )
340
- loss = loss_func (ploc , plabel , gloc , glabel )
339
+ current_batch_size = img .shape [0 ]
340
+ # Split batch for gradient accumulation
341
+ img = torch .split (img , fragment_size )
342
+ bbox = torch .split (bbox , fragment_size )
343
+ label = torch .split (label , fragment_size )
344
+
345
+ for (fimg , fbbox , flabel ) in zip (img , bbox , label ):
346
+ current_fragment_size = fimg .shape [0 ]
347
+ trans_bbox = fbbox .transpose (1 ,2 ).contiguous ()
348
+ if use_cuda :
349
+ fimg = fimg .cuda ()
350
+ trans_bbox = trans_bbox .cuda ()
351
+ flabel = flabel .cuda ()
352
+ fimg = Variable (fimg , requires_grad = True )
353
+ ploc , plabel = ssd300 (fimg )
354
+ gloc , glabel = Variable (trans_bbox , requires_grad = False ), \
355
+ Variable (flabel , requires_grad = False )
356
+ loss = loss_func (ploc , plabel , gloc , glabel )
357
+ loss = loss * (current_fragment_size / current_batch_size ) # weighted mean
358
+ loss .backward ()
341
359
360
+ warmup_step (iter_num , current_lr )
361
+ optim .step ()
362
+ optim .zero_grad ()
342
363
if not np .isinf (loss .item ()): avg_loss = 0.999 * avg_loss + 0.001 * loss .item ()
343
364
if args .rank == 0 and args .log_interval and not iter_num % args .log_interval :
344
365
print ("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}" \
345
366
.format (iter_num , loss .item (), avg_loss ))
346
- optim .zero_grad ()
347
- loss .backward ()
348
- warmup_step (iter_num , current_lr )
349
- optim .step ()
350
-
351
367
iter_num += 1
352
368
353
369
0 commit comments