Skip to content

Commit 98f8b2a

Browse files
authored
[SSD] Added gradient accumulation support (mlcommons#421)
1 parent b06269d commit 98f8b2a

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

single_stage_detector/ssd/train.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def parse_args():
4848
parser.add_argument('--val-epochs', nargs='*', type=int,
4949
default=[40, 50, 55, 60, 65, 70, 75, 80],
5050
help='epochs at which to evaluate in addition to --val-interval')
51+
parser.add_argument('--batch-splits', type=int, default=1,
52+
help='Split batch to N steps (gradient accumulation)')
5153
parser.add_argument('--lr-decay-schedule', nargs='*', type=int,
5254
default=[40, 50],
5355
help='epochs at which to decay the learning rate')
@@ -282,6 +284,11 @@ def train300_mlperf_coco(args):
282284
mllogger.event(key=mllog_const.MODEL_BN_SPAN, value=args.batch_size)
283285
current_lr = args.lr * (global_batch_size / 32)
284286

287+
assert args.batch_size % args.batch_splits == 0, "--batch-size must be divisible by --batch-splits"
288+
fragment_size = args.batch_size // args.batch_splits
289+
if args.batch_splits != 1:
290+
print("using gradient accumulation with fragments of size {}".format(fragment_size))
291+
285292
current_momentum = 0.9
286293
optim = torch.optim.SGD(ssd300.parameters(), lr=current_lr,
287294
momentum=current_momentum,
@@ -311,6 +318,8 @@ def train300_mlperf_coco(args):
311318
key=mllog_const.BLOCK_START,
312319
metadata={mllog_const.FIRST_EPOCH_NUM: 1,
313320
mllog_const.EPOCH_COUNT: args.epochs})
321+
322+
optim.zero_grad()
314323
for epoch in range(args.epochs):
315324
mllogger.start(
316325
key=mllog_const.EPOCH_START,
@@ -327,27 +336,34 @@ def train300_mlperf_coco(args):
327336
param_group['lr'] = current_lr
328337

329338
for nbatch, (img, img_id, img_size, bbox, label) in enumerate(train_dataloader):
330-
if use_cuda:
331-
img = img.cuda()
332-
img = Variable(img, requires_grad=True)
333-
ploc, plabel = ssd300(img)
334-
trans_bbox = bbox.transpose(1,2).contiguous()
335-
if use_cuda:
336-
trans_bbox = trans_bbox.cuda()
337-
label = label.cuda()
338-
gloc, glabel = Variable(trans_bbox, requires_grad=False), \
339-
Variable(label, requires_grad=False)
340-
loss = loss_func(ploc, plabel, gloc, glabel)
339+
current_batch_size = img.shape[0]
340+
# Split batch for gradient accumulation
341+
img = torch.split(img, fragment_size)
342+
bbox = torch.split(bbox, fragment_size)
343+
label = torch.split(label, fragment_size)
344+
345+
for (fimg, fbbox, flabel) in zip(img, bbox, label):
346+
current_fragment_size = fimg.shape[0]
347+
trans_bbox = fbbox.transpose(1,2).contiguous()
348+
if use_cuda:
349+
fimg = fimg.cuda()
350+
trans_bbox = trans_bbox.cuda()
351+
flabel = flabel.cuda()
352+
fimg = Variable(fimg, requires_grad=True)
353+
ploc, plabel = ssd300(fimg)
354+
gloc, glabel = Variable(trans_bbox, requires_grad=False), \
355+
Variable(flabel, requires_grad=False)
356+
loss = loss_func(ploc, plabel, gloc, glabel)
357+
loss = loss * (current_fragment_size / current_batch_size) # weighted mean
358+
loss.backward()
341359

360+
warmup_step(iter_num, current_lr)
361+
optim.step()
362+
optim.zero_grad()
342363
if not np.isinf(loss.item()): avg_loss = 0.999*avg_loss + 0.001*loss.item()
343364
if args.rank == 0 and args.log_interval and not iter_num % args.log_interval:
344365
print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
345366
.format(iter_num, loss.item(), avg_loss))
346-
optim.zero_grad()
347-
loss.backward()
348-
warmup_step(iter_num, current_lr)
349-
optim.step()
350-
351367
iter_num += 1
352368

353369

0 commit comments

Comments
 (0)