Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[func add] add vid_stride argument #118

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cv2
import numpy as np
import torch
from tqdm import tqdm
from loguru import logger
from PIL import Image

Expand All @@ -23,7 +24,6 @@

class Infer():
def __init__(self, config, infer_size=[640,640], device='cuda', output_dir='./', ckpt=None, end2end=False):

self.ckpt_path = ckpt
suffix = ckpt.split('.')[-1]
if suffix == 'onnx':
Expand Down Expand Up @@ -66,9 +66,7 @@ def _pad_image(self, img, target_size):

return ImageList(pad_imgs, img_sizes, pad_sizes)


def _build_engine(self, config, engine_type):

print(f'Inference with {engine_type} engine!')
if engine_type == 'torch':
model = build_local_model(config, self.device)
Expand All @@ -88,7 +86,6 @@ def _build_engine(self, config, engine_type):
return model

def build_tensorRT_engine(self, trt_path):

import tensorrt as trt
from cuda import cuda
loggert = trt.Logger(trt.Logger.INFO)
Expand Down Expand Up @@ -145,11 +142,7 @@ def predict(batch): # result gets copied into output

return predict




def build_onnx_engine(self, onnx_path):

import onnxruntime

session = onnxruntime.InferenceSession(onnx_path)
Expand All @@ -166,7 +159,6 @@ def build_onnx_engine(self, onnx_path):


def preprocess(self, origin_img):

img = transform_img(origin_img, 0,
**self.config.test.augment.transform,
infer_size=self.infer_size)
Expand All @@ -178,10 +170,8 @@ def preprocess(self, origin_img):
return img, (ow, oh)

def postprocess(self, preds, image, origin_shape=None):

if self.engine_type == 'torch':
output = preds

elif self.engine_type == 'onnx':
scores = torch.Tensor(preds[0])
bboxes = torch.Tensor(preds[1])
Expand Down Expand Up @@ -225,9 +215,7 @@ def postprocess(self, preds, image, origin_shape=None):

return bboxes, scores, cls_inds


def forward(self, origin_image):

image, origin_shape = self.preprocess(origin_image)

if self.engine_type == 'torch':
Expand Down Expand Up @@ -301,7 +289,10 @@ def make_parser():
default=True,
type=bool,
help='whether save visualization results')

parser.add_argument('--vid_stride',
default=1,
type=int,
help='video frame-rate stride')

return parser

Expand Down Expand Up @@ -334,22 +325,24 @@ def main():
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"),
fps, (int(width), int(height)))
while True:
ret_val, frame = cap.read()
if ret_val:
bboxes, scores, cls_inds = infer_engine.forward(frame)
result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False)
if args.save_result:
vid_writer.write(result_frame)
for i in tqdm(range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))):
# ret_val, frame = cap.read()
cap.grab() # .read() = .grab() followed by .retrieve()
if i % args.vid_stride == 0:
ret_val, frame = cap.retrieve()
if ret_val:
bboxes, scores, cls_inds = infer_engine.forward(frame)
result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False)
if args.save_result:
vid_writer.write(result_frame)
else:
cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL)
cv2.imshow("DAMO-YOLO", result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL)
cv2.imshow("DAMO-YOLO", result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break



if __name__ == '__main__':
Expand Down