-
Notifications
You must be signed in to change notification settings - Fork 4
/
exercise.py
197 lines (169 loc) · 5.68 KB
/
exercise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from loaders import get_train_eval_test
from typing import Literal
from torch import nn
from torch.nn import functional as F
import torch
class Backbone(nn.Module):
"""A 2D UNet
```
C -[conv xN]-> F ----------------------(cat)----------------------> 2*F -[conv xN]-> Cout
| ^
v |
F*m -[conv xN]-> F*m ---(cat)---> 2*F*m -[conv xN]-> F*m
| ^
v |
F*m*m -[conv xN]-> F*m*m
```
""" # noqa: E501
def __init__(
self,
inp_channels: int = 2,
out_channels: int = 2,
nb_features: int = 16,
mul_features: int = 2,
nb_levels: int = 3,
nb_conv_per_level: int = 2,
# Implementing the following switches is optional.
# If not implementing the switch, choose the mode you prefer.
activation: Literal['ReLU', 'ELU'] = 'ReLU',
pool: Literal['interpolate', 'conv'] = 'interpolate',
):
"""
Parameters
----------
inp_channels : int
Number of input channels
out_channels : int
Number of output chanels
nb_features : int
Number of features at the finest level
mul_features : int
Multiply the number of features by this number
each time we go down one level.
nb_conv_per_level : int
Number of convolutional layers at each level.
pool : {'interpolate', 'conv'}
Method used to go down/up one level.
If `interpolate`, use `torch.nn.functional.interpolate`.
If `conv`, use strided convolutions on the way down, and
transposed convolutions on the way up.
activation : {'ReLU', 'ELU'}
Type of activation
"""
raise NotImplementedError
def forward(self, inp):
"""
Parameters
----------
inp : (B, in_channels, X, Y)
Input tensor
Returns
-------
out : (B, out_channels, X, Y)
Output tensor
"""
raise NotImplementedError
class VoxelMorph(nn.Module):
"""
Construct a voxelmorph network with the given backbone
"""
def __init__(self, **backbone_parameters):
"""
Parameters
----------
backbone_parameters : dict
Parameters of the `Backbone` class
"""
super().__init__()
self.backbone = Backbone(2, 2, **backbone_parameters)
def forward(self, fixmov):
"""
Predict a displacement field from a fixed and moving images
Parameters
----------
fixmov : (B, 2, X, Y) tensor
Input fixed and moving images, stacked along
the channel dimension
Returns
-------
disp : (B, 2, X, Y) tensor
Predicted displacement field
"""
return self.backbone(fixmov)
def deform(self, mov, disp):
"""
Deform the image `mov` using the displacement field `disp`
Parameters
----------
moving : (B, 1, X, Y) tensor
Moving image
disp : (B, 2, X, Y) tensor
Displacement field
Returns
-------
moved : (B, 1, X, Y) tensor
Moved image
"""
opt = dict(dtype=mov.dtype, device=mov.device)
disp = disp.clone()
nx, ny = mov.shape[-2:]
# Rescale displacement to conform to torch conventions with
# align_corners=True
# 0) disp contains relative displacements in voxels
mx, my = torch.meshgrid(torch.arange(nx, **opt),
torch.arange(ny, **opt), indexing='ij')
disp[:, 0] += mx
disp[:, 1] += my
# 1) disp contains absolute coordinates in voxels
disp[:, 0] *= 2 / (nx - 1)
disp[:, 1] *= 2 / (ny - 1)
# 2) disp contains absolute coordinates in (0, 2)
disp -= 1
# 3) disp contains absolute coordinates in (-1, 1)
# Permute/flip to conform to torch conventions
disp = disp.permute([0, 2, 3, 1])
disp = disp.flip([-1])
# Transform moving image
return F.grid_sample(
mov, disp,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
)
def membrane(self, disp):
"""
Compute the membrane energy of the displacement field
(the average of its squared spatial gradients)
"""
return (
(disp[:, :, 1:, :] - disp[:, :, :-1, :]).square().mean() +
(disp[:, :, :, 1:] - disp[:, :, :, :-1]).square().mean())
def loss(self, fix, mov, disp, lam=0.1):
"""
Compute the regularized loss (mse + membrane * lam)
Parameters
----------
fix : (B, 1, X, Y) tensor
Fixed image
mov : (B, 1, X, Y) tensor
Moving image
disp : (B, 2, X, Y) tensor
Displacement field
lam : float
Regularization
"""
moved = self.deform(mov, disp)
loss = nn.MSELoss()(moved, fix)
loss += self.membrane(disp) * lam
return loss
trainset, evalset, testset = get_train_eval_test()
def train(*args, **kwargs):
"""
A training function
"""
raise NotImplementedError('Implement this function yourself')
def test(*args, **kwargs):
"""
A testing function
"""
raise NotImplementedError('Implement this function yourself')