-
Notifications
You must be signed in to change notification settings - Fork 0
/
SE.py
171 lines (138 loc) · 6.3 KB
/
SE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from torch import nn as nn
from torch.nn import functional as F
class ChannelSELayer3D(nn.Module):
"""
3D extension of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
*Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSELayer3D, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output tensor
"""
batch_size, num_channels, D, H, W = input_tensor.size()
# Average along each channel
squeeze_tensor = self.avg_pool(input_tensor)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1))
return output_tensor
class SpatialSELayer3D(nn.Module):
"""
3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
"""
def __init__(self, num_channels):
"""
:param num_channels: No of input channels
"""
super(SpatialSELayer3D, self).__init__()
self.conv = nn.Conv3d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor, weights=None):
"""
:param weights: weights for few shot learning
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output_tensor
"""
# channel squeeze
batch_size, channel, D, H, W = input_tensor.size()
if weights:
weights = weights.view(1, channel, 1, 1)
out = F.conv2d(input_tensor, weights)
else:
out = self.conv(input_tensor)
squeeze_tensor = self.sigmoid(out)
# spatial excitation
output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, D, H, W))
return output_tensor
class ChannelSpatialSELayer3D(nn.Module):
"""
3D extension of concurrent spatial and channel squeeze & excitation:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSpatialSELayer3D, self).__init__()
self.cSE = ChannelSELayer3D(num_channels, reduction_ratio)
self.sSE = SpatialSELayer3D(num_channels)
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output_tensor
"""
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
return output_tensor
class ProjectExciteLayer(nn.Module):
"""
Project & Excite Module, specifically designed for 3D inputs
*quote*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ProjectExciteLayer, self).__init__()
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.relu = nn.ReLU()
self.conv_c = nn.Conv3d(in_channels=num_channels, out_channels=num_channels_reduced, kernel_size=1, stride=1)
self.conv_cT = nn.Conv3d(in_channels=num_channels_reduced, out_channels=num_channels, kernel_size=1, stride=1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
:return: output tensor
"""
batch_size, num_channels, D, H, W = input_tensor.size()
# Project:
# Average along channels and different axes
squeeze_tensor_w = F.adaptive_avg_pool3d(input_tensor, (1, 1, W))
squeeze_tensor_h = F.adaptive_avg_pool3d(input_tensor, (1, H, 1))
squeeze_tensor_d = F.adaptive_avg_pool3d(input_tensor, (D, 1, 1))
# tile tensors to original size and add:
final_squeeze_tensor = sum([squeeze_tensor_w.view(batch_size, num_channels, 1, 1, W),
squeeze_tensor_h.view(batch_size, num_channels, 1, H, 1),
squeeze_tensor_d.view(batch_size, num_channels, D, 1, 1)])
# Excitation:
final_squeeze_tensor = self.sigmoid(self.conv_cT(self.relu(self.conv_c(final_squeeze_tensor))))
output_tensor = torch.mul(input_tensor, final_squeeze_tensor)
return output_tensor
class SELayer3D(Enum):
"""
Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blocks to
a neural network::
if self.se_block_type == se.SELayer3D.CSE3D.value:
self.SELayer = se.ChannelSELayer3D(params['num_filters'])
elif self.se_block_type == se.SELayer3D.SSE3D.value:
self.SELayer = se.SpatialSELayer3D(params['num_filters'])
elif self.se_block_type == se.SELayer3D.CSSE3D.value:
self.SELayer = se.ChannelSpatialSELayer3D(params['num_filters'])
elif self.se_block_type == se.SELayer3D.PE.value:
self.SELayer = se.ProjectExcite(params['num_filters')
"""
NONE = 'NONE'
CSE3D = 'CSE3D'
SSE3D = 'SSE3D'
CSSE3D = 'CSSE3D'
PE = 'PE'