-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtest_fanbeam.py
111 lines (86 loc) · 4.37 KB
/
test_fanbeam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from .utils import generate_random_images, relative_error, circle_mask
import astra
from nose.tools import assert_less, assert_equal
import torch
import numpy as np
from torch_radon import RadonFanbeam
from parameterized import parameterized
import matplotlib.pyplot as plt
device = torch.device('cuda')
full_angles = np.linspace(0, 2 * np.pi, 180).astype(np.float32)
limited_angles = np.linspace(0.2 * np.pi, 0.5 * np.pi, 50).astype(np.float32)
sparse_angles = np.linspace(0, 2 * np.pi, 60).astype(np.float32)
many_angles = np.linspace(0, 2 * np.pi, 800).astype(np.float32)
params = []
for batch_size in [1, 8]:
for image_size in [128, 151]:
for angles in [full_angles, limited_angles, sparse_angles, many_angles]:
for spacing in [1.0, 0.5, 1.3, 2.0]:
for distances in [(1.2, 1.2), (2.0, 2.0), (1.2, 3.0)]:
for det_count in [1.0, 1.5]:
for clip_to_circle in [False, True]:
params.append((device, batch_size, image_size, angles, spacing, distances, det_count, clip_to_circle))
half_params = [x for x in params if x[1] % 4 == 0]
@parameterized(params)
def test_fanbeam_error(device, batch_size, image_size, angles, spacing, distances, det_count, clip_to_circle):
# generate random images
# generate random images
det_count = int(det_count * image_size)
mask_radius = det_count / 2.0 if clip_to_circle else -1
x = generate_random_images(1, image_size, mask_radius)[0]
s_dist, d_dist = distances
s_dist *= image_size
d_dist *= image_size
# astra
vol_geom = astra.create_vol_geom(x.shape[0], x.shape[1])
proj_geom = astra.create_proj_geom('fanflat', spacing, det_count, angles, s_dist, d_dist)
proj_id = astra.create_projector('cuda', proj_geom, vol_geom)
id, astra_y = astra.create_sino(x, proj_id)
_, astra_bp = astra.create_backprojection(astra_y, proj_id)
if clip_to_circle:
astra_bp *= circle_mask(image_size, mask_radius)
# TODO clean astra structures
# our implementation
radon = RadonFanbeam(image_size, angles, s_dist, d_dist, det_count=det_count, det_spacing=spacing, clip_to_circle=clip_to_circle)
x = torch.FloatTensor(x).to(device).view(1, x.shape[0], x.shape[1])
# repeat data to fill batch size
x = torch.cat([x] * batch_size, dim=0)
our_fp = radon.forward(x)
our_bp = radon.backprojection(our_fp)
forward_error = relative_error(astra_y, our_fp[0].cpu().numpy())
back_error = relative_error(astra_bp, our_bp[0].cpu().numpy())
# if back_error > 5e-3:
# plt.imshow(astra_bp)
# plt.figure()
# plt.imshow(our_bp[0].cpu().numpy())
# plt.show()
print(np.max(our_fp.cpu().numpy()), np.max(our_bp.cpu().numpy()))
print(
f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, distances: {distances} circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}")
# TODO better checks
assert_less(forward_error, 1e-2)
assert_less(back_error, 5e-3)
@parameterized(half_params)
def test_half(device, batch_size, image_size, angles, spacing, distances, det_count, clip_to_circle):
# generate random images
det_count = int(det_count * image_size)
mask_radius = det_count / 2.0 if clip_to_circle else -1
x = generate_random_images(batch_size, image_size, mask_radius)
s_dist, d_dist = distances
s_dist *= image_size
d_dist *= image_size
# our implementation
radon = RadonFanbeam(image_size, angles, s_dist, d_dist, det_count=det_count, det_spacing=spacing, clip_to_circle=clip_to_circle)
x = torch.FloatTensor(x).to(device)
# divide by len(angles) to avoid half-precision overflow
sinogram = radon.forward(x) / len(angles)
single_precision = radon.backprojection(sinogram)
h_sino = radon.forward(x.half()) / len(angles)
half_precision = radon.backprojection(h_sino)
print(torch.min(half_precision).item(), torch.max(half_precision).item())
forward_error = relative_error(sinogram.cpu().numpy(), h_sino.cpu().numpy())
back_error = relative_error(single_precision.cpu().numpy(), half_precision.cpu().numpy())
print(
f"batch: {batch_size}, size: {image_size}, angles: {len(angles)}, spacing: {spacing}, circle: {clip_to_circle}, forward: {forward_error}, back: {back_error}")
assert_less(forward_error, 1e-3)
assert_less(back_error, 1e-3)