Skip to content

Commit

Permalink
Fix multi-batch problem in int8 quantization (#1604)
Browse files Browse the repository at this point in the history
* Use the configuration of FP16 by default.

* Fix multi-batch problem in int8 quantization
  • Loading branch information
mpj1234 authored Dec 5, 2024
1 parent b34d799 commit 62c1680
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions yolo11/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ nvinfer1::IHostMemory* buildEngineYolo11Cls(nvinfer1::IBuilder* builder, nvinfer
std::cout << "Your platform supports int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
assert(builder->platformHasFastInt8());
config->setFlag(nvinfer1::BuilderFlag::kINT8);
auto* calibrator = new Int8EntropyCalibrator2(1, kClsInputW, kClsInputH, kInputQuantizationFolder,
auto* calibrator = new Int8EntropyCalibrator2(kBatchSize, kClsInputW, kClsInputH, kInputQuantizationFolder,
"int8calib.table", kInputTensorName);
config->setInt8Calibrator(calibrator);
#endif
Expand Down Expand Up @@ -392,8 +392,8 @@ nvinfer1::IHostMemory* buildEngineYolo11Det(nvinfer1::IBuilder* builder, nvinfer
std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
assert(builder->platformHasFastInt8());
config->setFlag(nvinfer1::BuilderFlag::kINT8);
auto* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, kInputQuantizationFolder, "int8calib.table",
kInputTensorName);
auto* calibrator = new Int8EntropyCalibrator2(kBatchSize, kInputW, kInputH, kInputQuantizationFolder,
"int8calib.table", kInputTensorName);
config->setInt8Calibrator(calibrator);
#endif

Expand Down Expand Up @@ -775,8 +775,8 @@ nvinfer1::IHostMemory* buildEngineYolo11Seg(nvinfer1::IBuilder* builder, nvinfer
std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
assert(builder->platformHasFastInt8());
config->setFlag(nvinfer1::BuilderFlag::kINT8);
auto* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, kInputQuantizationFolder, "int8calib.table",
kInputTensorName);
auto* calibrator = new Int8EntropyCalibrator2(kBatchSize, kInputW, kInputH, kInputQuantizationFolder,
"int8calib.table", kInputTensorName);
config->setInt8Calibrator(calibrator);
#endif

Expand Down Expand Up @@ -1066,8 +1066,8 @@ nvinfer1::IHostMemory* buildEngineYolo11Pose(nvinfer1::IBuilder* builder, nvinfe
std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
assert(builder->platformHasFastInt8());
config->setFlag(nvinfer1::BuilderFlag::kINT8);
auto* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, kInputQuantizationFolder, "int8calib.table",
kInputTensorName);
auto* calibrator = new Int8EntropyCalibrator2(kBatchSize, kInputW, kInputH, kInputQuantizationFolder,
"int8calib.table", kInputTensorName);
config->setInt8Calibrator(calibrator);
#endif

Expand Down

0 comments on commit 62c1680

Please sign in to comment.