Skip to content

Commit f088ec7

Browse files
committed
fix typo, add latency log
1 parent bf1d851 commit f088ec7

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

yolov5/yolov5_trt.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import torch
1717
import torchvision
1818

19-
INPUT_W = 608
20-
INPUT_H = 608
21-
CONF_THRESH = 0.1
19+
INPUT_W = 640
20+
INPUT_H = 640
21+
CONF_THRESH = 0.5
2222
IOU_THRESHOLD = 0.4
2323

2424

@@ -66,7 +66,7 @@ class YoLov5TRT(object):
6666

6767
def __init__(self, engine_file_path):
6868
# Create a Context on this device,
69-
self.cfx = cuda.Device(0).make_context()
69+
self.ctx = cuda.Device(0).make_context()
7070
stream = cuda.Stream()
7171
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
7272
runtime = trt.Runtime(TRT_LOGGER)
@@ -111,7 +111,7 @@ def __init__(self, engine_file_path):
111111
def infer(self, input_image_path):
112112
threading.Thread.__init__(self)
113113
# Make self the active context, pushing it on top of the context stack.
114-
self.cfx.push()
114+
self.ctx.push()
115115
# Restore
116116
stream = self.stream
117117
context = self.context
@@ -127,6 +127,7 @@ def infer(self, input_image_path):
127127
)
128128
# Copy input image to host buffer
129129
np.copyto(host_inputs[0], input_image.ravel())
130+
start = time.time()
130131
# Transfer input data to the GPU.
131132
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
132133
# Run inference.
@@ -135,8 +136,9 @@ def infer(self, input_image_path):
135136
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
136137
# Synchronize the stream
137138
stream.synchronize()
139+
end = time.time()
138140
# Remove any context from the top of the context stack, deactivating it.
139-
self.cfx.pop()
141+
self.ctx.pop()
140142
# Here we use the first row of output in that batch_size = 1
141143
output = host_outputs[0]
142144
# Do postprocess
@@ -155,12 +157,13 @@ def infer(self, input_image_path):
155157
)
156158
parent, filename = os.path.split(input_image_path)
157159
save_name = os.path.join(parent, "output_" + filename)
158-
#  Save image
160+
# Save image
159161
cv2.imwrite(save_name, image_raw)
162+
print('{:.2f}ms, saving {}'.format((end - start) * 1000, save_name))
160163

161164
def destroy(self):
162165
# Remove any context from the top of the context stack, deactivating it.
163-
self.cfx.pop()
166+
self.ctx.pop()
164167

165168
def preprocess_image(self, input_image_path):
166169
"""
@@ -308,8 +311,7 @@ def run(self):
308311
# a YoLov5TRT instance
309312
yolov5_wrapper = YoLov5TRT(engine_file_path)
310313

311-
# from https://github.com/ultralytics/yolov5/tree/master/inference/images
312-
input_image_paths = ["zidane.jpg", "bus.jpg"]
314+
input_image_paths = ["samples/zidane.jpg", "samples/bus.jpg"]
313315

314316
for input_image_path in input_image_paths:
315317
# create a new thread to do inference

0 commit comments

Comments
 (0)