-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
56 lines (49 loc) · 2.49 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import numpy as np
import torch
from PIL import Image
from classification import (Classification, cvtColor, letterbox_image,
preprocess_input)
from utils.utils import letterbox_image
from utils.utils_metrics import evaluteTop1_5
#------------------------------------------------------#
# test_annotation_path 测试图片路径和标签
#------------------------------------------------------#
test_annotation_path = 'cls_test.txt'
#------------------------------------------------------#
# metrics_out_path 指标保存的文件夹
#------------------------------------------------------#
metrics_out_path = "metrics_out"
class Eval_Classification(Classification):
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------#
#---------------------------------------------------------#
# 归一化+添加上batch_size维度+转置
#---------------------------------------------------------#
image_data = np.transpose(np.expand_dims(preprocess_input(np.array(image, np.float32)), 0), (0, 3, 1, 2))
with torch.no_grad():
photo = torch.from_numpy(image_data).type(torch.FloatTensor)
if self.cuda:
photo = photo.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
preds = torch.softmax(self.model(photo)[0], dim=-1).cpu().numpy()
return preds
if __name__ == "__main__":
if not os.path.exists(metrics_out_path):
os.makedirs(metrics_out_path)
classfication = Eval_Classification()
with open("./cls_test.txt","r") as f:
lines = f.readlines()
top1, top5, Recall, Precision = evaluteTop1_5(classfication, lines, metrics_out_path)
print("top-1 accuracy = %.2f%%" % (top1*100))
# print("top-5 accuracy = %.2f%%" % (top5*100))
print("mean Recall = %.2f%%" % (np.mean(Recall)*100))
# print("mean Precision = %.2f%%" % (np.mean(Precision)*100))