diff --git a/tools/demo.py b/tools/demo.py index 5799e8a..4c8b8c9 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -6,6 +6,7 @@ import cv2 import numpy as np import torch +from tqdm import tqdm from loguru import logger from PIL import Image @@ -23,7 +24,6 @@ class Infer(): def __init__(self, config, infer_size=[640,640], device='cuda', output_dir='./', ckpt=None, end2end=False): - self.ckpt_path = ckpt suffix = ckpt.split('.')[-1] if suffix == 'onnx': @@ -66,9 +66,7 @@ def _pad_image(self, img, target_size): return ImageList(pad_imgs, img_sizes, pad_sizes) - def _build_engine(self, config, engine_type): - print(f'Inference with {engine_type} engine!') if engine_type == 'torch': model = build_local_model(config, self.device) @@ -88,7 +86,6 @@ def _build_engine(self, config, engine_type): return model def build_tensorRT_engine(self, trt_path): - import tensorrt as trt from cuda import cuda loggert = trt.Logger(trt.Logger.INFO) @@ -145,11 +142,7 @@ def predict(batch): # result gets copied into output return predict - - - def build_onnx_engine(self, onnx_path): - import onnxruntime session = onnxruntime.InferenceSession(onnx_path) @@ -166,7 +159,6 @@ def build_onnx_engine(self, onnx_path): def preprocess(self, origin_img): - img = transform_img(origin_img, 0, **self.config.test.augment.transform, infer_size=self.infer_size) @@ -178,10 +170,8 @@ def preprocess(self, origin_img): return img, (ow, oh) def postprocess(self, preds, image, origin_shape=None): - if self.engine_type == 'torch': output = preds - elif self.engine_type == 'onnx': scores = torch.Tensor(preds[0]) bboxes = torch.Tensor(preds[1]) @@ -225,9 +215,7 @@ def postprocess(self, preds, image, origin_shape=None): return bboxes, scores, cls_inds - def forward(self, origin_image): - image, origin_shape = self.preprocess(origin_image) if self.engine_type == 'torch': @@ -301,7 +289,10 @@ def make_parser(): default=True, type=bool, help='whether save visualization results') - + parser.add_argument('--vid_stride', + default=1, + type=int, + help='video frame-rate stride') return parser @@ -334,22 +325,24 @@ def main(): vid_writer = cv2.VideoWriter( save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))) - while True: - ret_val, frame = cap.read() - if ret_val: - bboxes, scores, cls_inds = infer_engine.forward(frame) - result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False) - if args.save_result: - vid_writer.write(result_frame) + for i in tqdm(range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))): + # ret_val, frame = cap.read() + cap.grab() # .read() = .grab() followed by .retrieve() + if i % args.vid_stride == 0: + ret_val, frame = cap.retrieve() + if ret_val: + bboxes, scores, cls_inds = infer_engine.forward(frame) + result_frame = infer_engine.visualize(frame, bboxes, scores, cls_inds, conf=args.conf, save_result=False) + if args.save_result: + vid_writer.write(result_frame) + else: + cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL) + cv2.imshow("DAMO-YOLO", result_frame) + ch = cv2.waitKey(1) + if ch == 27 or ch == ord("q") or ch == ord("Q"): + break else: - cv2.namedWindow("DAMO-YOLO", cv2.WINDOW_NORMAL) - cv2.imshow("DAMO-YOLO", result_frame) - ch = cv2.waitKey(1) - if ch == 27 or ch == ord("q") or ch == ord("Q"): break - else: - break - if __name__ == '__main__':