16
16
import torch
17
17
import torchvision
18
18
19
- INPUT_W = 608
20
- INPUT_H = 608
21
- CONF_THRESH = 0.1
19
+ INPUT_W = 640
20
+ INPUT_H = 640
21
+ CONF_THRESH = 0.5
22
22
IOU_THRESHOLD = 0.4
23
23
24
24
@@ -66,7 +66,7 @@ class YoLov5TRT(object):
66
66
67
67
def __init__ (self , engine_file_path ):
68
68
# Create a Context on this device,
69
- self .cfx = cuda .Device (0 ).make_context ()
69
+ self .ctx = cuda .Device (0 ).make_context ()
70
70
stream = cuda .Stream ()
71
71
TRT_LOGGER = trt .Logger (trt .Logger .INFO )
72
72
runtime = trt .Runtime (TRT_LOGGER )
@@ -111,7 +111,7 @@ def __init__(self, engine_file_path):
111
111
def infer (self , input_image_path ):
112
112
threading .Thread .__init__ (self )
113
113
# Make self the active context, pushing it on top of the context stack.
114
- self .cfx .push ()
114
+ self .ctx .push ()
115
115
# Restore
116
116
stream = self .stream
117
117
context = self .context
@@ -127,6 +127,7 @@ def infer(self, input_image_path):
127
127
)
128
128
# Copy input image to host buffer
129
129
np .copyto (host_inputs [0 ], input_image .ravel ())
130
+ start = time .time ()
130
131
# Transfer input data to the GPU.
131
132
cuda .memcpy_htod_async (cuda_inputs [0 ], host_inputs [0 ], stream )
132
133
# Run inference.
@@ -135,8 +136,9 @@ def infer(self, input_image_path):
135
136
cuda .memcpy_dtoh_async (host_outputs [0 ], cuda_outputs [0 ], stream )
136
137
# Synchronize the stream
137
138
stream .synchronize ()
139
+ end = time .time ()
138
140
# Remove any context from the top of the context stack, deactivating it.
139
- self .cfx .pop ()
141
+ self .ctx .pop ()
140
142
# Here we use the first row of output in that batch_size = 1
141
143
output = host_outputs [0 ]
142
144
# Do postprocess
@@ -155,12 +157,13 @@ def infer(self, input_image_path):
155
157
)
156
158
parent , filename = os .path .split (input_image_path )
157
159
save_name = os .path .join (parent , "output_" + filename )
158
- # Save image
160
+ # Save image
159
161
cv2 .imwrite (save_name , image_raw )
162
+ print ('{:.2f}ms, saving {}' .format ((end - start ) * 1000 , save_name ))
160
163
161
164
def destroy (self ):
162
165
# Remove any context from the top of the context stack, deactivating it.
163
- self .cfx .pop ()
166
+ self .ctx .pop ()
164
167
165
168
def preprocess_image (self , input_image_path ):
166
169
"""
@@ -308,8 +311,7 @@ def run(self):
308
311
# a YoLov5TRT instance
309
312
yolov5_wrapper = YoLov5TRT (engine_file_path )
310
313
311
- # from https://github.com/ultralytics/yolov5/tree/master/inference/images
312
- input_image_paths = ["zidane.jpg" , "bus.jpg" ]
314
+ input_image_paths = ["samples/zidane.jpg" , "samples/bus.jpg" ]
313
315
314
316
for input_image_path in input_image_paths :
315
317
# create a new thread to do inference
0 commit comments