-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodels.py
65 lines (55 loc) · 1.95 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
sigmoid = nn.Sigmoid()
class Swish(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class Swish_Module(nn.Module):
def forward(self, x):
return Swish.apply(x)
class Effnet_Melanoma(nn.Module):
def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128]):
super(Effnet_Melanoma, self).__init__()
self.n_meta_features = n_meta_features
self.enet = EfficientNet.from_pretrained(enet_type)
self.dropouts = nn.ModuleList([
nn.Dropout(0.5) for _ in range(5)
])
in_ch = self.enet._fc.in_features
if n_meta_features > 0:
self.meta = nn.Sequential(
nn.Linear(n_meta_features, n_meta_dim[0]),
nn.BatchNorm1d(n_meta_dim[0]),
Swish_Module(),
nn.Dropout(p=0.3),
nn.Linear(n_meta_dim[0], n_meta_dim[1]),
nn.BatchNorm1d(n_meta_dim[1]),
Swish_Module(),
)
in_ch += n_meta_dim[1]
self.myfc = nn.Linear(in_ch, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
x = self.enet(x)
return x
def forward(self, x, x_meta=None):
x = self.extract(x).squeeze(-1).squeeze(-1)
if self.n_meta_features > 0:
x_meta = self.meta(x_meta)
x = torch.cat((x, x_meta), dim=1)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out = self.myfc(dropout(x))
else:
out += self.myfc(dropout(x))
out /= len(self.dropouts)
return out