Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get input image shape statically in the model construction function #101

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ To run demo type this in the command line:
1. Use yolov3-spp
6. `--ckpt_file`
1. Output checkpoint file
7. `--size`
1. Input image size
2. convert_weights_pb.py:
1. `--class_names`
1. Path to the class names file
Expand All @@ -51,6 +53,8 @@ To run demo type this in the command line:
1. Use yolov3-spp
6. `--output_graph`
1. Location to write the output .pb graph to
7. `--size`
1. Input image size
3. demo.py
1. `--class_names`
1. Path to the class names file
Expand All @@ -68,3 +72,5 @@ To run demo type this in the command line:
1. Desired iou threshold
8. `--gpu_memory_fraction`
1. Fraction of gpu memory to work with
9. `--size`
1. Input image size
5 changes: 4 additions & 1 deletion convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
'spp', False, 'Use SPP version of YOLOv3')
tf.app.flags.DEFINE_string(
'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file')
tf.app.flags.DEFINE_integer(
'size', 416, 'Input Image size')


def main(argv=None):
Expand All @@ -39,7 +41,8 @@ def main(argv=None):

with tf.variable_scope('detector'):
detections = model(inputs, len(classes),
data_format=FLAGS.data_format)
data_format=FLAGS.data_format,
img_size=[FLAGS.size, FLAGS.size])
load_ops = load_weights(tf.global_variables(
scope='detector'), FLAGS.weights_file)

Expand Down
5 changes: 3 additions & 2 deletions convert_weights_pb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
'tiny', False, 'Use tiny version of YOLOv3')
tf.app.flags.DEFINE_bool(
'spp', False, 'Use SPP version of YOLOv3')

tf.app.flags.DEFINE_integer(
'size', 416, 'Image size')
'size', 416, 'Input image size')



Expand All @@ -42,7 +43,7 @@ def main(argv=None):
inputs = tf.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs")

with tf.variable_scope('detector'):
detections = model(inputs, len(classes), data_format=FLAGS.data_format)
detections = model(inputs, len(classes), data_format=FLAGS.data_format, img_size=[FLAGS.size, FLAGS.size])
load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file)

# Sets the output nodes in the current session
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'spp', False, 'Use SPP version of YOLOv3')

tf.app.flags.DEFINE_integer(
'size', 416, 'Image size')
'size', 416, 'Input image size')

tf.app.flags.DEFINE_float(
'conf_threshold', 0.5, 'Confidence threshold')
Expand Down
9 changes: 4 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def get_boxes_and_inputs(model, num_classes, size, data_format):

with tf.variable_scope('detector'):
detections = model(inputs, num_classes,
data_format=data_format)
data_format=data_format,
img_size=[size, size])

boxes = detections_boxes(detections)

Expand Down Expand Up @@ -185,10 +186,8 @@ def non_max_suppression(predictions_with_boxes, confidence_threshold, iou_thresh

result = {}
for i, image_pred in enumerate(predictions):
shape = image_pred.shape
non_zero_idxs = np.nonzero(image_pred)
image_pred = image_pred[non_zero_idxs]
image_pred = image_pred.reshape(-1, shape[-1])
# Remove predictions if all the prediction vector is zero
image_pred = image_pred[np.any(image_pred, axis=-1)]

bbox_attrs = image_pred[:, :5]
classes = image_pred[:, 5:]
Expand Down
8 changes: 3 additions & 5 deletions yolo_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'):
return inputs


def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False):
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False, img_size=[416, 416]):
"""
Creates YOLO v3 model.

Expand All @@ -213,8 +213,6 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
:param with_spp: whether or not is using spp layer.
:return:
"""
# it will be needed later on
img_size = inputs.get_shape().as_list()[1:3]

# transpose the inputs to NCHW
if data_format == 'NCHW':
Expand Down Expand Up @@ -277,7 +275,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
return detections


def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
"""
Creates YOLO v3 with SPP model.

Expand All @@ -289,4 +287,4 @@ def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reus
:param reuse: whether or not the network and its variables should be reused.
:return:
"""
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True, img_size=img_size)
4 changes: 1 addition & 3 deletions yolo_v3_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
(81, 82), (135, 169), (344, 319)]


def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
"""
Creates YOLO v3 tiny model.

Expand All @@ -27,8 +27,6 @@ def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reu
:param reuse: whether or not the network and its variables should be reused.
:return:
"""
# it will be needed later on
img_size = inputs.get_shape().as_list()[1:3]

# transpose the inputs to NCHW
if data_format == 'NCHW':
Expand Down