-
Notifications
You must be signed in to change notification settings - Fork 16
/
Net_Basic_V1.py
105 lines (87 loc) · 3.01 KB
/
Net_Basic_V1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torchvision.models as backbone_
import torch.nn.functional as F
import encoding
class Net_Basic(nn.Module):
def __init__(self):
super(Net_Basic, self).__init__()
backbone = backbone_.inception_v3(pretrained=True)
#self.backbone.aux_logits = False
self.Conv2d_1a_3x3 = backbone.Conv2d_1a_3x3
self.Conv2d_2a_3x3 = backbone.Conv2d_2a_3x3
self.Conv2d_2b_3x3 = backbone.Conv2d_2b_3x3
self.Conv2d_3b_1x1 = backbone.Conv2d_3b_1x1
self.Conv2d_4a_3x3 = backbone.Conv2d_4a_3x3
self.Mixed_5b = backbone.Mixed_5b
self.Mixed_5c = backbone.Mixed_5c
self.Mixed_5d = backbone.Mixed_5d
self.Mixed_6a = backbone.Mixed_6a
self.Mixed_6b = backbone.Mixed_6b
self.Mixed_6c = backbone.Mixed_6c
self.Mixed_6d = backbone.Mixed_6d
self.Mixed_6e = backbone.Mixed_6e
self.Mixed_7a = backbone.Mixed_7a
self.Mixed_7b = backbone.Mixed_7b
self.Mixed_7c = backbone.Mixed_7c
#self.linear_classification = nn.Linear(2048,125)
self.head_layer = nn.Sequential(
encoding.nn.Normalize(),
nn.Linear(2048, 64),
encoding.nn.Normalize())
def forward(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8 # Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 2048 x 1 x 1
x = x.view(x.size(0), -1) #N x 2048
##class_prediction = self.linear_classification(x) #N x 125
#embedding = F.normalize(x) #N x 2048
embedding = self.head_layer(x)
return embedding
def fix_weights(self):
for x in self.parameters():
x.requires_grad = False
if __name__ == "__main__":
model = Net_Basic()
print('loaded')
print(model)
embedding = model(torch.randn(10,3,299,299))
print(embedding.shape)
#for p in model.parameters():
# print(p.requires_grad, p.shape)