-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
648 lines (544 loc) · 26.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
import argparse
import logging
import math
import os
import random
import shutil
import time
from collections import OrderedDict
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import sys
import itertools
from dataset.cifar import DATASET_GETTERS
from utils import AverageMeter, accuracy
logger = logging.getLogger(__name__)
best_acc = 0
def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'):
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(checkpoint,
'model_best.pth.tar'))
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def get_cosine_schedule_with_warmup(optimizer,
num_warmup_steps,
num_training_steps,
num_cycles=7./16.,
last_epoch=-1):
def _lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
no_progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
return max(0., math.cos(math.pi * num_cycles * no_progress))
return LambdaLR(optimizer, _lr_lambda, last_epoch)
def interleave(x, size):
s = list(x.shape)
return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
def de_interleave(x, size):
s = list(x.shape)
return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
def main():
parser = argparse.ArgumentParser(description='PyTorch CLS Training')
parser.add_argument('--gpu-id', default='0', type=int,
help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--num-workers', type=int, default=4,
help='number of workers')
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'cifar100'],
help='dataset name')
parser.add_argument('--num-labeled', type=int, default=4000,
help='number of labeled data')
parser.add_argument('--arch', default='wideresnet', type=str,
choices=['wideresnet', 'resnext'],
help='dataset name')
parser.add_argument('--total-steps', default=307200, type=int,
help='number of total steps to run')
parser.add_argument('--eval-step', default=1024, type=int,
help='number of eval steps to run')
parser.add_argument('--start-epoch', default=0, type=int,
help='manual epoch number (useful on restarts)')
parser.add_argument('--batch-size', default=64, type=int,
help='train batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
help='initial learning rate')
parser.add_argument('--warmup', default=0, type=float,
help='warmup epochs (unlabeled data based)')
parser.add_argument('--wdecay', default=5e-4, type=float,
help='weight decay')
parser.add_argument('--nesterov', action='store_true', default=True,
help='use nesterov momentum')
parser.add_argument('--use-ema', action='store_true', default=True,
help='use EMA model')
parser.add_argument('--ema-decay', default=0.999, type=float,
help='EMA decay rate')
parser.add_argument('--mu', default=8, type=int,
help='coefficient of unlabeled batch size')
parser.add_argument('--lambda-u1', default=2, type=float,
help='coefficient of self-labeling loss')
parser.add_argument('--lambda-u2', default=1, type=float,
help='coefficient of co-labeling loss')
parser.add_argument('--Temp', default=1, type=float,
help='pseudo label temperature')
parser.add_argument('--threshold', default=0.85, type=float,
help='weight threshold for exchange')
parser.add_argument('--out', default='result',
help='directory to output the result')
parser.add_argument('--resume', default='', type=str,
help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', default=None, type=int,
help="random seed")
parser.add_argument("--amp", action="store_true",
help="use 16-bit (mixed) precision through NVIDIA apex AMP")
parser.add_argument("--opt_level", type=str, default="O1",
help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
parser.add_argument('--no-progress', action='store_true',
help="don't use progress bar")
args = parser.parse_args()
global best_acc
def create_model(args):
if args.arch == 'wideresnet':
import models.wideresnet as models
model = models.build_wideresnet(depth=args.model_depth,
widen_factor=args.model_width,
dropout=0,
num_classes=args.num_classes)
elif args.arch == 'resnext':
import models.resnext as models
model = models.build_resnext(cardinality=args.model_cardinality,
depth=args.model_depth,
width=args.model_width,
num_classes=args.num_classes)
logger.info("Total params: {:.2f}M".format(
sum(p.numel() for p in model.parameters())/1e6))
return model
if args.local_rank == -1:
if torch.cuda.is_available():
device = torch.device('cuda', args.gpu_id)
args.n_gpu = torch.cuda.device_count()
else:
device = torch.device('cpu')
args.n_gpu = 0
args.world_size = 1
else:
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
torch.distributed.init_process_group(backend='nccl')
args.world_size = torch.distributed.get_world_size()
args.n_gpu = 1
args.device = device
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning(
f"Process rank: {args.local_rank}, "
f"device: {args.device}, "
f"n_gpu: {args.n_gpu}, "
f"distributed training: {bool(args.local_rank != -1)}, "
f"16-bits training: {args.amp}",)
logger.info(dict(args._get_kwargs()))
if args.seed is not None:
set_seed(args)
if args.local_rank in [-1, 0]:
os.makedirs(args.out, exist_ok=True)
args.writer = SummaryWriter(args.out)
if args.dataset == 'cifar10':
args.num_classes = 10
if args.arch == 'wideresnet':
args.model_depth = 28
args.model_width = 2
elif args.arch == 'resnext':
args.model_cardinality = 4
args.model_depth = 28
args.model_width = 4
elif args.dataset == 'cifar100':
args.num_classes = 100
if args.arch == 'wideresnet':
args.model_depth = 28
args.model_width = 8
elif args.arch == 'resnext':
args.model_cardinality = 8
args.model_depth = 29
args.model_width = 64
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](
args, './data')
if args.local_rank == 0:
torch.distributed.barrier()
train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
labeled_trainloader = DataLoader(
labeled_dataset,
sampler=train_sampler(labeled_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True)
unlabeled_trainloader = DataLoader(
unlabeled_dataset,
sampler=train_sampler(unlabeled_dataset),
batch_size=args.batch_size*args.mu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True)
test_loader = DataLoader(
test_dataset,
sampler=SequentialSampler(test_dataset),
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
# create two independent models
model = create_model(args)
model_l = create_model(args)
if args.local_rank == 0:
torch.distributed.barrier()
model.to(args.device)
model_l.to(args.device)
no_decay = ['bias', 'bn']
grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
grouped_parameters_l = [
{'params': [p for n, p in model_l.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
{'params': [p for n, p in model_l.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer_l = optim.SGD(grouped_parameters_l, lr=args.lr,
momentum=0.9, nesterov=args.nesterov)
optimizer = optim.SGD(grouped_parameters, lr=args.lr,
momentum=0.9, nesterov=args.nesterov)
args.epochs = math.ceil(args.total_steps / args.eval_step)
scheduler = get_cosine_schedule_with_warmup(
optimizer, args.warmup, args.total_steps)
scheduler_l = get_cosine_schedule_with_warmup(
optimizer_l, args.warmup, args.total_steps)
ema_model = None
if args.use_ema:
from models.ema import ModelEMA
ema_model = ModelEMA(args, model, args.ema_decay)
args.start_epoch = 0
if args.resume:
logger.info("==> Resuming from checkpoint..")
assert os.path.isfile(
args.resume), "Error: no checkpoint directory found!"
args.out = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_acc = checkpoint['best_acc']
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
if args.use_ema:
ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
if args.amp:
from apex import amp
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.opt_level)
model_l, optimizer_l = amp.initialize(
model_l, optimizer_l, opt_level=args.opt_level)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank],
output_device=args.local_rank, find_unused_parameters=True)
model_l = torch.nn.parallel.DistributedDataParallel(
model_l, device_ids=[args.local_rank],
output_device=args.local_rank, find_unused_parameters=True)
logger.info("***** Running training *****")
logger.info(f" Task = {args.dataset}@{args.num_labeled}")
logger.info(f" Num Epochs = {args.epochs}")
logger.info(f" Batch size per GPU = {args.batch_size}")
logger.info(
f" Total train batch size = {args.batch_size*args.world_size}")
logger.info(f" Total optimization steps = {args.total_steps}")
model.zero_grad()
model_l.zero_grad()
train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
model, model_l, optimizer, optimizer_l, ema_model, scheduler, scheduler_l)
def train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
model, model_l, optimizer, optimizer_l, ema_model, scheduler, scheduler_l):
if args.amp:
from apex import amp
global best_acc
test_accs = []
end = time.time()
if args.world_size > 1:
labeled_epoch = 0
unlabeled_epoch = 0
labeled_trainloader.sampler.set_epoch(labeled_epoch)
unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
for epoch in range(args.start_epoch, args.epochs):
labeled_iter = iter(labeled_trainloader)
unlabeled_iter = iter(unlabeled_trainloader)
model.train()
model_l.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
losses_x = AverageMeter()
losses_u = AverageMeter()
mask_probs = AverageMeter()
mask_NL = AverageMeter()
mask_PL = AverageMeter()
mask_EC = AverageMeter()
dis_EMA = AverageMeter()
dis_LEFT = AverageMeter()
if not args.no_progress:
p_bar = tqdm(range(args.eval_step),
disable=args.local_rank not in [-1, 0])
for batch_idx in range(args.eval_step):
try:
inputs_x, targets_x = labeled_iter.next()
except:
if args.world_size > 1:
labeled_epoch += 1
labeled_trainloader.sampler.set_epoch(labeled_epoch)
labeled_iter = iter(labeled_trainloader)
inputs_x, targets_x = labeled_iter.next()
try:
(inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
except:
if args.world_size > 1:
unlabeled_epoch += 1
unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
unlabeled_iter = iter(unlabeled_trainloader)
(inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
data_time.update(time.time() - end)
batch_size = inputs_x.shape[0]
inputs = interleave(
torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*args.mu+1).to(args.device)
targets_x = targets_x.to(args.device)
logits = model(inputs)
logits_left = model_l(inputs)
# label generation of model-1
logits = de_interleave(logits, 2*args.mu+1)
logits_x = logits[:batch_size]
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
del logits
pseudo_label = torch.softmax(logits_u_w.detach()/args.Temp, dim=-1)
max_probs, targets_u = torch.max(pseudo_label, dim=-1)
min_probs, targets_u_nl = torch.min(pseudo_label, dim=-1)
# caluating confidence weight for model-1's artifical labels
h_label = torch.sum(-(pseudo_label*torch.log2(pseudo_label)), dim=-1)
w_label_1 = 1.0 - h_label/torch.log2(torch.tensor(args.num_classes).float())
# masking low-confidence artificial labels for exchange
mask = w_label_1.ge(args.threshold).float()
w_label = w_label_1 * mask
# label generation of model-2
logits_left = de_interleave(logits_left, 2*args.mu+1)
logits_x_left = logits_left[:batch_size]
logits_u_w_left, logits_u_s_left = logits_left[batch_size:].chunk(2)
del logits_left
pseudo_label_left = torch.softmax(logits_u_w_left.detach()/args.Temp, dim=-1)
max_probs_left, targets_u_left = torch.max(pseudo_label_left, dim=-1)
min_probs_left, targets_u_nl_left = torch.min(pseudo_label_left, dim=-1)
# caluating the overlap of artificial labels
mask_PL_ratio = targets_u_left.eq(targets_u).float()
mask_NL_ratio = targets_u_nl_left.eq(targets_u_nl).float()
# caluating confidence weight for model-2's artifical labels
h_label_left = torch.sum(-(pseudo_label_left*torch.log2(pseudo_label_left)), dim=-1)
w_label_left_1 = 1.0 - h_label_left/torch.log2(torch.tensor(args.num_classes).float())
# masking low-confidence artificial labels for exchange
mask_left = w_label_left_1.ge(args.threshold).float()
w_label_left = w_label_left_1 * mask_left
# negative learning
interm_nl_mask = (pseudo_label < -0.1) *1
interm_nl_mask.scatter_(1, targets_u_nl.view(-1,1), 1) #manually setting the argmin value to one
pred_nl = F.softmax(logits_u_s, dim=1)
pred_nl = 1 - pred_nl
pred_nl = torch.clamp(pred_nl, 1e-7, 1.0)
interm_nl_mask_left = (pseudo_label_left < -0.1) *1
interm_nl_mask_left.scatter_(1, targets_u_nl_left.view(-1,1), 1) #manually setting the argmin value to one
pred_nl_left = F.softmax(logits_u_s_left, dim=1)
pred_nl_left = 1 - pred_nl_left
pred_nl_left = torch.clamp(pred_nl_left, 1e-7, 1.0)
# mixed all loss up
# For model-1
# supervised learning
Lx = F.cross_entropy(logits_x, targets_x.long(), reduction='mean')
# positive learning with self-labeling pseudo labels
L_self_pl = (F.cross_entropy(logits_u_s, targets_u,
reduction='none') * w_label_1.detach()).mean()
# negative learning with self-labeling complementary labels
L_self_nl = torch.mean((-torch.sum(torch.log(pred_nl)*interm_nl_mask, dim = -1)) * w_label_1.detach())
# positive learning with co-labeling pseudo labels; w_label_left is already masked.
L_co_pl = (F.cross_entropy(logits_u_s, targets_u_left,
reduction='none') * w_label_left.detach()).mean()
# negative learning with co-labeling complementary labels; w_label_left is already masked.
L_co_nl = torch.mean((-torch.sum(torch.log(pred_nl)*interm_nl_mask_left, dim = -1)) * w_label_left.detach())
# For model-2
# supervised learning
Lx_l = F.cross_entropy(logits_x_left, targets_x.long(), reduction='mean')
# positive learning with self-labeling pseudo labels
L_self_pl_l = (F.cross_entropy(logits_u_s_left, targets_u_left,
reduction='none') * w_label_left_1.detach()).mean()
# negative learning with self-labeling complementary labels
L_self_nl_l = torch.mean((-torch.sum(torch.log(pred_nl_left)*interm_nl_mask_left, dim = -1)) * w_label_left_1.detach())
# positive learning with co-labeling pseudo labels; w_label is already masked.
L_co_pl_l = (F.cross_entropy(logits_u_s_left, targets_u,
reduction='none') * w_label.detach()).mean()
# negative learning with co-labeling complementary labels; w_label is already masked.
L_co_nl_l = torch.mean((-torch.sum(torch.log(pred_nl_left)*interm_nl_mask, dim = -1)) * w_label.detach())
loss = Lx + args.lambda_u1*L_self_pl + args.lambda_u1*L_self_nl + args.lambda_u2*L_co_pl + args.lambda_u2*L_co_nl
loss1 = Lx_l + args.lambda_u1*L_self_pl_l + args.lambda_u1*L_self_nl_l + args.lambda_u2*L_co_pl_l + args.lambda_u2*L_co_nl_l
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
with amp.scale_loss(loss1, optimizer_l) as scaled_loss_l:
scaled_loss_l.backward()
else:
if np.isnan(loss.item()):
logger.info('Loss value is NaN!')
else:
loss.backward()
losses.update(loss.item())
losses_x.update(Lx.item())
losses_u.update(L_self_pl.item()+L_self_nl.item())
optimizer.step()
scheduler.step()
if args.use_ema:
ema_model.update(model)
if np.isnan(loss1.item()):
logger.info('Loss1 value is NaN!')
else:
loss1.backward()
optimizer_l.step()
scheduler_l.step()
model.zero_grad()
model_l.zero_grad()
batch_time.update(time.time() - end)
end = time.time()
mask_probs.update(w_label.mean().item())
mask_PL.update(mask_PL_ratio.mean().item())
mask_NL.update(mask_NL_ratio.mean().item())
mask_EC.update((mask.mean().item()+mask_left.mean().item())/2.0)
diss_ema = 0.0
diss_left = 0.0
for par1, par2, par3 in zip(model.parameters(), ema_model.ema.parameters(), model_l.parameters()):
diss_ema += (par1-par2).pow(2).sum().item()
diss_left += (par1-par3).pow(2).sum().item()
dis_EMA.update(diss_ema)
dis_LEFT.update(diss_left)
if not args.no_progress:
p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.4f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Mask: {mask:.2f}. ".format(
epoch=epoch + 1,
epochs=args.epochs,
batch=batch_idx + 1,
iter=args.eval_step,
lr=scheduler.get_lr()[0],
data=data_time.avg,
bt=batch_time.avg,
loss=losses.avg,
loss_x=losses_x.avg,
loss_u=losses_u.avg,
mask=mask_probs.avg))
p_bar.update()
if not args.no_progress:
p_bar.close()
if args.use_ema:
test_model = ema_model.ema
else:
test_model = model
if args.local_rank in [-1, 0]:
test_loss, test_acc = test(args, test_loader, test_model, epoch)
args.writer.add_scalar('train/1.train_loss', losses.avg, epoch)
args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch)
args.writer.add_scalar('train/3.train_loss_u', losses_u.avg, epoch)
args.writer.add_scalar('train/4.mask', mask_probs.avg, epoch)
args.writer.add_scalar('train/5.PL', mask_PL.avg, epoch)
args.writer.add_scalar('train/5.NL', mask_NL.avg, epoch)
args.writer.add_scalar('train/5.EC', mask_EC.avg, epoch)
args.writer.add_scalar('train/6.EMA', dis_EMA.avg, epoch)
args.writer.add_scalar('train/6.Co-training', dis_LEFT.avg, epoch)
args.writer.add_scalar('test/1.test_acc', test_acc, epoch)
args.writer.add_scalar('test/2.test_loss', test_loss, epoch)
is_best = test_acc > best_acc
best_acc = max(test_acc, best_acc)
model_to_save = model.module if hasattr(model, "module") else model
if args.use_ema:
ema_to_save = ema_model.ema.module if hasattr(
ema_model.ema, "module") else ema_model.ema
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model_to_save.state_dict(),
'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None,
'acc': test_acc,
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, is_best, args.out)
test_accs.append(test_acc)
logger.info('Best top-1 acc: {:.2f}'.format(best_acc))
logger.info('Mean top-1 acc: {:.2f}\n'.format(
np.mean(test_accs[-20:])))
freq = 10
if (epoch+1)%freq == 0:
with open(os.path.join(args.out, 'log.txt'), 'a+') as ofile:
ofile.write(f'############################# PL Iteration ({freq} epoch): {(epoch+1)//freq} #############################\n')
ofile.write(f'Last Test Acc: {test_acc}, Best Test Acc: {best_acc}\n')
if args.local_rank in [-1, 0]:
args.writer.close()
def test(args, test_loader, model, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
if not args.no_progress:
test_loader = tqdm(test_loader,
disable=args.local_rank not in [-1, 0])
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
data_time.update(time.time() - end)
model.eval()
inputs = inputs.to(args.device)
targets = targets.to(args.device)
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
losses.update(loss.item(), inputs.shape[0])
top1.update(prec1.item(), inputs.shape[0])
top5.update(prec5.item(), inputs.shape[0])
batch_time.update(time.time() - end)
end = time.time()
if not args.no_progress:
test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
batch=batch_idx + 1,
iter=len(test_loader),
data=data_time.avg,
bt=batch_time.avg,
loss=losses.avg,
top1=top1.avg,
top5=top5.avg,
))
if not args.no_progress:
test_loader.close()
logger.info("top-1 acc: {:.2f}".format(top1.avg))
logger.info("top-5 acc: {:.2f}".format(top5.avg))
return losses.avg, top1.avg
if __name__ == '__main__':
main()