Skip to content

Commit

Permalink
added argument checking to classify_image
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Mar 6, 2018
1 parent 54aac51 commit 9c9b213
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
34 changes: 34 additions & 0 deletions examples/classify_image/classify_image.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ int main(int argc, char *argv[])
/* load the engine */
cout << "Loading TensorRT engine from plan file..." << endl;
ifstream planFile(planFilename);

if (!planFile.is_open())
{
cout << "Could not open plan file." << endl;
return 1;
}

stringstream planBuffer;
planBuffer << planFile.rdbuf();
string plan = planBuffer.str();
Expand All @@ -63,6 +70,19 @@ int main(int argc, char *argv[])
int inputBindingIndex, outputBindingIndex;
inputBindingIndex = engine->getBindingIndex(inputName.c_str());
outputBindingIndex = engine->getBindingIndex(outputName.c_str());

if (inputBindingIndex < 0)
{
cout << "Invalid input name." << endl;
return 1;
}

if (outputBindingIndex < 0)
{
cout << "Invalid output name." << endl;
return 1;
}

Dims inputDims, outputDims;
inputDims = engine->getBindingDimensions(inputBindingIndex);
outputDims = engine->getBindingDimensions(outputBindingIndex);
Expand All @@ -73,6 +93,13 @@ int main(int argc, char *argv[])
/* read image, convert color, and resize */
cout << "Preprocessing input..." << endl;
cv::Mat image = cv::imread(imageFilename, CV_LOAD_IMAGE_COLOR);

if (image.data == NULL)
{
cout << "Could not read image from file." << endl;
return 1;
}

cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
cv::resize(image, image, cv::Size(inputWidth, inputHeight));

Expand Down Expand Up @@ -119,6 +146,13 @@ int main(int argc, char *argv[])
cout << sortedIndices[i] << " ";

ifstream labelsFile(labelFilename);

if (!labelsFile.is_open())
{
cout << "\nCould not open label file." << endl;
return 1;
}

vector<string> labelMap;
string label;
while(getline(labelsFile, label))
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def frozenToPlan(frozen_graph_filename, plan_filename, input_name, input_height,
frozen_file=frozen_graph_filename,
output_nodes=[output_name],
output_filename=TMP_UFF_FILENAME,
text=False
text=False,
)

# convert frozen graph to engine (plan)
Expand Down

0 comments on commit 9c9b213

Please sign in to comment.