-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresult_check.py
101 lines (82 loc) · 3.46 KB
/
result_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import time
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from model import FundusModel
from dataset import FundusDataset
def confusionMatrix(model_path, model_name, mode, output_class, plotclass, image_size, dataset=FundusDataset):
model = Attension(base_model=base_model, pt_depth=pt_depth,
feature_size=feature_size, output_class=args.output_class, freeze=args.freeze)
model.load_state_dict(torch.load(os.path.join(machine_path,'model_save', args.load_model_para)))
model.load_state_dict(torch.load(model_path))
model.eval()
print('model loaded')
model = model.cuda()
with torch.no_grad():
dataset = FundusDataset(mode=mode, image_size=image_size)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True,
pin_memory=True, num_workers=2*os.cpu_count())
targets = torch.Tensor().type(torch.long)
predicts = torch.Tensor().type(torch.long).cuda()
for i, data in enumerate(dataloader, start=0):
input, target = data[0].cuda(), data[1]
targets = torch.cat((targets, target))
output = model(input)
_, predicted = torch.max(output, 1)
predicts = torch.cat((predicts, predicted))
print(f'-- {i} batch--')
correct_count = (predicts == targets.cuda()).sum().item()
accuracy = (100 * correct_count/len(dataset))
print(f'\n Accuracy on {mode} set: %.2f %% \n' % (accuracy) )
targets = targets.numpy()
predicts = predicts.cpu().numpy()
c_matrix = confusion_matrix(targets, predicts, normalize='true',
labels=[i for i in range(plotclass)])
return c_matrix
if __name__ == '__main__':
start_time = time.time()
image_size = 600
mode = 'val'
output_class = 5
plotclass = 5
trial = 38
model_name = 'se_resnext101_32x4d'
# model_name = 'resnet18'
path = os.path.join('model_save', f'{trial}_{model_name}_best.pth')
# path = os.path.join('model_save', '1_resnet18.pth')
c_matrix = confusionMatrix(
model_path=path, model_name=model_name,
mode=mode, output_class=output_class,
image_size = image_size, plotclass=plotclass)
# print(type(c_matrix))
import matplotlib.pyplot as plt
figure = plt.figure()
axes = figure.add_subplot(111)
axes.matshow(c_matrix)
axes.set_title(f'Confusion Matrix: {mode} set')
axes.set(xlabel = 'Predicted',ylabel = 'Truth')
axes.set_xticks(np.arange(0, plotclass-1))
axes.set_yticks(np.arange(0, plotclass-1))
caxes = axes.matshow(c_matrix, interpolation ='nearest')
figure.colorbar(caxes)
def format_func(value, tick_number):
if value == 0:
return '01~20'
elif value == 1:
return '21~40'
elif value == 2:
return '41~60'
elif value == 3:
return '61~80'
elif value == 4:
return '81~100'
axes.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
axes.yaxis.set_major_formatter(plt.FuncFormatter(format_func))
for row_i, row in enumerate(c_matrix):
for col_i, col in enumerate(row):
axes.text(col_i-0.3,row_i+0.2,f'{col:.2f}',color='white')
print(f'--- %.1f sec ---' % (time.time() - start_time))
plt.savefig(f'confusion_matrix_{mode}_800.png')
# plt.show()