diff --git a/README.md b/README.md index 5fbbe0f..b279db1 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ To run demo type this in the command line: 1. Download binary file with desired weights: 1. Full weights: `wget https://pjreddie.com/media/files/yolov3.weights` 1. Tiny weights: `wget https://pjreddie.com/media/files/yolov3-tiny.weights` + 1. SPP weights: `wget https://pjreddie.com/media/files/yolov3-spp.weights` 2. Run `python ./convert_weights.py` and `python ./convert_weights_pb.py` 3. Run `python ./demo.py --input_img --output_img --frozen_model ` @@ -33,7 +34,9 @@ To run demo type this in the command line: 1. `NCHW` (gpu only) or `NHWC` 4. `--tiny` 1. Use yolov3-tiny - 5. `--ckpt_file` + 5. `--spp` + 1. Use yolov3-spp + 6. `--ckpt_file` 1. Output checkpoint file 2. convert_weights_pb.py: 1. `--class_names` @@ -44,7 +47,9 @@ To run demo type this in the command line: 1. `NCHW` (gpu only) or `NHWC` 4. `--tiny` 1. Use yolov3-tiny - 5. `--output_graph` + 5. `--spp` + 1. Use yolov3-spp + 6. `--output_graph` 1. Location to write the output .pb graph to 3. demo.py 1. `--class_names` @@ -62,4 +67,4 @@ To run demo type this in the command line: 7. `--iou_threshold` 1. Desired iou threshold 8. `--gpu_memory_fraction` - 1. Fraction of gpu memory to work with \ No newline at end of file + 1. Fraction of gpu memory to work with diff --git a/convert_weights.py b/convert_weights.py index e652bca..01f656e 100644 --- a/convert_weights.py +++ b/convert_weights.py @@ -17,6 +17,8 @@ 'data_format', 'NCHW', 'Data format: NCHW (gpu only) / NHWC') tf.app.flags.DEFINE_bool( 'tiny', False, 'Use tiny version of YOLOv3') +tf.app.flags.DEFINE_bool( + 'spp', False, 'Use SPP version of YOLOv3') tf.app.flags.DEFINE_string( 'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file') @@ -24,6 +26,8 @@ def main(argv=None): if FLAGS.tiny: model = yolo_v3_tiny.yolo_v3_tiny + elif FLAGS.spp: + model = yolo_v3.yolo_v3_spp else: model = yolo_v3.yolo_v3 diff --git a/convert_weights_pb.py b/convert_weights_pb.py index 7f0af00..1ade65a 100644 --- a/convert_weights_pb.py +++ b/convert_weights_pb.py @@ -21,6 +21,8 @@ tf.app.flags.DEFINE_bool( 'tiny', False, 'Use tiny version of YOLOv3') +tf.app.flags.DEFINE_bool( + 'spp', False, 'Use SPP version of YOLOv3') tf.app.flags.DEFINE_integer( 'size', 416, 'Image size') @@ -29,6 +31,8 @@ def main(argv=None): if FLAGS.tiny: model = yolo_v3_tiny.yolo_v3_tiny + elif FLAGS.spp: + model = yolo_v3.yolo_v3_spp else: model = yolo_v3.yolo_v3 diff --git a/demo.py b/demo.py index 188908c..de7e931 100644 --- a/demo.py +++ b/demo.py @@ -29,6 +29,8 @@ 'frozen_model', '', 'Frozen tensorflow protobuf model') tf.app.flags.DEFINE_bool( 'tiny', False, 'Use tiny version of YOLOv3') +tf.app.flags.DEFINE_bool( + 'spp', False, 'Use SPP version of YOLOv3') tf.app.flags.DEFINE_integer( 'size', 416, 'Image size') @@ -71,6 +73,8 @@ def main(argv=None): else: if FLAGS.tiny: model = yolo_v3_tiny.yolo_v3_tiny + elif FLAGS.spp: + model = yolo_v3.yolo_v3_spp else: model = yolo_v3.yolo_v3 diff --git a/yolo_v3.py b/yolo_v3.py index cd5b6c0..ef1d1a7 100644 --- a/yolo_v3.py +++ b/yolo_v3.py @@ -63,6 +63,14 @@ def _darknet53_block(inputs, filters): return inputs +def _spp_block(inputs, data_format='NCHW'): + return tf.concat([slim.max_pool2d(inputs, 13, 1, 'SAME'), + slim.max_pool2d(inputs, 9, 1, 'SAME'), + slim.max_pool2d(inputs, 5, 1, 'SAME'), + inputs], + axis=1 if data_format == 'NCHW' else 3) + + @tf.contrib.framework.add_arg_scope def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs): """ @@ -95,10 +103,15 @@ def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs): return padded_inputs -def _yolo_block(inputs, filters): +def _yolo_block(inputs, filters, data_format='NCHW', with_spp=False): inputs = _conv2d_fixed_padding(inputs, filters, 1) inputs = _conv2d_fixed_padding(inputs, filters * 2, 3) inputs = _conv2d_fixed_padding(inputs, filters, 1) + + if with_spp: + inputs = _spp_block(inputs, data_format) + inputs = _conv2d_fixed_padding(inputs, filters, 1) + inputs = _conv2d_fixed_padding(inputs, filters * 2, 3) inputs = _conv2d_fixed_padding(inputs, filters, 1) route = inputs @@ -187,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'): return inputs -def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False): +def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False): """ Creates YOLO v3 model. @@ -197,6 +210,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa :param is_training: whether is training or not. :param data_format: data format NCHW or NHWC. :param reuse: whether or not the network and its variables should be reused. + :param with_spp: whether or not is using spp layer. :return: """ # it will be needed later on @@ -228,7 +242,8 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa route_1, route_2, inputs = darknet53(inputs) with tf.variable_scope('yolo-v3'): - route, inputs = _yolo_block(inputs, 512) + route, inputs = _yolo_block(inputs, 512, data_format, with_spp) + detect_1 = _detection_layer( inputs, num_classes, _ANCHORS[6:9], img_size, data_format) detect_1 = tf.identity(detect_1, name='detect_1') @@ -260,3 +275,18 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa detections = tf.concat([detect_1, detect_2, detect_3], axis=1) detections = tf.identity(detections, name='detections') return detections + + +def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False): + """ + Creates YOLO v3 with SPP model. + + :param inputs: a 4-D tensor of size [batch_size, height, width, channels]. + Dimension batch_size may be undefined. The channel order is RGB. + :param num_classes: number of predicted classes. + :param is_training: whether is training or not. + :param data_format: data format NCHW or NHWC. + :param reuse: whether or not the network and its variables should be reused. + :return: + """ + return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)