Skip to content

Commit b58a798

Browse files
authored
add a simple warm_up in yolov5_trt.py (wang-xinyu#513)
* add a simple warm_up in yolov5_trt.py * remove if-else condition in infer() * be careful with bgr and rgb image
1 parent e8653a7 commit b58a798

File tree

1 file changed

+70
-30
lines changed

1 file changed

+70
-30
lines changed

yolov5/yolov5_trt.py

+70-30
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(self, engine_file_path):
123123
self.bindings = bindings
124124
self.batch_size = engine.max_batch_size
125125

126-
def infer(self, image_path_batch):
126+
def infer(self, raw_image_generator):
127127
threading.Thread.__init__(self)
128128
# Make self the active context, pushing it on top of the context stack.
129129
self.ctx.push()
@@ -141,8 +141,8 @@ def infer(self, image_path_batch):
141141
batch_origin_h = []
142142
batch_origin_w = []
143143
batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])
144-
for i, img_path in enumerate(image_path_batch):
145-
input_image, image_raw, origin_h, origin_w = self.preprocess_image(img_path)
144+
for i, image_raw in enumerate(raw_image_generator):
145+
input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw)
146146
batch_image_raw.append(image_raw)
147147
batch_origin_h.append(origin_h)
148148
batch_origin_w.append(origin_w)
@@ -166,7 +166,7 @@ def infer(self, image_path_batch):
166166
# Here we use the first row of output in that batch_size = 1
167167
output = host_outputs[0]
168168
# Do postprocess
169-
for i, img_path in enumerate(image_path_batch):
169+
for i in range(self.batch_size):
170170
result_boxes, result_scores, result_classid = self.post_process(
171171
output[i * 6001: (i + 1) * 6001], batch_origin_h[i], batch_origin_w[i]
172172
)
@@ -180,19 +180,29 @@ def infer(self, image_path_batch):
180180
categories[int(result_classid[j])], result_scores[j]
181181
),
182182
)
183-
parent, filename = os.path.split(img_path)
184-
save_name = os.path.join('output', filename)
185-
# Save image
186-
cv2.imwrite(save_name, batch_image_raw[i])
187-
print('input->{}, time->{:.2f}ms, saving into output/'.format(image_path_batch, (end - start) * 1000))
183+
return batch_image_raw, end - start
188184

189185
def destroy(self):
190186
# Remove any context from the top of the context stack, deactivating it.
191187
self.ctx.pop()
188+
189+
def get_raw_image(self, image_path_batch):
190+
"""
191+
description: Read an image from image path
192+
"""
193+
for img_path in image_path_batch:
194+
yield cv2.imread(img_path)
195+
196+
def get_raw_image_zeros(self, image_path_batch=None):
197+
"""
198+
description: Ready data for warmup
199+
"""
200+
for _ in range(self.batch_size):
201+
yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8)
192202

193-
def preprocess_image(self, input_image_path):
203+
def preprocess_image(self, raw_bgr_image):
194204
"""
195-
description: Read an image from image path, convert it to RGB,
205+
description: Convert BGR image to RGB,
196206
resize and pad it to target size, normalize to [0,1],
197207
transform to NCHW format.
198208
param:
@@ -203,7 +213,7 @@ def preprocess_image(self, input_image_path):
203213
h: original height
204214
w: original width
205215
"""
206-
image_raw = cv2.imread(input_image_path)
216+
image_raw = raw_bgr_image
207217
h, w, c = image_raw.shape
208218
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
209219
# Calculate widht and height and paddings
@@ -305,22 +315,45 @@ def post_process(self, output, origin_h, origin_w):
305315
return result_boxes, result_scores, result_classid
306316

307317

308-
class myThread(threading.Thread):
309-
def __init__(self, func, args):
318+
class inferThread(threading.Thread):
319+
def __init__(self, yolov5_wrapper, image_path_batch):
310320
threading.Thread.__init__(self)
311-
self.func = func
312-
self.args = args
321+
self.yolov5_wrapper = yolov5_wrapper
322+
self.image_path_batch = image_path_batch
313323

314324
def run(self):
315-
self.func(*self.args)
325+
batch_image_raw, use_time = self.yolov5_wrapper.infer(self.yolov5_wrapper.get_raw_image(self.image_path_batch))
326+
for i, img_path in enumerate(self.image_path_batch):
327+
parent, filename = os.path.split(img_path)
328+
save_name = os.path.join('output', filename)
329+
# Save image
330+
cv2.imwrite(save_name, batch_image_raw[i])
331+
print('input->{}, time->{:.2f}ms, saving into output/'.format(self.image_path_batch, use_time * 1000))
332+
333+
334+
class warmUpThread(threading.Thread):
335+
def __init__(self, yolov5_wrapper):
336+
threading.Thread.__init__(self)
337+
self.yolov5_wrapper = yolov5_wrapper
338+
339+
def run(self):
340+
batch_image_raw, use_time = self.yolov5_wrapper.infer(self.yolov5_wrapper.get_raw_image_zeros())
341+
print('warm_up->{}, time->{:.2f}ms'.format(batch_image_raw[0].shape, use_time * 1000))
342+
316343

317344

318345
if __name__ == "__main__":
319346
# load custom plugins
320347
PLUGIN_LIBRARY = "build/libmyplugins.so"
321-
ctypes.CDLL(PLUGIN_LIBRARY)
322348
engine_file_path = "build/yolov5s.engine"
323349

350+
if len(sys.argv) > 1:
351+
engine_file_path = sys.argv[1]
352+
if len(sys.argv) > 2:
353+
PLUGIN_LIBRARY = sys.argv[2]
354+
355+
ctypes.CDLL(PLUGIN_LIBRARY)
356+
324357
# load coco labels
325358

326359
categories = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
@@ -338,15 +371,22 @@ def run(self):
338371
os.makedirs('output/')
339372
# a YoLov5TRT instance
340373
yolov5_wrapper = YoLov5TRT(engine_file_path)
341-
print('batch size is', yolov5_wrapper.batch_size)
342-
image_dir = "samples/"
343-
image_path_batches = get_img_path_batches(yolov5_wrapper.batch_size, image_dir)
344-
345-
for batch in image_path_batches:
346-
# create a new thread to do inference
347-
thread1 = myThread(yolov5_wrapper.infer, [batch])
348-
thread1.start()
349-
thread1.join()
350-
351-
# destroy the instance
352-
yolov5_wrapper.destroy()
374+
try:
375+
print('batch size is', yolov5_wrapper.batch_size)
376+
377+
image_dir = "samples/"
378+
image_path_batches = get_img_path_batches(yolov5_wrapper.batch_size, image_dir)
379+
380+
for i in range(10):
381+
# create a new thread to do warm_up
382+
thread1 = warmUpThread(yolov5_wrapper)
383+
thread1.start()
384+
thread1.join()
385+
for batch in image_path_batches:
386+
# create a new thread to do inference
387+
thread1 = inferThread(yolov5_wrapper, batch)
388+
thread1.start()
389+
thread1.join()
390+
finally:
391+
# destroy the instance
392+
yolov5_wrapper.destroy()

0 commit comments

Comments
 (0)