-
Notifications
You must be signed in to change notification settings - Fork 25
/
yolo_v3.py
331 lines (253 loc) · 13.9 KB
/
yolo_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Contains Yolo core definitions
import tensorflow as tf
_BATCH_NORM_DECAY = 0.9
_BATCH_NORM_EPSILON = 1e-05
_LEAKY_RELU = 0.1
_ANCHORS = [(10, 13), (16, 30), (33, 23),
(30, 61), (62, 45), (59, 119),
(116, 90), (156, 198), (373, 326)]
# Performs a batch normalization using a standard set of parameters
def batch_norm(inputs, training, data_format):
return tf.layers.batch_normalization(
inputs=inputs, axis=1 if data_format == 'channels_first' else 3,
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON,
scale=True, training=training)
# ResNet implementation of fixed padding
def fixed_padding(inputs, kernel_size, data_format):
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if data_format == 'channels_first':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
[pad_beg, pad_end],
[pad_beg, pad_end]])
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]])
return padded_inputs
# Strided 2-D convolution with explicit padding
def conv2d_fixed_padding(inputs, filters, kernel_size, data_format, strides=1):
if strides > 1:
inputs = fixed_padding(inputs, kernel_size, data_format)
return tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size,
strides=strides, padding=('SAME' if strides == 1 else 'VALID'),
use_bias=False, data_format=data_format)
# Creates a residual block for Darknet
def darknet53_residual_block(inputs, filters, training, data_format, strides=1):
shortcut = inputs
inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, strides=strides, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding( inputs, filters=2 * filters, kernel_size=3, strides=strides, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs += shortcut
return inputs
# Creates Darknet53 model for feature extraction
def darknet53(inputs, training, data_format):
inputs = conv2d_fixed_padding(inputs, filters=32, kernel_size=3, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding(inputs, filters=64, kernel_size=3, strides=2, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = darknet53_residual_block(inputs, filters=32, training=training, data_format=data_format)
inputs = conv2d_fixed_padding(inputs, filters=128, kernel_size=3, strides=2, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
for _ in range(2):
inputs = darknet53_residual_block(inputs, filters=64, training=training, data_format=data_format)
inputs = conv2d_fixed_padding(inputs, filters=256, kernel_size=3, strides=2, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
for _ in range(8):
inputs = darknet53_residual_block(inputs, filters=128, training=training, data_format=data_format)
route1 = inputs
inputs = conv2d_fixed_padding(inputs, filters=512, kernel_size=3, strides=2, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
for _ in range(8):
inputs = darknet53_residual_block(inputs, filters=256, training=training, data_format=data_format)
route2 = inputs
inputs = conv2d_fixed_padding(inputs, filters=1024, kernel_size=3, strides=2, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
for _ in range(4):
inputs = darknet53_residual_block(inputs, filters=512, training=training, data_format=data_format)
return route1, route2, inputs
# Creates convolution operations layer used after Darknet
def yolo_convolution_block(inputs, filters, training, data_format):
inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
route = inputs
inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format)
inputs = batch_norm(inputs, training=training, data_format=data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
return route, inputs
# Creates Yolo final detection layer
def yolo_layer(inputs, n_classes, anchors, img_size, data_format):
n_anchors = len(anchors)
inputs = tf.layers.conv2d(inputs, filters=n_anchors * (5 + n_classes),
kernel_size=1, strides=1, use_bias=True,
data_format=data_format)
shape = inputs.get_shape().as_list()
grid_shape = shape[2:4] if data_format == 'channels_first' else shape[1:3]
if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 2, 3, 1])
inputs = tf.reshape(inputs, [-1, n_anchors * grid_shape[0] * grid_shape[1], 5 + n_classes])
strides = (img_size[0] // grid_shape[0], img_size[1] // grid_shape[1])
box_centers, box_shapes, confidence, classes = \
tf.split(inputs, [2, 2, 1, n_classes], axis=-1)
x = tf.range(grid_shape[0], dtype=tf.float32)
y = tf.range(grid_shape[1], dtype=tf.float32)
x_offset, y_offset = tf.meshgrid(x, y)
x_offset = tf.reshape(x_offset, (-1, 1))
y_offset = tf.reshape(y_offset, (-1, 1))
x_y_offset = tf.concat([x_offset, y_offset], axis=-1)
x_y_offset = tf.tile(x_y_offset, [1, n_anchors])
x_y_offset = tf.reshape(x_y_offset, [1, -1, 2])
box_centers = tf.nn.sigmoid(box_centers)
box_centers = (box_centers + x_y_offset) * strides
anchors = tf.tile(anchors, [grid_shape[0] * grid_shape[1], 1])
box_shapes = tf.exp(box_shapes) * tf.to_float(anchors)
confidence = tf.nn.sigmoid(confidence)
classes = tf.nn.sigmoid(classes)
inputs = tf.concat([box_centers, box_shapes, confidence, classes], axis=-1)
return inputs
# Upsamples to `out_shape` using nearest neighbor interpolation
def upsample(inputs, out_shape, data_format):
if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 2, 3, 1])
new_height = out_shape[3]
new_width = out_shape[2]
else:
new_height = out_shape[2]
new_width = out_shape[1]
inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width))
if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
return inputs
# Performs non-max suppression separately for each class
def non_max_suppression(inputs, n_classes, max_output_size, iou_threshold, confidence_threshold):
batch = tf.unstack(inputs)
boxes_dicts = []
for boxes in batch:
boxes = tf.boolean_mask(boxes, boxes[:, 4] > confidence_threshold)
classes = tf.argmax(boxes[:, 5:], axis=-1)
classes = tf.expand_dims(tf.to_float(classes), axis=-1)
boxes = tf.concat([boxes[:, :5], classes], axis=-1)
boxes_dict = dict()
for cls in range(n_classes):
mask = tf.equal(boxes[:, 5], cls)
mask_shape = mask.get_shape()
if mask_shape.ndims != 0:
class_boxes = tf.boolean_mask(boxes, mask)
boxes_coords, boxes_conf_scores, _ = tf.split(class_boxes,
[4, 1, -1],
axis=-1)
boxes_conf_scores = tf.reshape(boxes_conf_scores, [-1])
indices = tf.image.non_max_suppression(boxes_coords,
boxes_conf_scores,
max_output_size,
iou_threshold)
class_boxes = tf.gather(class_boxes, indices)
boxes_dict[cls] = class_boxes[:, :5]
boxes_dicts.append(boxes_dict)
return boxes_dicts
# Computes top left and bottom right points of the boxes
def build_boxes(inputs):
center_x, center_y, width, height, confidence, classes = \
tf.split(inputs, [1, 1, 1, 1, 1, -1], axis=-1)
top_left_x = center_x - width / 2
top_left_y = center_y - height / 2
bottom_right_x = center_x + width / 2
bottom_right_y = center_y + height / 2
boxes = tf.concat([top_left_x, top_left_y,
bottom_right_x, bottom_right_y,
confidence, classes], axis=-1)
return boxes
# Yolo v3 model class
class Yolo_v3:
def __init__(self, n_classes, model_size, max_output_size, iou_threshold,
confidence_threshold, data_format=None):
if not data_format:
if tf.test.is_built_with_cuda():
data_format = 'channels_first'
else:
data_format = 'channels_last'
self.n_classes = n_classes
self.model_size = model_size
self.max_output_size = max_output_size
self.iou_threshold = iou_threshold
self.confidence_threshold = confidence_threshold
self.data_format = data_format
# Add operations to detect boxes for a batch of input images
def __call__(self, inputs, training):
with tf.variable_scope('yolo_v3_model'):
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = inputs / 255
route1, route2, inputs = darknet53(inputs, training=training,
data_format=self.data_format)
route, inputs = yolo_convolution_block(
inputs, filters=512, training=training,
data_format=self.data_format)
detect1 = yolo_layer(inputs, n_classes=self.n_classes,
anchors=_ANCHORS[6:9],
img_size=self.model_size,
data_format=self.data_format)
inputs = conv2d_fixed_padding(route, filters=256, kernel_size=1,
data_format=self.data_format)
inputs = batch_norm(inputs, training=training,
data_format=self.data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
upsample_size = route2.get_shape().as_list()
inputs = upsample(inputs, out_shape=upsample_size,
data_format=self.data_format)
axis = 1 if self.data_format == 'channels_first' else 3
inputs = tf.concat([inputs, route2], axis=axis)
route, inputs = yolo_convolution_block(
inputs, filters=256, training=training,
data_format=self.data_format)
detect2 = yolo_layer(inputs, n_classes=self.n_classes,
anchors=_ANCHORS[3:6],
img_size=self.model_size,
data_format=self.data_format)
inputs = conv2d_fixed_padding(route, filters=128, kernel_size=1,
data_format=self.data_format)
inputs = batch_norm(inputs, training=training,
data_format=self.data_format)
inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU)
upsample_size = route1.get_shape().as_list()
inputs = upsample(inputs, out_shape=upsample_size,
data_format=self.data_format)
inputs = tf.concat([inputs, route1], axis=axis)
route, inputs = yolo_convolution_block(
inputs, filters=128, training=training,
data_format=self.data_format)
detect3 = yolo_layer(inputs, n_classes=self.n_classes,
anchors=_ANCHORS[0:3],
img_size=self.model_size,
data_format=self.data_format)
inputs = tf.concat([detect1, detect2, detect3], axis=1)
inputs = build_boxes(inputs)
boxes_dicts = non_max_suppression(
inputs, n_classes=self.n_classes,
max_output_size=self.max_output_size,
iou_threshold=self.iou_threshold,
confidence_threshold=self.confidence_threshold)
return boxes_dicts