-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval.py
101 lines (79 loc) · 3.45 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import math
import numpy as np
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == "__main__":
from mixnet import MixNet
data_root = "/home/liuhuijun/Datasets/ImageNet"
val_dir = os.path.join(data_root, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
arch = "l"
img_preparam = {"s": (224, 0.875), "m": (224, 0.875), "l": (224, 0.875)}
valid_dataset = datasets.ImageFolder(val_dir, transforms.Compose([transforms.Resize(int(img_preparam[arch][0] / img_preparam[arch][1]), Image.BICUBIC),
transforms.CenterCrop(img_preparam[arch][0]),
transforms.ToTensor(),
normalize]))
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False,
num_workers=16, pin_memory=False)
num_batches = int(math.ceil(len(valid_loader.dataset) / float(valid_loader.batch_size)))
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
model = MixNet(arch=arch, num_classes=1000).cuda()
used_gpus = [idx for idx in range(torch.cuda.device_count())]
model = torch.nn.DataParallel(model, device_ids=used_gpus).cuda()
checkpoint = torch.load("/home/liuhuijun/TrainLog/release/imagenet/mixnet_{}_top1v_78.6.pkl".format(arch))
pre_weight = checkpoint['model_state']
model_dict = model.state_dict()
pretrained_dict = {"module." + k: v for k, v in pre_weight.items() if "module." + k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
with torch.no_grad():
pbar = tqdm(np.arange(num_batches))
for i_val, (images, labels) in enumerate(valid_loader):
images = images.cuda()
labels = torch.squeeze(labels.cuda())
net_out = model(images)
prec1, prec5 = accuracy(net_out, labels, topk=(1, 5))
top1.update(prec1.item(), images.size(0))
top5.update(prec5.item(), images.size(0))
pbar.update(1)
pbar.set_description("> Eval")
pbar.set_postfix(Top1=top1.avg, Top5=top5.avg)
pbar.set_postfix(Top1=top1.avg, Top5=top5.avg)
pbar.update(1)
pbar.close()