-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
102 lines (79 loc) · 2.89 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.optim as optim
from organnet.dataloader import MICCAI
from organnet.model import OrganNet
from organnet.loss import FocalLoss, DiceLoss
import torch
import sys, os
# load model
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LOAD_PATH = sys.argv[1] if len(sys.argv) > 1 else os.path.join('models',os.listdir('models')[-1])
ALPHA = torch.tensor([0.5, 1.0, 4.0, 1.0, 4.0, 4.0, 1.0, 1.0, 3.0, 3.0]).reshape(1, 10, 1, 1, 1)
GAMMA = 2
organs = [
'Background',
'Brain Stem',
'Opt. Chiasm',
'Mandible',
'Opt. Ner. L',
'Opt. Ner. R',
'Parotid L',
'Parotid R',
'Subman. L',
'Subman. R'
]
net = OrganNet().to(DEVICE)
optimizer = optim.Adam(net.parameters(), lr=0.001)
net.load_checkpoint(LOAD_PATH, optimizer, 0.001)
# load data
dset = 'test_offsite'
load_data_set = True if dset + '.pickle' in os.listdir('data') else False
test_dataloader = DataLoader(MICCAI(dset, load=load_data_set), batch_size=1, shuffle=True)
# focal loss + dice loss
criterion_focal = FocalLoss(GAMMA, ALPHA)
criterion_dice = DiceLoss()
losses = []
val_losses = []
def dice_score(inputs, targets):
n, c, h, w, d = inputs.shape
assert n == 1 and len(inputs.shape) == 5
inputs = inputs.reshape((c, h, w, d))
targets = targets.reshape((c, h, w, d))
c_max_input = torch.argmax(inputs, 0)
smooth = 1.0
inputs = torch.empty(c, h, w, d)
for i in range(c):
inputs[i] = torch.where(c_max_input == i, 1, 0)
inputs = inputs.to(DEVICE)
intersection = torch.mul(inputs, targets).sum([1, 2, 3])
dice = (2. * intersection) / (inputs.sum([1, 2, 3]) + targets.sum([1, 2, 3]) + smooth)
return dice * 100
DSC = {"0": [], "1": [], "2": [], "3": [], "4": [], "5": [], "6": [], "7": [], "8": [], "9": []}
with torch.no_grad():
test_loss = 0
for test_sample in test_dataloader:
inputs, labels = test_sample[0].to(DEVICE), test_sample[1].to(DEVICE)
outputs = net(inputs)
loss_dice = criterion_dice(outputs, labels)
loss_focal = criterion_focal(outputs, labels)
loss = loss_dice + loss_focal
test_loss += loss.item()
dsc = dice_score(outputs, labels)
for i, organ_dsc in enumerate(dsc):
DSC[str(i)].append(float(organ_dsc.item()))
print(f'TEST LOSS: {test_loss/len(test_dataloader)}')
DSC_avg = {}
for i, organ in enumerate(DSC.items()):
DSC_avg[str(i)] = sum(organ[1]) / len(organ[1])
for i, (k, v) in enumerate(DSC_avg.items()):
print("Organ:", k ,organs[i], 'DSC:', round(v,1))
print('DSC AVERAGE = ', round((sum(DSC_avg.values()) / len(DSC_avg)), 1))
# draw Boxplot
fig, ax = plt.subplots(1)
ax.boxplot(DSC.values())
ax.set_title('Boxplot DSC per Organ', fontsize=14, fontweight='bold')
ax.set_xticklabels(organs, rotation=25, fontsize=10)
ax.set_xlabel('Organs')
ax.set_ylabel('DSC')
plt.show()