-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
197 lines (160 loc) · 7.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
class ResNet(nn.Module):
def __init__(self, backbone='resnet50', backbone_path=None):
super().__init__()
if backbone == 'resnet18':
backbone = resnet18(pretrained=not backbone_path)
self.out_channels = [256, 512, 512, 256, 256, 128]
elif backbone == 'resnet34':
backbone = resnet34(pretrained=not backbone_path)
self.out_channels = [256, 512, 512, 256, 256, 256]
elif backbone == 'resnet50':
backbone = resnet50(pretrained=not backbone_path)
self.out_channels = [1024, 512, 512, 256, 256, 256]
elif backbone == 'resnet101':
backbone = resnet101(pretrained=not backbone_path)
self.out_channels = [1024, 512, 512, 256, 256, 256]
else: # backbone == 'resnet152':
backbone = resnet152(pretrained=not backbone_path)
self.out_channels = [1024, 512, 512, 256, 256, 256]
if backbone_path:
backbone.load_state_dict(torch.load(backbone_path))
self.feature_extractor = nn.Sequential(*list(backbone.children())[:7])
conv4_block1 = self.feature_extractor[-1][0]
conv4_block1.conv1.stride = (1, 1)
conv4_block1.conv2.stride = (1, 1)
conv4_block1.downsample[0].stride = (1, 1)
def forward(self, x):
x = self.feature_extractor(x)
return x
class SSD300(nn.Module):
def __init__(self, backbone=ResNet('resnet50')):
super().__init__()
self.feature_extractor = backbone
self.label_num = 81 # number of COCO classes
self._build_additional_features(self.feature_extractor.out_channels)
self.num_defaults = [4, 6, 6, 6, 4, 4]
self.loc = []
self.conf = []
for nd, oc in zip(self.num_defaults, self.feature_extractor.out_channels):
self.loc.append(nn.Conv2d(oc, nd * 4, kernel_size=3, padding=1))
self.conf.append(nn.Conv2d(oc, nd * self.label_num, kernel_size=3, padding=1))
self.loc = nn.ModuleList(self.loc)
self.conf = nn.ModuleList(self.conf)
self._init_weights()
def _build_additional_features(self, input_size):
self.additional_blocks = []
for i, (input_size, output_size, channels) in enumerate(zip(input_size[:-1], input_size[1:], [256, 256, 128, 128, 128])):
if i < 3:
layer = nn.Sequential(
nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, output_size, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(output_size),
nn.ReLU(inplace=True),
)
else:
layer = nn.Sequential(
nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, output_size, kernel_size=3, bias=False),
nn.BatchNorm2d(output_size),
nn.ReLU(inplace=True),
)
self.additional_blocks.append(layer)
self.additional_blocks = nn.ModuleList(self.additional_blocks)
def _init_weights(self):
layers = [*self.additional_blocks, *self.loc, *self.conf]
for layer in layers:
for param in layer.parameters():
if param.dim() > 1: nn.init.xavier_uniform_(param)
# Shape the classifier to the view of bboxes
def bbox_view(self, src, loc, conf):
ret = []
for s, l, c in zip(src, loc, conf):
ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
locs, confs = list(zip(*ret))
locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
return locs, confs
def forward(self, x):
x = self.feature_extractor(x)
detection_feed = [x]
for l in self.additional_blocks:
x = l(x)
detection_feed.append(x)
# Feature Map 38x38x4, 19x19x6, 10x10x6, 5x5x6, 3x3x4, 1x1x4
locs, confs = self.bbox_view(detection_feed, self.loc, self.conf)
# For SSD 300, shall return nbatch x 8732 x {nlabels, nlocs} results
return locs, confs
class Loss(nn.Module):
"""
Implements the loss as the sum of the followings:
1. Confidence Loss: All labels, with hard negative mining
2. Localization Loss: Only on positive labels
Suppose input dboxes has the shape 8732x4
"""
def __init__(self, dboxes):
super(Loss, self).__init__()
self.scale_xy = 1.0/dboxes.scale_xy
self.scale_wh = 1.0/dboxes.scale_wh
self.sl1_loss = nn.SmoothL1Loss(reduction='none')
self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim = 0),
requires_grad=False)
# Two factor are from following links
# http://jany.st/post/2017-11-05-single-shot-detector-ssd-from-scratch-in-tensorflow.html
self.con_loss = nn.CrossEntropyLoss(reduction='none')
def _loc_vec(self, loc):
"""
Generate Location Vectors
"""
gxy = self.scale_xy*(loc[:, :2, :] - self.dboxes[:, :2, :])/self.dboxes[:, 2:, ]
gwh = self.scale_wh*(loc[:, 2:, :]/self.dboxes[:, 2:, :]).log()
return torch.cat((gxy, gwh), dim=1).contiguous()
def forward(self, ploc, plabel, gloc, glabel):
"""
ploc, plabel: Nx4x8732, Nxlabel_numx8732
predicted location and labels
gloc, glabel: Nx4x8732, Nx8732
ground truth location and labels
"""
mask = glabel > 0
pos_num = mask.sum(dim=1)
vec_gd = self._loc_vec(gloc)
# sum on four coordinates, and mask
sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
sl1 = (mask.float()*sl1).sum(dim=1)
# hard negative mining
con = self.con_loss(plabel, glabel)
# postive mask will never selected
con_neg = con.clone()
con_neg[mask] = 0
_, con_idx = con_neg.sort(dim=1, descending=True)
_, con_rank = con_idx.sort(dim=1)
# number of negative three times positive
neg_num = torch.clamp(3*pos_num, max=mask.size(1)).unsqueeze(-1)
neg_mask = con_rank < neg_num
#print(con.shape, mask.shape, neg_mask.shape)
closs = (con*((mask + neg_mask).float())).sum(dim=1)
# avoid no object detected
total_loss = sl1 + closs
num_mask = (pos_num > 0).float()
pos_num = pos_num.float().clamp(min=1e-6)
ret = (total_loss*num_mask/pos_num).mean(dim=0)
return ret