@@ -123,7 +123,7 @@ def __init__(self, engine_file_path):
123
123
self .bindings = bindings
124
124
self .batch_size = engine .max_batch_size
125
125
126
- def infer (self , image_path_batch ):
126
+ def infer (self , raw_image_generator ):
127
127
threading .Thread .__init__ (self )
128
128
# Make self the active context, pushing it on top of the context stack.
129
129
self .ctx .push ()
@@ -141,8 +141,8 @@ def infer(self, image_path_batch):
141
141
batch_origin_h = []
142
142
batch_origin_w = []
143
143
batch_input_image = np .empty (shape = [self .batch_size , 3 , self .input_h , self .input_w ])
144
- for i , img_path in enumerate (image_path_batch ):
145
- input_image , image_raw , origin_h , origin_w = self .preprocess_image (img_path )
144
+ for i , image_raw in enumerate (raw_image_generator ):
145
+ input_image , image_raw , origin_h , origin_w = self .preprocess_image (image_raw )
146
146
batch_image_raw .append (image_raw )
147
147
batch_origin_h .append (origin_h )
148
148
batch_origin_w .append (origin_w )
@@ -166,7 +166,7 @@ def infer(self, image_path_batch):
166
166
# Here we use the first row of output in that batch_size = 1
167
167
output = host_outputs [0 ]
168
168
# Do postprocess
169
- for i , img_path in enumerate ( image_path_batch ):
169
+ for i in range ( self . batch_size ):
170
170
result_boxes , result_scores , result_classid = self .post_process (
171
171
output [i * 6001 : (i + 1 ) * 6001 ], batch_origin_h [i ], batch_origin_w [i ]
172
172
)
@@ -180,19 +180,29 @@ def infer(self, image_path_batch):
180
180
categories [int (result_classid [j ])], result_scores [j ]
181
181
),
182
182
)
183
- parent , filename = os .path .split (img_path )
184
- save_name = os .path .join ('output' , filename )
185
- # Save image
186
- cv2 .imwrite (save_name , batch_image_raw [i ])
187
- print ('input->{}, time->{:.2f}ms, saving into output/' .format (image_path_batch , (end - start ) * 1000 ))
183
+ return batch_image_raw , end - start
188
184
189
185
def destroy (self ):
190
186
# Remove any context from the top of the context stack, deactivating it.
191
187
self .ctx .pop ()
188
+
189
+ def get_raw_image (self , image_path_batch ):
190
+ """
191
+ description: Read an image from image path
192
+ """
193
+ for img_path in image_path_batch :
194
+ yield cv2 .imread (img_path )
195
+
196
+ def get_raw_image_zeros (self , image_path_batch = None ):
197
+ """
198
+ description: Ready data for warmup
199
+ """
200
+ for _ in range (self .batch_size ):
201
+ yield np .zeros ([self .input_h , self .input_w , 3 ], dtype = np .uint8 )
192
202
193
- def preprocess_image (self , input_image_path ):
203
+ def preprocess_image (self , raw_bgr_image ):
194
204
"""
195
- description: Read an image from image path, convert it to RGB,
205
+ description: Convert BGR image to RGB,
196
206
resize and pad it to target size, normalize to [0,1],
197
207
transform to NCHW format.
198
208
param:
@@ -203,7 +213,7 @@ def preprocess_image(self, input_image_path):
203
213
h: original height
204
214
w: original width
205
215
"""
206
- image_raw = cv2 . imread ( input_image_path )
216
+ image_raw = raw_bgr_image
207
217
h , w , c = image_raw .shape
208
218
image = cv2 .cvtColor (image_raw , cv2 .COLOR_BGR2RGB )
209
219
# Calculate widht and height and paddings
@@ -305,22 +315,45 @@ def post_process(self, output, origin_h, origin_w):
305
315
return result_boxes , result_scores , result_classid
306
316
307
317
308
- class myThread (threading .Thread ):
309
- def __init__ (self , func , args ):
318
+ class inferThread (threading .Thread ):
319
+ def __init__ (self , yolov5_wrapper , image_path_batch ):
310
320
threading .Thread .__init__ (self )
311
- self .func = func
312
- self .args = args
321
+ self .yolov5_wrapper = yolov5_wrapper
322
+ self .image_path_batch = image_path_batch
313
323
314
324
def run (self ):
315
- self .func (* self .args )
325
+ batch_image_raw , use_time = self .yolov5_wrapper .infer (self .yolov5_wrapper .get_raw_image (self .image_path_batch ))
326
+ for i , img_path in enumerate (self .image_path_batch ):
327
+ parent , filename = os .path .split (img_path )
328
+ save_name = os .path .join ('output' , filename )
329
+ # Save image
330
+ cv2 .imwrite (save_name , batch_image_raw [i ])
331
+ print ('input->{}, time->{:.2f}ms, saving into output/' .format (self .image_path_batch , use_time * 1000 ))
332
+
333
+
334
+ class warmUpThread (threading .Thread ):
335
+ def __init__ (self , yolov5_wrapper ):
336
+ threading .Thread .__init__ (self )
337
+ self .yolov5_wrapper = yolov5_wrapper
338
+
339
+ def run (self ):
340
+ batch_image_raw , use_time = self .yolov5_wrapper .infer (self .yolov5_wrapper .get_raw_image_zeros ())
341
+ print ('warm_up->{}, time->{:.2f}ms' .format (batch_image_raw [0 ].shape , use_time * 1000 ))
342
+
316
343
317
344
318
345
if __name__ == "__main__" :
319
346
# load custom plugins
320
347
PLUGIN_LIBRARY = "build/libmyplugins.so"
321
- ctypes .CDLL (PLUGIN_LIBRARY )
322
348
engine_file_path = "build/yolov5s.engine"
323
349
350
+ if len (sys .argv ) > 1 :
351
+ engine_file_path = sys .argv [1 ]
352
+ if len (sys .argv ) > 2 :
353
+ PLUGIN_LIBRARY = sys .argv [2 ]
354
+
355
+ ctypes .CDLL (PLUGIN_LIBRARY )
356
+
324
357
# load coco labels
325
358
326
359
categories = ["person" , "bicycle" , "car" , "motorcycle" , "airplane" , "bus" , "train" , "truck" , "boat" , "traffic light" ,
@@ -338,15 +371,22 @@ def run(self):
338
371
os .makedirs ('output/' )
339
372
# a YoLov5TRT instance
340
373
yolov5_wrapper = YoLov5TRT (engine_file_path )
341
- print ('batch size is' , yolov5_wrapper .batch_size )
342
- image_dir = "samples/"
343
- image_path_batches = get_img_path_batches (yolov5_wrapper .batch_size , image_dir )
344
-
345
- for batch in image_path_batches :
346
- # create a new thread to do inference
347
- thread1 = myThread (yolov5_wrapper .infer , [batch ])
348
- thread1 .start ()
349
- thread1 .join ()
350
-
351
- # destroy the instance
352
- yolov5_wrapper .destroy ()
374
+ try :
375
+ print ('batch size is' , yolov5_wrapper .batch_size )
376
+
377
+ image_dir = "samples/"
378
+ image_path_batches = get_img_path_batches (yolov5_wrapper .batch_size , image_dir )
379
+
380
+ for i in range (10 ):
381
+ # create a new thread to do warm_up
382
+ thread1 = warmUpThread (yolov5_wrapper )
383
+ thread1 .start ()
384
+ thread1 .join ()
385
+ for batch in image_path_batches :
386
+ # create a new thread to do inference
387
+ thread1 = inferThread (yolov5_wrapper , batch )
388
+ thread1 .start ()
389
+ thread1 .join ()
390
+ finally :
391
+ # destroy the instance
392
+ yolov5_wrapper .destroy ()
0 commit comments