|
19 | 19 | import torchvision.transforms as transforms
|
20 | 20 | import torchvision.datasets as datasets
|
21 | 21 | import torchvision.models as models
|
| 22 | +from torch.utils.data import Subset |
22 | 23 |
|
23 | 24 | model_names = sorted(name for name in models.__dict__
|
24 | 25 | if name.islower() and not name.startswith("__")
|
@@ -219,24 +220,29 @@ def main_worker(gpu, ngpus_per_node, args):
|
219 | 220 | normalize,
|
220 | 221 | ]))
|
221 | 222 |
|
| 223 | + val_dataset = datasets.ImageFolder( |
| 224 | + valdir, |
| 225 | + transforms.Compose([ |
| 226 | + transforms.Resize(256), |
| 227 | + transforms.CenterCrop(224), |
| 228 | + transforms.ToTensor(), |
| 229 | + normalize, |
| 230 | + ])) |
| 231 | + |
222 | 232 | if args.distributed:
|
223 | 233 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
| 234 | + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) |
224 | 235 | else:
|
225 | 236 | train_sampler = None
|
| 237 | + val_sampler = None |
226 | 238 |
|
227 | 239 | train_loader = torch.utils.data.DataLoader(
|
228 | 240 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
229 | 241 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
230 | 242 |
|
231 | 243 | val_loader = torch.utils.data.DataLoader(
|
232 |
| - datasets.ImageFolder(valdir, transforms.Compose([ |
233 |
| - transforms.Resize(256), |
234 |
| - transforms.CenterCrop(224), |
235 |
| - transforms.ToTensor(), |
236 |
| - normalize, |
237 |
| - ])), |
238 |
| - batch_size=args.batch_size, shuffle=False, |
239 |
| - num_workers=args.workers, pin_memory=True) |
| 244 | + val_dataset, batch_size=args.batch_size, shuffle=False, |
| 245 | + num_workers=args.workers, pin_memory=True, sampler=val_sampler) |
240 | 246 |
|
241 | 247 | if args.evaluate:
|
242 | 248 | validate(val_loader, model, criterion, args)
|
@@ -315,48 +321,64 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
|
315 | 321 | end = time.time()
|
316 | 322 |
|
317 | 323 | if i % args.print_freq == 0:
|
318 |
| - progress.display(i) |
| 324 | + progress.display(i + 1) |
319 | 325 |
|
320 | 326 |
|
321 | 327 | def validate(val_loader, model, criterion, args):
|
| 328 | + |
| 329 | + def run_validate(loader, base_progress=0): |
| 330 | + with torch.no_grad(): |
| 331 | + end = time.time() |
| 332 | + for i, (images, target) in enumerate(loader): |
| 333 | + i = base_progress + i |
| 334 | + if args.gpu is not None: |
| 335 | + images = images.cuda(args.gpu, non_blocking=True) |
| 336 | + if torch.cuda.is_available(): |
| 337 | + target = target.cuda(args.gpu, non_blocking=True) |
| 338 | + |
| 339 | + # compute output |
| 340 | + output = model(images) |
| 341 | + loss = criterion(output, target) |
| 342 | + |
| 343 | + # measure accuracy and record loss |
| 344 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 345 | + losses.update(loss.item(), images.size(0)) |
| 346 | + top1.update(acc1[0], images.size(0)) |
| 347 | + top5.update(acc5[0], images.size(0)) |
| 348 | + |
| 349 | + # measure elapsed time |
| 350 | + batch_time.update(time.time() - end) |
| 351 | + end = time.time() |
| 352 | + |
| 353 | + if i % args.print_freq == 0: |
| 354 | + progress.display(i + 1) |
| 355 | + |
322 | 356 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
|
323 | 357 | losses = AverageMeter('Loss', ':.4e', Summary.NONE)
|
324 | 358 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
|
325 | 359 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
|
326 | 360 | progress = ProgressMeter(
|
327 |
| - len(val_loader), |
| 361 | + len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), |
328 | 362 | [batch_time, losses, top1, top5],
|
329 | 363 | prefix='Test: ')
|
330 | 364 |
|
331 | 365 | # switch to evaluate mode
|
332 | 366 | model.eval()
|
333 | 367 |
|
334 |
| - with torch.no_grad(): |
335 |
| - end = time.time() |
336 |
| - for i, (images, target) in enumerate(val_loader): |
337 |
| - if args.gpu is not None: |
338 |
| - images = images.cuda(args.gpu, non_blocking=True) |
339 |
| - if torch.cuda.is_available(): |
340 |
| - target = target.cuda(args.gpu, non_blocking=True) |
341 |
| - |
342 |
| - # compute output |
343 |
| - output = model(images) |
344 |
| - loss = criterion(output, target) |
345 |
| - |
346 |
| - # measure accuracy and record loss |
347 |
| - acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
348 |
| - losses.update(loss.item(), images.size(0)) |
349 |
| - top1.update(acc1[0], images.size(0)) |
350 |
| - top5.update(acc5[0], images.size(0)) |
351 |
| - |
352 |
| - # measure elapsed time |
353 |
| - batch_time.update(time.time() - end) |
354 |
| - end = time.time() |
| 368 | + run_validate(val_loader) |
| 369 | + if args.distributed: |
| 370 | + top1.all_reduce() |
| 371 | + top5.all_reduce() |
355 | 372 |
|
356 |
| - if i % args.print_freq == 0: |
357 |
| - progress.display(i) |
| 373 | + if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): |
| 374 | + aux_val_dataset = Subset(val_loader.dataset, |
| 375 | + range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset))) |
| 376 | + aux_val_loader = torch.utils.data.DataLoader( |
| 377 | + aux_val_dataset, batch_size=args.batch_size, shuffle=False, |
| 378 | + num_workers=args.workers, pin_memory=True) |
| 379 | + run_validate(aux_val_loader, len(val_loader)) |
358 | 380 |
|
359 |
| - progress.display_summary() |
| 381 | + progress.display_summary() |
360 | 382 |
|
361 | 383 | return top1.avg
|
362 | 384 |
|
@@ -392,6 +414,12 @@ def update(self, val, n=1):
|
392 | 414 | self.count += n
|
393 | 415 | self.avg = self.sum / self.count
|
394 | 416 |
|
| 417 | + def all_reduce(self): |
| 418 | + total = torch.FloatTensor([self.sum, self.count]) |
| 419 | + dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
| 420 | + self.sum, self.count = total.tolist() |
| 421 | + self.avg = self.sum / self.count |
| 422 | + |
395 | 423 | def __str__(self):
|
396 | 424 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
397 | 425 | return fmtstr.format(**self.__dict__)
|
|
0 commit comments