forked from NVIDIA-AI-IOT/tf_to_trt_image_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_trt.cu
353 lines (296 loc) · 9.39 KB
/
test_trt.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
/**
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
* Full license terms provided in LICENSE.md file.
*/
#include <iostream>
#include <string>
#include <vector>
#include <sstream>
#include <chrono>
#include <stdexcept>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <NvInfer.h>
#define MS_PER_SEC 1000.0
using namespace std;
using namespace nvinfer1;
class TestConfig;
typedef void (*preprocess_fn_t)(float *input, size_t channels, size_t height, size_t width);
float * imageToTensor(const cv::Mat & image);
void preprocessVgg(float *input, size_t channels, size_t height, size_t width);
void preprocessInception(float *input, size_t channels, size_t height, size_t width);
size_t argmax(float *input, size_t numel);
void test(const TestConfig &testConfig);
class TestConfig
{
public:
string imagePath;
string planPath;
string inputNodeName;
string outputNodeName;
string preprocessFnName;
string inputHeight;
string inputWidth;
string numOutputCategories;
string dataType;
string maxBatchSize;
string workspaceSize;
string numRuns;
string useMappedMemory;
string statsPath;
TestConfig(int argc, char * argv[])
{
imagePath = argv[1];
planPath = argv[2];
inputNodeName = argv[3];
inputHeight = argv[4];
inputWidth = argv[5];
outputNodeName = argv[6];
numOutputCategories = argv[7];
preprocessFnName = argv[8];
numRuns = argv[9];
dataType = argv[10];
maxBatchSize = argv[11];
workspaceSize = argv[12];
useMappedMemory = argv[13];
statsPath = argv[14];
}
static string UsageString()
{
string s = "";
s += "imagePath: \n";
s += "planPath: \n";
s += "inputNodeName: \n";
s += "inputHeight: \n";
s += "inputWidth: \n";
s += "outputNodeName: \n";
s += "numOutputCategories: \n";
s += "preprocessFnName: \n";
s += "numRuns: \n";
s += "dataType: \n";
s += "maxBatchSize: \n";
s += "workspaceSize: \n";
s += "useMappedMemory: \n";
s += "statsPath: \n";
return s;
}
string ToString()
{
string s = "";
s += "imagePath: " + imagePath + "\n";
s += "planPath: " + planPath + "\n";
s += "inputNodeName: " + inputNodeName + "\n";
s += "inputHeight: " + inputHeight + "\n";
s += "inputWidth: " + inputWidth + "\n";
s += "outputNodeName: " + outputNodeName + "\n";
s += "numOutputCategories: " + numOutputCategories + "\n";
s += "preprocessFnName: " + preprocessFnName + "\n";
s += "numRuns: " + numRuns + "\n";
s += "dataType: " + dataType + "\n";
s += "maxBatchSize: " + maxBatchSize + "\n";
s += "workspaceSize: " + workspaceSize + "\n";
s += "useMappedMemory: " + useMappedMemory + "\n";
s += "statsPath: " + statsPath + "\n";
return s;
}
static int ToInteger(string value)
{
int valueInt;
stringstream ss;
ss << value;
ss >> valueInt;
return valueInt;
}
preprocess_fn_t PreprocessFn() const {
if (preprocessFnName == "preprocess_vgg")
return preprocessVgg;
else if (preprocessFnName == "preprocess_inception")
return preprocessInception;
else
throw runtime_error("Invalid preprocessing function name.");
}
int InputWidth() const { return ToInteger(inputWidth); }
int InputHeight() const { return ToInteger(inputHeight); }
int NumOutputCategories() const { return ToInteger(numOutputCategories); }
nvinfer1::DataType DataType() const {
if (dataType == "float")
return nvinfer1::DataType::kFLOAT;
else if (dataType == "half")
return nvinfer1::DataType::kHALF;
else
throw runtime_error("Invalid data type.");
}
int MaxBatchSize() const { return ToInteger(maxBatchSize); }
int WorkspaceSize() const { return ToInteger(workspaceSize); }
int NumRuns() const { return ToInteger(numRuns); }
int UseMappedMemory() const { return ToInteger(useMappedMemory); }
};
class Logger : public ILogger
{
void log(Severity severity, const char * msg) override
{
cout << msg << endl;
}
} gLogger;
int main(int argc, char * argv[])
{
if (argc != 15)
{
cout << TestConfig::UsageString() << endl;
return 0;
}
TestConfig testConfig(argc, argv);
cout << "\ntestConfig: \n" << testConfig.ToString() << endl;
test(testConfig);
return 0;
}
float *imageToTensor(const cv::Mat & image)
{
const size_t height = image.rows;
const size_t width = image.cols;
const size_t channels = image.channels();
const size_t numel = height * width * channels;
const size_t stridesCv[3] = { width * channels, channels, 1 };
const size_t strides[3] = { height * width, width, 1 };
float * tensor;
cudaHostAlloc((void**)&tensor, numel * sizeof(float), cudaHostAllocMapped);
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
for (int k = 0; k < channels; k++)
{
const size_t offsetCv = i * stridesCv[0] + j * stridesCv[1] + k * stridesCv[2];
const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
tensor[offset] = (float) image.data[offsetCv];
}
}
}
return tensor;
}
void preprocessVgg(float * tensor, size_t channels, size_t height, size_t width)
{
const size_t strides[3] = { height * width, width, 1 };
const float mean[3] = { 123.68, 116.78, 103.94 };
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
for (int k = 0; k < channels; k++)
{
const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
tensor[offset] -= mean[k];
}
}
}
}
void preprocessInception(float * tensor, size_t channels, size_t height, size_t width)
{
const size_t numel = channels * height * width;
for (int i = 0; i < numel; i++)
tensor[i] = 2.0 * (tensor[i] / 255.0 - 0.5);
}
size_t argmax(float * tensor, size_t numel)
{
if (numel <= 0)
return 0;
size_t maxIndex = 0;
float max = tensor[0];
for (int i = 0; i < numel; i++)
{
if (tensor[i] > max)
{
maxIndex = i;
max = tensor[i];
}
}
return maxIndex;
}
void test(const TestConfig &testConfig)
{
ifstream planFile(testConfig.planPath);
stringstream planBuffer;
planBuffer << planFile.rdbuf();
string plan = planBuffer.str();
IRuntime *runtime = createInferRuntime(gLogger);
ICudaEngine *engine = runtime->deserializeCudaEngine((void*)plan.data(),
plan.size(), nullptr);
IExecutionContext *context = engine->createExecutionContext();
int inputBindingIndex, outputBindingIndex;
inputBindingIndex = engine->getBindingIndex(testConfig.inputNodeName.c_str());
outputBindingIndex = engine->getBindingIndex(testConfig.outputNodeName.c_str());
// load and preprocess image
cv::Mat image = cv::imread(testConfig.imagePath, CV_LOAD_IMAGE_COLOR);
cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
cv::resize(image, image, cv::Size(testConfig.InputWidth(), testConfig.InputHeight()));
float *input = imageToTensor(image);
testConfig.PreprocessFn()(input, 3, testConfig.InputHeight(), testConfig.InputWidth());
// allocate memory on host / device for input / output
float *output;
float *inputDevice;
float *outputDevice;
size_t inputSize = testConfig.InputHeight() * testConfig.InputWidth() * 3 * sizeof(float);
cudaHostAlloc(&output, testConfig.NumOutputCategories() * sizeof(float), cudaHostAllocMapped);
if (testConfig.UseMappedMemory())
{
cudaHostGetDevicePointer(&inputDevice, input, 0);
cudaHostGetDevicePointer(&outputDevice, output, 0);
}
else
{
cudaMalloc(&inputDevice, inputSize);
cudaMalloc(&outputDevice, testConfig.NumOutputCategories() * sizeof(float));
}
float *bindings[2];
bindings[inputBindingIndex] = inputDevice;
bindings[outputBindingIndex] = outputDevice;
// run and compute average time over numRuns iterations
double avgTime = 0;
for (int i = 0; i < testConfig.NumRuns() + 1; i++)
{
chrono::duration<double> diff;
if (testConfig.UseMappedMemory())
{
auto t0 = chrono::steady_clock::now();
context->execute(1, (void**)bindings);
auto t1 = chrono::steady_clock::now();
diff = t1 - t0;
}
else
{
auto t0 = chrono::steady_clock::now();
cudaMemcpy(inputDevice, input, inputSize, cudaMemcpyHostToDevice);
context->execute(1, (void**)bindings);
cudaMemcpy(output, outputDevice, testConfig.NumOutputCategories() * sizeof(float), cudaMemcpyDeviceToHost);
auto t1 = chrono::steady_clock::now();
diff = t1 - t0;
}
if (i != 0)
avgTime += MS_PER_SEC * diff.count();
}
avgTime /= testConfig.NumRuns();
// save results to file
int maxCategoryIndex = argmax(output, testConfig.NumOutputCategories()) + 1001 - testConfig.NumOutputCategories();
cout << "Most likely category id is " << maxCategoryIndex << endl;
cout << "Average execution time in ms is " << avgTime << endl;
ofstream outfile;
outfile.open(testConfig.statsPath, ios_base::app);
outfile << "\n" << testConfig.planPath
<< " " << avgTime;
// << " " << maxCategoryIndex
// << " " << testConfig.InputWidth()
// << " " << testConfig.InputHeight()
// << " " << testConfig.MaxBatchSize()
// << " " << testConfig.WorkspaceSize()
// << " " << testConfig.dataType
// << " " << testConfig.NumRuns()
// << " " << testConfig.UseMappedMemory();
outfile.close();
cudaFree(inputDevice);
cudaFree(outputDevice);
cudaFreeHost(input);
cudaFreeHost(output);
engine->destroy();
context->destroy();
runtime->destroy();
}