Skip to content

Commit e381ae5

Browse files
authored
Retinaface python (wang-xinyu#458)
* Create retinaface_trt.py An example that uses TensorRT's Python api to make inferences. * Update retinaface_trt.py * Update README.md
1 parent a7f43a3 commit e381ae5

File tree

2 files changed

+311
-0
lines changed

2 files changed

+311
-0
lines changed

Diff for: retinaface/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ sudo ./retina_r50 -d // deserialize model file and run inference.
4444

4545
3. check the images generated, as follows. 0_result.jpg
4646

47+
4. we also provide a tensorrt model in python
48+
49+
```
50+
// install python-tensorrt, pycuda, etc.
51+
// ensure the retina_r50.engine and libdecodeplugin.so have been built
52+
python retinaface_trt.py
53+
```
54+
55+
56+
4757
# INT8 Quantization
4858

4959
1. Prepare calibration images, you can randomly select 1000s images from your train set. For widerface, you can also download my calibration images `widerface_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh

Diff for: retinaface/retinaface_trt.py

+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""
2+
Use TensorRT's Python api to make inferences.
3+
"""
4+
# -*- coding: utf-8 -*
5+
import ctypes
6+
import os
7+
import random
8+
import sys
9+
import threading
10+
import time
11+
12+
import cv2
13+
import numpy as np
14+
import pycuda.autoinit
15+
import pycuda.driver as cuda
16+
import tensorrt as trt
17+
import torch
18+
import torchvision
19+
20+
INPUT_H = 480 #defined in decode.h
21+
INPUT_W = 640
22+
CONF_THRESH = 0.4
23+
IOU_THRESHOLD = 0.1
24+
np.set_printoptions(threshold=np.inf)
25+
26+
def plot_one_box(x, landmark,img, color=None, label=None, line_thickness=None):
27+
"""
28+
description: Plots one bounding box on image img,
29+
30+
param:
31+
x: a box likes [x1,y1,x2,y2]
32+
img: a opencv image object
33+
color: color to draw rectangle, such as (0,255,0)
34+
label: str
35+
line_thickness: int
36+
return:
37+
no return
38+
39+
"""
40+
tl = (
41+
line_thickness or round(0.001 * (img.shape[0] + img.shape[1]) / 2) + 1
42+
) # line/font thickness
43+
44+
color = color or [random.randint(0, 255) for _ in range(3)]
45+
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
46+
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
47+
48+
cv2.circle(img, (int(landmark[0]), int(landmark[1])), 1, (0, 0, 255), 4)
49+
cv2.circle(img, (int(landmark[2]), int(landmark[3])), 1, (0, 255, 255), 4)
50+
cv2.circle(img, (int(landmark[4]), int(landmark[5])), 1, (255, 0, 255), 4)
51+
cv2.circle(img, (int(landmark[6]), int(landmark[7])), 1, (0, 255, 0), 4)
52+
cv2.circle(img, (int(landmark[8]), int(landmark[9])), 1, (255, 0, 0), 4)
53+
54+
if label:
55+
tf = max(tl - 1, 1) # font thickness
56+
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
57+
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
58+
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
59+
cv2.putText(
60+
img,
61+
label,
62+
(c1[0], c1[1] - 2),
63+
0,
64+
tl / 3,
65+
[225, 255, 255],
66+
thickness=tf,
67+
lineType=cv2.LINE_AA,
68+
)
69+
70+
71+
class Retinaface_trt(object):
72+
"""
73+
description: A Retineface class that warps TensorRT ops, preprocess and postprocess ops.
74+
"""
75+
76+
def __init__(self, engine_file_path):
77+
# Create a Context on this device,
78+
self.cfx = cuda.Device(0).make_context()
79+
stream = cuda.Stream()
80+
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
81+
runtime = trt.Runtime(TRT_LOGGER)
82+
83+
# Deserialize the engine from file
84+
with open(engine_file_path, "rb") as f:
85+
engine = runtime.deserialize_cuda_engine(f.read())
86+
context = engine.create_execution_context()
87+
88+
host_inputs = []
89+
cuda_inputs = []
90+
host_outputs = []
91+
cuda_outputs = []
92+
bindings = []
93+
94+
for binding in engine:
95+
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
96+
dtype = trt.nptype(engine.get_binding_dtype(binding))
97+
# Allocate host and device buffers
98+
host_mem = cuda.pagelocked_empty(size, dtype)
99+
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
100+
# Append the device buffer to device bindings.
101+
bindings.append(int(cuda_mem))
102+
# Append to the appropriate list.
103+
if engine.binding_is_input(binding):
104+
host_inputs.append(host_mem)
105+
cuda_inputs.append(cuda_mem)
106+
else:
107+
host_outputs.append(host_mem)
108+
cuda_outputs.append(cuda_mem)
109+
110+
# Store
111+
self.stream = stream
112+
self.context = context
113+
self.engine = engine
114+
self.host_inputs = host_inputs
115+
self.cuda_inputs = cuda_inputs
116+
self.host_outputs = host_outputs
117+
self.cuda_outputs = cuda_outputs
118+
self.bindings = bindings
119+
120+
def infer(self, input_image_path):
121+
threading.Thread.__init__(self)
122+
# Make self the active context, pushing it on top of the context stack.
123+
124+
self.cfx.push()
125+
# Restore
126+
stream = self.stream
127+
context = self.context
128+
engine = self.engine
129+
host_inputs = self.host_inputs
130+
cuda_inputs = self.cuda_inputs
131+
host_outputs = self.host_outputs
132+
cuda_outputs = self.cuda_outputs
133+
bindings = self.bindings
134+
# Do image preprocess
135+
input_image, image_raw, origin_h, origin_w = self.preprocess_image(
136+
input_image_path
137+
)
138+
a = time.time()
139+
# Copy input image to host buffer
140+
np.copyto(host_inputs[0], input_image.ravel())
141+
# Transfer input data to the GPU.
142+
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
143+
# Run inference.
144+
context.execute_async(bindings=bindings, stream_handle=stream.handle)
145+
# Transfer predictions back from the GPU.
146+
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
147+
# Synchronize the stream
148+
stream.synchronize()
149+
# Remove any context from the top of the context stack, deactivating it.
150+
self.cfx.pop()
151+
# Here we use the first row of output in that batch_size = 1
152+
output = host_outputs[0]
153+
154+
# Do postprocess
155+
result_boxes, result_scores, result_landmark = self.post_process(
156+
output, origin_h, origin_w
157+
)
158+
b = time.time()-a
159+
print(b)
160+
161+
# Draw rectangles and labels on the original image
162+
163+
# Save image
164+
for i in range(len(result_boxes)):
165+
box = result_boxes[i]
166+
landmark = result_landmark[i]
167+
plot_one_box(
168+
box,
169+
landmark,
170+
image_raw,
171+
label="{}:{:.2f}".format( 'Face', result_scores[i]))
172+
parent, filename = os.path.split(input_image_path)
173+
save_name = os.path.join(parent, "output_" + filename)
174+
175+
cv2.imwrite(save_name, image_raw)
176+
177+
def destroy(self):
178+
# Remove any context from the top of the context stack, deactivating it.
179+
self.cfx.pop()
180+
181+
def preprocess_image(self, input_image_path):
182+
"""
183+
description: Read an image from image path, convert it to RGB,
184+
resize and pad it to target size, normalize to [0,1],
185+
transform to NCHW format.
186+
param:
187+
input_image_path: str, image path
188+
return:
189+
image: the processed image
190+
image_raw: the original image
191+
h: original height
192+
w: original width
193+
"""
194+
image_raw = cv2.imread(input_image_path)
195+
h, w, c = image_raw.shape
196+
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
197+
image = cv2.resize(image, (INPUT_W, INPUT_H))
198+
199+
image = image.astype(np.float32)
200+
201+
# HWC to CHW format:
202+
image -= (104, 117, 123)
203+
image = np.transpose(image, [2, 0, 1])
204+
# CHW to NCHW format
205+
image = np.expand_dims(image, axis=0)
206+
# Convert the image to row-major order, also known as "C order":
207+
image = np.ascontiguousarray(image)
208+
return image, image_raw, h, w
209+
210+
def xywh2xyxy(self, origin_h, origin_w, x,landmark):
211+
212+
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
213+
214+
r_w = INPUT_W / origin_w
215+
r_h = INPUT_H / origin_h
216+
217+
y[:, 0] = x[:, 0]/r_w
218+
y[:, 2] = x[:, 2] /r_w
219+
y[:, 1] = x[:, 1]/ r_h
220+
y[:, 3] = x[:, 3] /r_h
221+
222+
landmark[:,0] = landmark[:,0]/r_w
223+
landmark[:,1] = landmark[:,1]/ r_h
224+
landmark[:,2] = landmark[:,2]/r_w
225+
landmark[:,3] = landmark[:,3]/ r_h
226+
landmark[:,4] = landmark[:,4]/r_w
227+
landmark[:,5] = landmark[:,5]/ r_h
228+
landmark[:,6] = landmark[:,6]/r_w
229+
landmark[:,7] = landmark[:,7]/ r_h
230+
landmark[:,8] = landmark[:,8]/r_w
231+
landmark[:,9] = landmark[:,9]/ r_h
232+
233+
return y, landmark
234+
235+
def post_process(self, output, origin_h, origin_w):
236+
"""
237+
description: postprocess the prediction
238+
param:
239+
output: A tensor likes [num_boxes,x1,y1,x2,y2,conf,landmark_x1,landmark_y1,
240+
landmark_x2,landmark_y2,...]
241+
origin_h: height of original image
242+
origin_w: width of original image
243+
return:
244+
result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2]
245+
result_scores: finally scores, a tensor, each element is the score correspoing to box
246+
result_classid: finally classid, a tensor, each element is the classid correspoing to box
247+
"""
248+
# Get the num of boxes detected
249+
num = int(output[0])
250+
# Reshape to a two dimentional ndarray
251+
pred = np.reshape(output[1:], (-1, 15))[:num, :]
252+
# to torch Tensor
253+
pred = torch.Tensor(pred).cuda()
254+
# Get the boxes
255+
boxes = pred[:, :4]
256+
# Get the scores
257+
scores = pred[:, 4]
258+
# Get the landmark
259+
landmark = pred[:,5:15]
260+
# Choose those boxes that score > CONF_THRESH
261+
si = scores > CONF_THRESH
262+
boxes = boxes[si, :]
263+
scores = scores[si]
264+
265+
landmark = landmark[si,:]
266+
267+
# Get boxes and landmark
268+
boxes,landmark = self.xywh2xyxy(origin_h, origin_w, boxes,landmark)
269+
# Do nms
270+
indices = torchvision.ops.nms(boxes, scores, iou_threshold=IOU_THRESHOLD).cpu()
271+
result_boxes = boxes[indices, :].cpu()
272+
result_scores = scores[indices].cpu()
273+
result_landmark = landmark[indices].cpu()
274+
return result_boxes, result_scores, result_landmark
275+
276+
class myThread(threading.Thread):
277+
def __init__(self, func, args):
278+
threading.Thread.__init__(self)
279+
self.func = func
280+
self.args = args
281+
282+
def run(self):
283+
self.func(*self.args)
284+
285+
if __name__ == "__main__":
286+
# load custom plugins,make sure it has been generated
287+
PLUGIN_LIBRARY = "build/libdecodeplugin.so"
288+
ctypes.CDLL(PLUGIN_LIBRARY)
289+
engine_file_path = "build/retina_r50.engine"
290+
291+
retinaface = Retinaface_trt(engine_file_path)
292+
input_image_paths = ["zidane.jpg"]
293+
for i in range(10):
294+
for input_image_path in input_image_paths:
295+
# create a new thread to do inference
296+
thread = myThread(retinaface.infer, [input_image_path])
297+
thread.start()
298+
thread.join()
299+
300+
# destroy the instance
301+
retinaface.destroy()

0 commit comments

Comments
 (0)