-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_4x.py
74 lines (65 loc) · 4.21 KB
/
evaluate_4x.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import os
import time
import numpy as np
import nibabel as nib
from dataset import loader_test, ImgTest, norm_01
from collections import OrderedDict
from network import InterpolationNetwork, DiscriminatorForVGG
from utils.util import print_current_losses, evaluate_2D, evaluate_3D
def evaluate(args):
save_dir = os.path.join('results', args.project_name)
os.makedirs(os.path.join(save_dir, 'evaluate'), exist_ok=True)
model = InterpolationNetwork()
model.load_state_dict(torch.load(args.resume, map_location='cpu'))
model = model.cuda()
model.eval()
psnr = []
ssim = []
mae = []
os.makedirs(os.path.join(save_dir, 'evaluate_4x'), exist_ok=True)
for subject in os.listdir(args.data_path):
data_path = os.path.join(args.data_path, subject)
subject_dataset = ImgTest(data_path, 4, args.thick_direction, simulate_lr=True)
subject_loader = loader_test(subject_dataset)
upsampled_data = []
input_data = []
for data in subject_loader:
slice_img_0, slice_img_1 = data
with torch.no_grad():
generated_intermediate_slice = model(slice_img_0.unsqueeze(0).cuda(), slice_img_1.unsqueeze(0).cuda())
generated_second_slice = model(slice_img_0.unsqueeze(0).cuda(), generated_intermediate_slice)
generated_thrid_slice = model(generated_intermediate_slice, slice_img_1.unsqueeze(0).cuda())
upsampled_data.append(slice_img_0.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
upsampled_data.append(np.clip(generated_second_slice.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy(), 0, 1))
upsampled_data.append(np.clip(generated_intermediate_slice.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy(), 0, 1))
upsampled_data.append(np.clip(generated_thrid_slice.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy(), 0, 1))
input_data.append(slice_img_0.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_0.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_0.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_0.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
# the last slice should be duplicated
upsampled_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
upsampled_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
upsampled_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
upsampled_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
input_data.append(slice_img_1.squeeze().unsqueeze(subject_dataset.axis).cpu().numpy())
upsampled_data = np.concatenate(upsampled_data, axis=subject_dataset.axis)
input_data = np.concatenate(input_data, axis=subject_dataset.axis)
gt_img = subject_dataset.img
gt_data = np.array(gt_img.dataobj)
gt_data = norm_01(gt_data)
out_input_img = nib.Nifti1Image(input_data, gt_img.affine, gt_img.header)
out_img = nib.Nifti1Image(upsampled_data, gt_img.affine, gt_img.header)
nib.save(out_input_img, os.path.join(save_dir, 'evaluate_4x', subject.replace('.nii.gz', '_input.nii.gz')))
nib.save(out_img, os.path.join(save_dir, 'evaluate_4x', subject.replace('.nii.gz', '_prediction.nii.gz')))
print(upsampled_data.shape, gt_data.shape)
c_psnr, c_ssim, c_mae = evaluate_3D(upsampled_data, gt_data, data_range=1)
psnr.append(c_psnr)
ssim.append(c_ssim)
mae.append(c_mae)
print('Subject: {}, PSNR: {:.4f}, SSIM: {:.4f}, MAE: {:.4f}'.format(subject, c_psnr, c_ssim, c_mae))
print('Average PSNR: {:.4f}, Average SSIM: {:.4f}, Average MAE: {:.4f}'.format(np.mean(psnr), np.mean(ssim), np.mean(mae)))