forked from pytorch/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compiler.cpp
452 lines (387 loc) · 20.2 KB
/
compiler.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#include <iostream>
#include <memory>
#include <sstream>
#include <vector>
#include <cuda_runtime.h>
#include "NvInfer.h"
#include "ATen/core/function_schema.h"
#include "ATen/core/jit_type.h"
#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/loop_unrolling.h"
#include "torch/csrc/jit/passes/lower_graph.h"
#include "torch/csrc/jit/passes/pass_manager.h"
#include "torch/custom_class.h"
#include "core/compiler.h"
#include "core/conversion/conversion.h"
#include "core/lowering/lowering.h"
#include "core/partitioning/partitioning.h"
#include "core/runtime/runtime.h"
namespace torch_tensorrt {
namespace core {
void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
const std::string& serialized_engine,
runtime::RTDevice& device_info,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names,
std::string engine_id = "",
bool fallback = false) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(
mod._ivalue()->name() + "_engine_" + engine_id,
serialized_engine,
device_info,
input_binding_names,
output_binding_names);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;
//..
// Add the engine as an attribute of the module, this will let the engine be
// serialized and deserialized
mod.register_attribute(
name,
c10::getCustomClassType<c10::intrusive_ptr<runtime::TRTEngine>>(),
c10::IValue(std::move(engine_ptr)),
false);
// Add the module as an input into the graph
auto self = g->addInput("self_1");
self->setType(mod.type());
// Start by retriveing the engine from the module attribute list
auto engine_node = g->createGetAttr(self, name);
g->block()->appendNode(engine_node);
// Add inputs to the graph corresponding to the number of input tensors
// expected by the engine Also store those inputs in a vector so that they can
// be coalesced into a single list at runtime
std::vector<torch::jit::Value*> engine_inputs;
for (uint64_t i = 0; i < num_io.first; i++) {
auto in_val = g->addInput(std::string("input_") + std::to_string(i));
in_val->setType(c10::TensorType::get());
engine_inputs.push_back(in_val);
}
// Create a node that will merge all of the input tensors into a single list
// argument to the trt::execute_engine op Creates: prim::ListConstruct(<input
// tensors>)
auto input_list_node = g->createList(c10::TensorType::get(), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs));
g->block()->appendNode(input_list_node);
// Make a list of inputs to the actual trt::execute_engine op
// Note: Ordering of list and then engine is because we can pop off the engine
// first which contains all the metadata needed for execution
std::vector<torch::jit::Value*> execute_node_inputs;
execute_node_inputs.push_back(input_list_node->outputs()[0]);
execute_node_inputs.push_back(engine_node->outputs()[0]);
// Create the actual execution node trt::execute_engine using the assembled
// inputs
auto execute_node = g->create(
c10::Symbol::fromQualString("tensorrt::execute_engine"),
torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs),
1);
g->block()->appendNode(execute_node);
execute_node->outputs()[0]->setType(c10::ListType::ofTensors());
// Create a node to unpack the list into seperate tensors, in the case of
// there being only one tensor, the tensor will be returned, otherwise they
// are returned as a tuple of tensors. Creates: prim::ListUnpack(<engine
// output>)
auto unpack_node = g->createListUnpack(execute_node->outputs()[0], num_io.second);
g->block()->appendNode(unpack_node);
// If there are multiple output tensors from TensorRT we wrap them in a tuple
// to return, convert to tuple only when we only have 1 segmented graph
if (!fallback && unpack_node->outputs().size() > 1) {
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
// unpack node
auto return_tuple_node = g->createTuple(unpack_node->outputs());
g->block()->appendNode(return_tuple_node);
// Set the output as the produced tuple
g->registerOutput(return_tuple_node->outputs()[0]);
} else {
// if fallback is enabled, multiple outputs will be registered
for (size_t i = 0; i < unpack_node->outputs().size(); ++i) {
g->registerOutput(unpack_node->outputs()[i]);
}
}
LOG_DEBUG(*g << "(AddEngineToGraph)\n");
return;
}
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
// Go through Lowering to simplify graph
auto graph_and_parameters = lowering::Lower(mod, method_name, lowering::LowerInfo());
auto g = graph_and_parameters.first;
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
return conversion::VerifyConverterSupportForBlock(g->block());
}
partitioning::GraphAndMapping BuildHybridGraph(
torch::jit::script::Module& new_mod,
torch::jit::Block* block,
CompileSpec cfg,
ir::StaticParams static_params,
ir::CollectionTypeMap first_use_types,
bool expect_full_compilation = false) {
auto convert_info = cfg.convert_info;
auto partitioning_info = cfg.partitioning_info;
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
partitioning_ctx.input_types_map = first_use_types;
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
// TODO: Combine this within partition call
partitioning::populateInputIValues(&partitioning_ctx);
partitioning::partition(&partitioning_ctx, expect_full_compilation);
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
int num_torch_segments = 0;
int num_trt_segments = 0;
for (auto& seg_block : segmented_blocks) {
LOG_INFO("Block segment:" << seg_block);
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
num_trt_segments++;
auto inputs = seg_block.construct_inputs_spec();
// update the input ranges for each segments
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_info.engine_settings.device;
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(
new_mod,
temp_g,
engine,
cuda_device,
std::vector<std::string>(),
std::vector<std::string>(),
trt_engine_id.str(),
true);
seg_block.update_graph(temp_g);
} else {
num_torch_segments++;
// If full compilation is expected, ensure that all operators in Torch blocks are
// for collections processing
if (expect_full_compilation) {
for (auto torch_node : seg_block.block()->nodes()) {
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
TORCHTRT_THROW_ERROR(
"Full compilation specified but node "
<< *torch_node
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
<< " Try recompiling with require_full_compilation=False.");
}
}
}
}
}
// If full compilation is expected, cannot have more than 2 Torch segments
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
TORCHTRT_THROW_ERROR(
"Full compilation was requested but unable to convert all operations to TensorRT."
<< " Try recompiling with require_full_compilation=False.");
}
}
return partitioning::stitch(&partitioning_ctx, block);
}
ir::TypeMap MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::CollectionTypeMap& first_use_type_map,
bool requires_collection_handling = false) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
ir::TypeMap inferred_dtypes;
auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
<< g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());
for (auto in : collection_inputs) {
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
std::vector<c10::optional<at::ScalarType>> est_type_opt;
auto est_it = first_use_type_map.find(in);
if (est_it != first_use_type_map.end()) {
est_type_opt = first_use_type_map.find(in)->second;
}
// traverse elements in est_type_out and spec
for (size_t i = 0; i < est_type_opt.size(); i++) {
if (est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
// type
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
<< in->debugName() << " has type " << est_type_opt[i].value());
spec[i].dtype = est_type_opt[i].value();
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = at::kFloat;
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ". The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting "
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
}
} else {
// The user defined the type so no changes are necessary
}
// Insert entry for Value pointer and determined ScalarType
inferred_dtypes.insert({in, {spec[i].dtype}});
}
}
return inferred_dtypes;
}
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
auto g = graph_and_parameters.first;
TORCHTRT_CHECK(
conversion::VerifyConverterSupportForBlock(g->block()),
"Not all operations in graph are supported by the compiler");
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
// Ensure none of the specified types are of acceptable input types incompatible with TRT
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
for (auto value_to_dtypes : first_use_types) {
for (auto dtype : value_to_dtypes.second) {
TORCHTRT_CHECK(
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
}
}
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
return engine;
}
bool userRequestedFallback(CompileSpec& cfg) {
return cfg.lower_info.forced_fallback_modules.size() != 0 ||
cfg.partitioning_info.forced_fallback_operators.size() != 0;
}
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);
for (const torch::jit::Method& method : mod.get_methods()) {
if (method.name().compare("forward") == 0) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
// Determine if the block is convertible/has collection output, and based on the result,
// whether full compilation can be expected
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto inputIsCollection = conversion::InputIsCollection(g->block());
auto outputIsCollection = conversion::OutputIsCollection(g->block());
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
// Determine whether user specifications necessitate partitioning
auto isFallbackRequested = userRequestedFallback(cfg);
// Extract map of IValue to DType
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);
// Check whether any of the input types are Long
bool user_requested_long = false;
for (auto dtype : type_map) {
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
}
// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
user_requested_long &= (casts_inserted > 0);
}
// Partitioning is required if:
// 1. User requested some modules/operators fallback
// 2. The block (graph) cannot be converted due to operator coverage
// 3. The output of the graph is a collection
// 4. The user requested a non-TRT data type input
auto isPartitioningRequired =
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
// The user did not require full compilation, but the model can be fully compiled
if (cfg.partitioning_info.enabled && !isPartitioningRequired) {
LOG_INFO("Skipping partitioning since model is fully supported");
}
// The user did not require full compilation, and the model can be fully compiled
// or, the user required full compilation but the I/O of the graph use collections
if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) {
// If the model is fully-compilable and the user has specified full compilation, run partitioning
// to generate collection-processing code in Torch
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);
auto graph_and_mapping =
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
new_g->inputs()[i]->setDebugName(std::string("input_") + std::to_string(i));
}
LOG_INFO(*new_g << "(GraphAfterFallback)");
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
// module
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
return mod;
}
} else {
TORCHTRT_CHECK(
conversion::VerifyConverterSupportForBlock(g->block()),
"Not all operations in graph are supported by the compiler");
// TODO find the right
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
AddEngineToGraph(new_mod, new_g, engine, cuda_device, std::vector<std::string>(), std::vector<std::string>());
}
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);
}
}
return new_mod;
}
torch::jit::script::Module EmbedEngineInNewModule(
const std::string& engine,
runtime::RTDevice cuda_device,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<const int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine, cuda_device, input_binding_names, output_binding_names);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);
return new_mod;
}
void set_device(const int gpu_id) {
TORCHTRT_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
}
} // namespace core
} // namespace torch_tensorrt