Skip to content

Commit d547876

Browse files
authored
If the dataset is not exactly divisible by world_size, the validation accuracy is incorrect. We solve this problem with an auxiliary validation set. (#980)
1 parent 2bf23f1 commit d547876

File tree

1 file changed

+62
-34
lines changed

1 file changed

+62
-34
lines changed

imagenet/main.py

+62-34
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torchvision.transforms as transforms
2020
import torchvision.datasets as datasets
2121
import torchvision.models as models
22+
from torch.utils.data import Subset
2223

2324
model_names = sorted(name for name in models.__dict__
2425
if name.islower() and not name.startswith("__")
@@ -219,24 +220,29 @@ def main_worker(gpu, ngpus_per_node, args):
219220
normalize,
220221
]))
221222

223+
val_dataset = datasets.ImageFolder(
224+
valdir,
225+
transforms.Compose([
226+
transforms.Resize(256),
227+
transforms.CenterCrop(224),
228+
transforms.ToTensor(),
229+
normalize,
230+
]))
231+
222232
if args.distributed:
223233
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
234+
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
224235
else:
225236
train_sampler = None
237+
val_sampler = None
226238

227239
train_loader = torch.utils.data.DataLoader(
228240
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
229241
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
230242

231243
val_loader = torch.utils.data.DataLoader(
232-
datasets.ImageFolder(valdir, transforms.Compose([
233-
transforms.Resize(256),
234-
transforms.CenterCrop(224),
235-
transforms.ToTensor(),
236-
normalize,
237-
])),
238-
batch_size=args.batch_size, shuffle=False,
239-
num_workers=args.workers, pin_memory=True)
244+
val_dataset, batch_size=args.batch_size, shuffle=False,
245+
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
240246

241247
if args.evaluate:
242248
validate(val_loader, model, criterion, args)
@@ -315,48 +321,64 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
315321
end = time.time()
316322

317323
if i % args.print_freq == 0:
318-
progress.display(i)
324+
progress.display(i + 1)
319325

320326

321327
def validate(val_loader, model, criterion, args):
328+
329+
def run_validate(loader, base_progress=0):
330+
with torch.no_grad():
331+
end = time.time()
332+
for i, (images, target) in enumerate(loader):
333+
i = base_progress + i
334+
if args.gpu is not None:
335+
images = images.cuda(args.gpu, non_blocking=True)
336+
if torch.cuda.is_available():
337+
target = target.cuda(args.gpu, non_blocking=True)
338+
339+
# compute output
340+
output = model(images)
341+
loss = criterion(output, target)
342+
343+
# measure accuracy and record loss
344+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
345+
losses.update(loss.item(), images.size(0))
346+
top1.update(acc1[0], images.size(0))
347+
top5.update(acc5[0], images.size(0))
348+
349+
# measure elapsed time
350+
batch_time.update(time.time() - end)
351+
end = time.time()
352+
353+
if i % args.print_freq == 0:
354+
progress.display(i + 1)
355+
322356
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
323357
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
324358
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
325359
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
326360
progress = ProgressMeter(
327-
len(val_loader),
361+
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
328362
[batch_time, losses, top1, top5],
329363
prefix='Test: ')
330364

331365
# switch to evaluate mode
332366
model.eval()
333367

334-
with torch.no_grad():
335-
end = time.time()
336-
for i, (images, target) in enumerate(val_loader):
337-
if args.gpu is not None:
338-
images = images.cuda(args.gpu, non_blocking=True)
339-
if torch.cuda.is_available():
340-
target = target.cuda(args.gpu, non_blocking=True)
341-
342-
# compute output
343-
output = model(images)
344-
loss = criterion(output, target)
345-
346-
# measure accuracy and record loss
347-
acc1, acc5 = accuracy(output, target, topk=(1, 5))
348-
losses.update(loss.item(), images.size(0))
349-
top1.update(acc1[0], images.size(0))
350-
top5.update(acc5[0], images.size(0))
351-
352-
# measure elapsed time
353-
batch_time.update(time.time() - end)
354-
end = time.time()
368+
run_validate(val_loader)
369+
if args.distributed:
370+
top1.all_reduce()
371+
top5.all_reduce()
355372

356-
if i % args.print_freq == 0:
357-
progress.display(i)
373+
if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
374+
aux_val_dataset = Subset(val_loader.dataset,
375+
range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
376+
aux_val_loader = torch.utils.data.DataLoader(
377+
aux_val_dataset, batch_size=args.batch_size, shuffle=False,
378+
num_workers=args.workers, pin_memory=True)
379+
run_validate(aux_val_loader, len(val_loader))
358380

359-
progress.display_summary()
381+
progress.display_summary()
360382

361383
return top1.avg
362384

@@ -392,6 +414,12 @@ def update(self, val, n=1):
392414
self.count += n
393415
self.avg = self.sum / self.count
394416

417+
def all_reduce(self):
418+
total = torch.FloatTensor([self.sum, self.count])
419+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
420+
self.sum, self.count = total.tolist()
421+
self.avg = self.sum / self.count
422+
395423
def __str__(self):
396424
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
397425
return fmtstr.format(**self.__dict__)

0 commit comments

Comments
 (0)