forked from NVIDIA-AI-IOT/tf_to_trt_image_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuff_to_plan.cpp
105 lines (89 loc) · 2.5 KB
/
uff_to_plan.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/**
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
* Full license terms provided in LICENSE.md file.
*/
#include <iostream>
#include <string>
#include <sstream>
#include <fstream>
#include <NvInfer.h>
#include <NvUffParser.h>
using namespace std;
using namespace nvinfer1;
using namespace nvuffparser;
class Logger : public ILogger
{
void log(Severity severity, const char * msg) override
{
cout << msg << endl;
}
} gLogger;
int toInteger(string value)
{
int valueInteger;
stringstream ss;
ss << value;
ss >> valueInteger;
return valueInteger;
}
DataType toDataType(string value)
{
if (value == "float")
return DataType::kFLOAT;
else if (value == "half")
return DataType::kHALF;
else
throw runtime_error("Unsupported data type");
}
int main(int argc, char *argv[])
{
if (argc != 10)
{
cout << "Usage: <uff_filename> <plan_filename> <input_name> <input_height> <input_width>"
<< " <output_name> <max_batch_size> <max_workspace_size> <data_type>\n";
return 1;
}
/* parse command line arguments */
string uffFilename = argv[1];
string planFilename = argv[2];
string inputName = argv[3];
int inputHeight = toInteger(argv[4]);
int inputWidth = toInteger(argv[5]);
string outputName = argv[6];
int maxBatchSize = toInteger(argv[7]);
int maxWorkspaceSize = toInteger(argv[8]);
DataType dataType = toDataType(argv[9]);
/* parse uff */
IBuilder *builder = createInferBuilder(gLogger);
INetworkDefinition *network = builder->createNetwork();
IUffParser *parser = createUffParser();
parser->registerInput(inputName.c_str(), DimsCHW(3, inputHeight, inputWidth));
parser->registerOutput(outputName.c_str());
if (!parser->parse(uffFilename.c_str(), *network, dataType))
{
cout << "Failed to parse UFF\n";
builder->destroy();
parser->destroy();
network->destroy();
return 1;
}
/* build engine */
if (dataType == DataType::kHALF)
builder->setHalf2Mode(true);
builder->setMaxBatchSize(maxBatchSize);
builder->setMaxWorkspaceSize(maxWorkspaceSize);
ICudaEngine *engine = builder->buildCudaEngine(*network);
/* serialize engine and write to file */
ofstream planFile;
planFile.open(planFilename);
IHostMemory *serializedEngine = engine->serialize();
planFile.write((char *)serializedEngine->data(), serializedEngine->size());
planFile.close();
/* break down */
builder->destroy();
parser->destroy();
network->destroy();
engine->destroy();
serializedEngine->destroy();
return 0;
}