forked from Ainimal/Aini_Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Demo4_ResNetUnet_MultiLabelSegmentation.py
56 lines (48 loc) · 1.87 KB
/
Demo4_ResNetUnet_MultiLabelSegmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])