Skip to content

Commit

Permalink
ibaiGorordo#8 Add support for the official YOLOv7 ONNX NMS format
Browse files Browse the repository at this point in the history
  • Loading branch information
ibaiGorordo committed Nov 19, 2022
1 parent 572b95f commit a1eb903
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
23 changes: 15 additions & 8 deletions YOLOv7/YOLOv7.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np
import onnxruntime

from .utils import xywh2xyxy, nms, draw_detections
from yolov7.utils import xywh2xyxy, nms, draw_detections


class YOLOv7:

def __init__(self, path, conf_thres=0.7, iou_thres=0.5):
def __init__(self, path, conf_thres=0.7, iou_thres=0.5, official_nms=False):
self.conf_threshold = conf_thres
self.iou_threshold = iou_thres
self.official_nms = official_nms

# Initialize model
self.initialize_model(path)
Expand All @@ -26,7 +27,7 @@ def initialize_model(self, path):
self.get_input_details()
self.get_output_details()

self.has_postprocess = 'score' in self.output_names
self.has_postprocess = 'score' in self.output_names or self.official_nms


def detect_objects(self, image):
Expand Down Expand Up @@ -64,7 +65,7 @@ def inference(self, input_tensor):
start = time.perf_counter()
outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor})

# print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
return outputs

def process_output(self, output):
Expand Down Expand Up @@ -101,8 +102,13 @@ def process_output(self, output):

def parse_processed_output(self, outputs):

scores = np.squeeze(outputs[0], axis=1)
predictions = outputs[1]
#Pinto's postprocessing is different from the official nms version
if self.official_nms:
scores = outputs[0][:,-1]
predictions = outputs[0][:, [0,5,1,2,3,4]]
else:
scores = np.squeeze(outputs[0], axis=1)
predictions = outputs[1]
# Filter out object scores below threshold
valid_scores = scores > self.conf_threshold
predictions = predictions[valid_scores, :]
Expand All @@ -114,11 +120,12 @@ def parse_processed_output(self, outputs):
# Extract the boxes and class ids
# TODO: Separate based on batch number
batch_number = predictions[:, 0]
class_ids = predictions[:, 1]
class_ids = predictions[:, 1].astype(int)
boxes = predictions[:, 2:]

# In postprocess, the x,y are the y,x
boxes = boxes[:, [1, 0, 3, 2]]
if not self.official_nms:
boxes = boxes[:, [1, 0, 3, 2]]

# Rescale boxes to original image dimensions
boxes = self.rescale_boxes(boxes)
Expand Down
10 changes: 4 additions & 6 deletions comparison_with_yolov5_v6.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import cv2
import pafy
from cap_from_youtube import cap_from_youtube

from YOLOv7 import YOLOv7
from YOLOv7.utils import draw_comparison
from yolov7 import YOLOv7
from yolov7.utils import draw_comparison

# Initialize video
# cap = cv2.VideoCapture("input.mp4")

videoUrl = 'https://youtu.be/zPre8MgmcHY'
videoPafy = pafy.new(videoUrl)
print(videoPafy.streams)
cap = cv2.VideoCapture(videoPafy.streams[-1].url)
cap = cap_from_youtube(videoUrl)
start_time = 0 # skip first {start_time} seconds
cap.set(cv2.CAP_PROP_POS_FRAMES, start_time * 60)

Expand Down
8 changes: 4 additions & 4 deletions image_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import cv2
from imread_from_url import imread_from_url

from YOLOv7 import YOLOv7
from yolov7 import YOLOv7

# Initialize YOLOv7 object detector
model_path = "models/yolov7_736x1280.onnx"
yolov7_detector = YOLOv7(model_path, conf_thres=0.3, iou_thres=0.5)
# Initialize yolov7 object detector
model_path = "models/yolov7_post_640x640.onnx"
yolov7_detector = YOLOv7(model_path, conf_thres=0.2, iou_thres=0.3)

# Read image
img_url = "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
Expand Down
2 changes: 1 addition & 1 deletion video_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cv2
from cap_from_youtube import cap_from_youtube

from YOLOv7 import YOLOv7
from yolov7 import YOLOv7

# # Initialize video
# cap = cv2.VideoCapture("input.mp4")
Expand Down
2 changes: 1 addition & 1 deletion webcam_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import cv2

from YOLOv7 import YOLOv7
from yolov7 import YOLOv7

# Initialize the webcam
cap = cv2.VideoCapture(0)
Expand Down

0 comments on commit a1eb903

Please sign in to comment.