forked from Ainimal/Aini_Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Demo7_2D_TransUnet_Segmentation.py
98 lines (84 loc) · 3.77 KB
/
Demo7_2D_TransUnet_Segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class
class TransUNet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_downsample_size = configuration.image_size = [8, 8]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)
pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)
# neck forward
f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size))
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_downsample_size[0],
self.trans_downsample_size[1]
) # reshape
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck
# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
# seg_head forward
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 256, 256])
label_category_dict = dict(organ=3, tumor=4)
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])