-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathEvaluation.py
43 lines (31 loc) · 1.28 KB
/
Evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import matplotlib.pyplot as plt
from utils import Adder, get_miou, get_biou
def evaluation(args, model, test_loader):
state = torch.load(args.load)
model.load_state_dict(state['model'])
model = model.to(args.device)
mIoU = Adder()
bIoU = Adder()
model.eval()
with torch.no_grad():
for idx, (name_id, img, inputs, masks) in enumerate(test_loader):
inputs, masks = inputs.to(args.device), masks.to(args.device)
outputs = model(inputs)
mIoU(get_miou(outputs, masks))
bIoU(get_biou(outputs, masks))
outputs = outputs.cpu().squeeze()
outputs = torch.argmax(outputs,dim=0)
print('name_id')
fig = plt.figure(figsize=(12, 8))
plt.subplot(1, 3, 1), plt.imshow(img.squeeze()), plt.axis('off')
plt.title('Original')
plt.subplot(1, 3, 2), plt.imshow(masks.cpu().squeeze(), 'gray'), plt.axis('off')
plt.title('GroundTruth')
plt.subplot(1, 3, 3), plt.imshow(outputs, 'gray'), plt.axis('off')
plt.title('Output')
plt.show()
print('mIOU:%.3f' % mIoU.average())
print('bIOU:%.3f' % bIoU.average())
print('Done!')
return