-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathExampleNet_CAM.py
46 lines (38 loc) · 1.26 KB
/
ExampleNet_CAM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from ExampleNet import ExampleNet
from CAM_categorical import CAM
class ExampleNetCAM(ExampleNet):
def __init__(
self,
out_channels,
):
super(ExampleNetCAM, self).__init__(
out_channels=out_channels,
)
self.cam = CAM(
nclasses=out_channels,
)
def forward(
self,
img,
atlas,
atlas_label,
):
"""
img [B 1 W+padding H+padding D+padding]: target scan intensity map after standardization.
atlas [B 1 W H D]: atlas scan intensity map after standardization.
atlas_label [B W H D C]: probabilistic atlas label
"""
x = img
for i, layer in enumerate(self.layers):
x = layer(x)
output_shape = x.shape
edge_length = [(img.shape[2] - output_shape[2]) // 2,
(img.shape[3] - output_shape[3]) // 2,
(img.shape[4] - output_shape[4]) // 2]
predictions = self.cam.forward(
unary=x,
img=img[:, :, edge_length[0]:-edge_length[0], edge_length[1]:-edge_length[1], edge_length[2]:-edge_length[2]],
atlas=atlas,
atlas_label=atlas_label,
)
return predictions