-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrt_util.py
277 lines (238 loc) · 11 KB
/
trt_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from collections import OrderedDict
from copy import copy
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import os
import math
from polygraphy import cuda
import random
from scipy import integrate
import tensorrt as trt
import torch
import requests
from io import BytesIO
from cuda import cudart
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import CreateConfig, Profile
from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine
from polygraphy.backend.trt import util as trt_util
from polygraphy import cuda
import ctypes
logger = trt.Logger(trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, '')
# layernorm = ctypes.CDLL("./layerNormPlugin/layerNormKernel.so")
# layernorm = ctypes.CDLL("./LayerNormPlugin/LayerNormPlugin.so")
groupnorm = ctypes.CDLL("./groupNormPlugin/groupNormKernel.so")
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
# Map of numpy dtype -> torch dtype
numpy_to_torch_dtype_dict = {
np.uint8 : torch.uint8,
np.int8 : torch.int8,
np.int16 : torch.int16,
np.int32 : torch.int32,
np.int64 : torch.int64,
np.float16 : torch.float16,
np.float32 : torch.float32,
np.float64 : torch.float64,
np.complex64 : torch.complex64,
np.complex128 : torch.complex128
}
if np.version.full_version >= "1.24.0":
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
else:
numpy_to_torch_dtype_dict[np.bool] = torch.bool
# Map of torch dtype -> numpy dtype
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
def CUASSERT(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t")
if len(cuda_ret) > 1:
return cuda_ret[1]
return None
class Engine():
def __init__(
self,
engine_path,
):
self.engine_path = engine_path
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.cuda_graph_instance = None # cuda graph
def __del__(self):
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray) ]
del self.engine
del self.context
del self.buffers
del self.tensors
def refit(self, onnx_path, onnx_refit_path):
def convert_int64(arr):
# TODO: smarter conversion
if len(arr.shape) == 0:
return np.int32(arr)
return arr
def add_to_map(refit_dict, name, values):
if name in refit_dict:
assert refit_dict[name] is None
if values.dtype == np.int64:
values = convert_int64(values)
refit_dict[name] = values
print(f"Refitting TensorRT engine with {onnx_refit_path} weights")
refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes
# Construct mapping from weight names in refit model -> original model
name_map = {}
for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes):
refit_node = refit_nodes[n]
assert node.op == refit_node.op
# Constant nodes in ONNX do not have inputs but have a constant output
if node.op == "Constant":
name_map[refit_node.outputs[0].name] = node.outputs[0].name
# Handle scale and bias weights
elif node.op == "Conv":
if node.inputs[1].__class__ == gs.Constant:
name_map[refit_node.name+"_TRTKERNEL"] = node.name+"_TRTKERNEL"
if node.inputs[2].__class__ == gs.Constant:
name_map[refit_node.name+"_TRTBIAS"] = node.name+"_TRTBIAS"
# For all other nodes: find node inputs that are initializers (gs.Constant)
else:
for i, inp in enumerate(node.inputs):
if inp.__class__ == gs.Constant:
name_map[refit_node.inputs[i].name] = inp.name
def map_name(name):
if name in name_map:
return name_map[name]
return name
# Construct refit dictionary
refit_dict = {}
refitter = trt.Refitter(self.engine, TRT_LOGGER)
all_weights = refitter.get_all()
for layer_name, role in zip(all_weights[0], all_weights[1]):
# for speciailized roles, use a unique name in the map:
if role == trt.WeightsRole.KERNEL:
name = layer_name+"_TRTKERNEL"
elif role == trt.WeightsRole.BIAS:
name = layer_name+"_TRTBIAS"
else:
name = layer_name
assert name not in refit_dict, "Found duplicate layer: " + name
refit_dict[name] = None
for n in refit_nodes:
# Constant nodes in ONNX do not have inputs but have a constant output
if n.op == "Constant":
name = map_name(n.outputs[0].name)
print(f"Add Constant {name}\n")
add_to_map(refit_dict, name, n.outputs[0].values)
# Handle scale and bias weights
elif n.op == "Conv":
if n.inputs[1].__class__ == gs.Constant:
name = map_name(n.name+"_TRTKERNEL")
add_to_map(refit_dict, name, n.inputs[1].values)
if n.inputs[2].__class__ == gs.Constant:
name = map_name(n.name+"_TRTBIAS")
add_to_map(refit_dict, name, n.inputs[2].values)
# For all other nodes: find node inputs that are initializers (AKA gs.Constant)
else:
for inp in n.inputs:
name = map_name(inp.name)
if inp.__class__ == gs.Constant:
add_to_map(refit_dict, name, inp.values)
for layer_name, weights_role in zip(all_weights[0], all_weights[1]):
if weights_role == trt.WeightsRole.KERNEL:
custom_name = layer_name+"_TRTKERNEL"
elif weights_role == trt.WeightsRole.BIAS:
custom_name = layer_name+"_TRTBIAS"
else:
custom_name = layer_name
# Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model
if layer_name.startswith("onnx::Trilu"):
continue
if refit_dict[custom_name] is not None:
refitter.set_weights(layer_name, weights_role, refit_dict[custom_name])
else:
print(f"[W] No refit weights for layer: {layer_name}")
if not refitter.refit_cuda_engine():
print("Failed to refit!")
exit(0)
def build(self, onnx_path, fp16, input_profile=None, enable_refit=False, enable_preview=False, enable_all_tactics=False, timing_cache=None, workspace_size=0):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
config_kwargs['preview_features'] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs['preview_features'].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs['memory_pool_limits'] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics:
config_kwargs['tactic_sources'] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16,
refittable=enable_refit,
profiles=[p],
load_timing_cache=timing_cache,
**config_kwargs
),
save_timing_cache=timing_cache
)
save_engine(engine, path=self.engine_path)
def load(self):
print(f"Loading TensorRT engine: {self.engine_path}")
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
def load2(self):
with open(self.engine_path, 'rb') as planfile:
self.engine = trt.Runtime(logger).deserialize_cuda_engine(planfile.read())
def activate(self, reuse_device_memory=None):
if reuse_device_memory:
self.context = self.engine.create_execution_context_without_device_memory()
self.context.device_memory = reuse_device_memory
else:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device='cuda'):
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
# print("-----------------")
# print(binding)
# print(shape)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
def infer(self, feed_dict, stream, use_cuda_graph=False):
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
if use_cuda_graph:
if self.cuda_graph_instance is not None:
CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr))
CUASSERT(cudart.cudaStreamSynchronize(stream.ptr))
else:
# do inference before CUDA graph capture
noerror = self.context.execute_async_v3(stream.ptr)
if not noerror:
raise ValueError(f"ERROR: inference failed.")
# capture cuda graph
CUASSERT(cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal))
self.context.execute_async_v3(stream.ptr)
self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
else:
noerror = self.context.execute_async_v3(stream.ptr)
if not noerror:
raise ValueError(f"ERROR: inference failed.")
return self.tensors
def runEngine(engine, feed_dict,stream):
return engine.infer(feed_dict, stream, use_cuda_graph=True)