diff --git a/convert_weights.py b/convert_weights.py index 01f656e..a752ded 100644 --- a/convert_weights.py +++ b/convert_weights.py @@ -21,6 +21,8 @@ 'spp', False, 'Use SPP version of YOLOv3') tf.app.flags.DEFINE_string( 'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file') +tf.app.flags.DEFINE_integer( + 'size', 416, 'Image size') def main(argv=None): @@ -39,7 +41,8 @@ def main(argv=None): with tf.variable_scope('detector'): detections = model(inputs, len(classes), - data_format=FLAGS.data_format) + data_format=FLAGS.data_format, + img_size=[FLAGS.size, FLAGS.size]) load_ops = load_weights(tf.global_variables( scope='detector'), FLAGS.weights_file) diff --git a/convert_weights_pb.py b/convert_weights_pb.py index 1ade65a..9aa4e5b 100644 --- a/convert_weights_pb.py +++ b/convert_weights_pb.py @@ -23,6 +23,7 @@ 'tiny', False, 'Use tiny version of YOLOv3') tf.app.flags.DEFINE_bool( 'spp', False, 'Use SPP version of YOLOv3') + tf.app.flags.DEFINE_integer( 'size', 416, 'Image size') @@ -42,7 +43,7 @@ def main(argv=None): inputs = tf.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs") with tf.variable_scope('detector'): - detections = model(inputs, len(classes), data_format=FLAGS.data_format) + detections = model(inputs, len(classes), data_format=FLAGS.data_format, img_size=[FLAGS.size, FLAGS.size]) load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file) # Sets the output nodes in the current session diff --git a/utils.py b/utils.py index c4cce93..6e5bfc8 100644 --- a/utils.py +++ b/utils.py @@ -20,7 +20,8 @@ def get_boxes_and_inputs(model, num_classes, size, data_format): with tf.variable_scope('detector'): detections = model(inputs, num_classes, - data_format=data_format) + data_format=data_format, + img_size=[size, size]) boxes = detections_boxes(detections) diff --git a/yolo_v3.py b/yolo_v3.py index ef1d1a7..15992e5 100644 --- a/yolo_v3.py +++ b/yolo_v3.py @@ -200,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'): return inputs -def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False): +def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False, img_size=[416, 416]): """ Creates YOLO v3 model. @@ -213,8 +213,6 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa :param with_spp: whether or not is using spp layer. :return: """ - # it will be needed later on - img_size = inputs.get_shape().as_list()[1:3] # transpose the inputs to NCHW if data_format == 'NCHW': @@ -277,7 +275,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa return detections -def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False): +def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]): """ Creates YOLO v3 with SPP model. @@ -289,4 +287,4 @@ def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reus :param reuse: whether or not the network and its variables should be reused. :return: """ - return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True) + return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True, img_size=img_size) diff --git a/yolo_v3_tiny.py b/yolo_v3_tiny.py index 5f308f5..4ea996e 100644 --- a/yolo_v3_tiny.py +++ b/yolo_v3_tiny.py @@ -15,7 +15,7 @@ (81, 82), (135, 169), (344, 319)] -def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False): +def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]): """ Creates YOLO v3 tiny model. @@ -27,8 +27,6 @@ def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reu :param reuse: whether or not the network and its variables should be reused. :return: """ - # it will be needed later on - img_size = inputs.get_shape().as_list()[1:3] # transpose the inputs to NCHW if data_format == 'NCHW':