Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Works with yolov4-tiny and TRT7/8 #158

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions modules/calibrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ SOFTWARE.
#include <fstream>
#include <iostream>
#include <iterator>
#include <random>
#include <cuda_runtime.h>
#include <cuda.h>

Int8EntropyCalibrator::Int8EntropyCalibrator(const uint32_t& batchSize, const std::string& calibImages,
const std::string& calibImagesPath,
const std::string& calibTableFilePath,
const uint64_t& inputSize, const uint32_t& inputH,
const uint32_t& inputW, const std::string& inputBlobName,
const std::string &s_net_type_) :
const std::string& calibImagesPath,
const std::string& calibTableFilePath,
const uint64_t& inputSize, const uint32_t& inputH,
const uint32_t& inputW, const std::string& inputBlobName, const std::string &s_net_type_) :
m_BatchSize(batchSize),
m_InputH(inputH),
m_InputW(inputW),
Expand All @@ -42,22 +44,24 @@ Int8EntropyCalibrator::Int8EntropyCalibrator(const uint32_t& batchSize, const st
m_InputBlobName(inputBlobName),
m_CalibTableFilePath(calibTableFilePath),
m_ImageIndex(0),
_s_net_type(s_net_type_)
_s_net_type(s_net_type_)
{
if (!fileExists(m_CalibTableFilePath, false))
{
m_ImageList = loadImageList(calibImages, calibImagesPath);
m_ImageList.resize(static_cast<int>(m_ImageList.size() / m_BatchSize) * m_BatchSize);
std::random_shuffle(m_ImageList.begin(), m_ImageList.end(),
[](int i) { return rand() % i; });
std::random_device rng;
std::mt19937 urng(rng());

m_ImageList = loadImageList(calibImages, calibImagesPath);
m_ImageList.resize(static_cast<int>(m_ImageList.size() / m_BatchSize) * m_BatchSize);
std::shuffle(m_ImageList.begin(), m_ImageList.end(), urng);
}

NV_CUDA_CHECK(cudaMalloc(&m_DeviceInput, m_InputCount * sizeof(float)));
}

Int8EntropyCalibrator::~Int8EntropyCalibrator() { NV_CUDA_CHECK(cudaFree(m_DeviceInput)); }

bool Int8EntropyCalibrator::getBatch(void* bindings[], const char* names[], int nbBindings)
bool Int8EntropyCalibrator::getBatch(void* bindings[], const char* names[], int /*nbBindings*/) noexcept
{
if (m_ImageIndex + m_BatchSize >= m_ImageList.size()) return false;

Expand All @@ -69,16 +73,16 @@ bool Int8EntropyCalibrator::getBatch(void* bindings[], const char* names[], int
}
m_ImageIndex += m_BatchSize;

cv::Mat trtInput = blobFromDsImages(dsImages, m_InputH, m_InputW);
blobFromDsImages(dsImages, m_blob, m_InputH, m_InputW);

NV_CUDA_CHECK(cudaMemcpy(m_DeviceInput, trtInput.ptr<float>(0), m_InputCount * sizeof(float),
NV_CUDA_CHECK(cudaMemcpy(m_DeviceInput, m_blob.ptr<float>(0), m_InputCount * sizeof(float),
cudaMemcpyHostToDevice));
assert(!strcmp(names[0], m_InputBlobName.c_str()));
bindings[0] = m_DeviceInput;
return true;
}

const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length)
const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length) noexcept
{
void* output;
m_CalibrationCache.clear();
Expand All @@ -105,10 +109,10 @@ const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length)
return output;
}

void Int8EntropyCalibrator::writeCalibrationCache(const void* cache, size_t length)
void Int8EntropyCalibrator::writeCalibrationCache(const void* cache, size_t length) noexcept
{
assert(!m_CalibTableFilePath.empty());
std::ofstream output(m_CalibTableFilePath, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
output.close();
}
}
16 changes: 9 additions & 7 deletions modules/calibrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class Int8EntropyCalibrator : public nvinfer1::IInt8EntropyCalibrator2
Int8EntropyCalibrator(const uint32_t& batchSize, const std::string& calibImages,
const std::string& calibImagesPath, const std::string& calibTableFilePath,
const uint64_t& inputSize, const uint32_t& inputH, const uint32_t& inputW,
const std::string& inputBlobName,const std::string &s_net_type_);
const std::string& inputBlobName, const std::string &s_net_type_);
virtual ~Int8EntropyCalibrator();

int getBatchSize() const override { return m_BatchSize; }
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
const void* readCalibrationCache(size_t& length) override;
void writeCalibrationCache(const void* cache, size_t length) override;
int getBatchSize() const noexcept override { return m_BatchSize; }
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override;
const void* readCalibrationCache(size_t& length) noexcept override;
void writeCalibrationCache(const void* cache, size_t length) noexcept override;

private:
const uint32_t m_BatchSize;
Expand All @@ -50,13 +50,15 @@ class Int8EntropyCalibrator : public nvinfer1::IInt8EntropyCalibrator2
const uint64_t m_InputSize;
const uint64_t m_InputCount;
const std::string m_InputBlobName;
const std::string _s_net_type;
const std::string _s_net_type;
const std::string m_CalibTableFilePath{nullptr};
uint32_t m_ImageIndex;
bool m_ReadCache{true};
void* m_DeviceInput{nullptr};
std::vector<std::string> m_ImageList;
std::vector<char> m_CalibrationCache;

cv::Mat m_blob;
};

#endif
#endif
107 changes: 70 additions & 37 deletions modules/chunk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,34 @@

namespace nvinfer1
{
Chunk::Chunk()
{

}
Chunk::Chunk(const void* buffer, size_t size)
{
assert(size == sizeof(_n_size_split));
_n_size_split = *reinterpret_cast<const int*>(buffer);
}
Chunk::~Chunk()
{
}

}
int Chunk::getNbOutputs() const
int Chunk::getNbOutputs() const noexcept
{
return 2;
}

Dims Chunk::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
Dims Chunk::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)noexcept
{
assert(nbInputDims == 1);
assert(index == 0 || index == 1);
return Dims3(inputs[0].d[0] / 2, inputs[0].d[1], inputs[0].d[2]);
}

int Chunk::initialize()
int Chunk::initialize() noexcept
{
return 0;
}

void Chunk::terminate()
void Chunk::terminate() noexcept
{
}

size_t Chunk::getWorkspaceSize(int maxBatchSize) const
size_t Chunk::getWorkspaceSize(int maxBatchSize) const noexcept
{
return 0;
}
Expand All @@ -60,81 +53,121 @@ namespace nvinfer1
const void* const* inputs,
void** outputs,
void* workspace,
cudaStream_t stream)
cudaStream_t stream)noexcept
{
return enqueue(batchSize, inputs, (void* const*)outputs, workspace, stream);
}

int Chunk::enqueue(int batchSize,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept
{
//batch
for (int b = 0; b < batchSize; ++b)
{
NV_CUDA_CHECK(cudaMemcpy((char*)outputs[0] + b * _n_size_split, (char*)inputs[0] + b * 2 * _n_size_split, _n_size_split, cudaMemcpyDeviceToDevice));
NV_CUDA_CHECK(cudaMemcpy((char*)outputs[1] + b * _n_size_split, (char*)inputs[0] + b * 2 * _n_size_split + _n_size_split, _n_size_split, cudaMemcpyDeviceToDevice));
}
// NV_CUDA_CHECK(cudaMemcpy(outputs[0], inputs[0], _n_size_split, cudaMemcpyDeviceToDevice));
// NV_CUDA_CHECK(cudaMemcpy(outputs[1], (void*)((char*)inputs[0] + _n_size_split), _n_size_split, cudaMemcpyDeviceToDevice));
return 0;
}

size_t Chunk::getSerializationSize() const
size_t Chunk::getSerializationSize() const noexcept
{
return sizeof(_n_size_split);
}

void Chunk::serialize(void *buffer)const
void Chunk::serialize(void *buffer)const noexcept
{
*reinterpret_cast<int*>(buffer) = _n_size_split;
}

const char* Chunk::getPluginType()const
const char* Chunk::getPluginType()const noexcept
{
return "CHUNK_TRT";
}
const char* Chunk::getPluginVersion() const
{
const char* Chunk::getPluginVersion() const noexcept
{
return "1.0";
}

void Chunk::destroy()
void Chunk::destroy() noexcept
{
delete this;
}

void Chunk::setPluginNamespace(const char* pluginNamespace)
void Chunk::setPluginNamespace(const char* pluginNamespace) noexcept
{
_s_plugin_namespace = pluginNamespace;
}

const char* Chunk::getPluginNamespace() const
const char* Chunk::getPluginNamespace() const noexcept
{
return _s_plugin_namespace.c_str();
}

DataType Chunk::getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
int nbInputs) const noexcept
{
assert(index == 0 || index == 1);
return DataType::kFLOAT;
}

bool Chunk::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
bool Chunk::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const noexcept
{
return false;
}

bool Chunk::canBroadcastInputAcrossBatch(int inputIndex) const
bool Chunk::canBroadcastInputAcrossBatch(int inputIndex) const noexcept
{
return false;
}

void Chunk::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) {}
void Chunk::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
}

void Chunk::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
{
_n_size_split = in->dims.d[0] / 2 * in->dims.d[1] * in->dims.d[2] *sizeof(float);
}
void Chunk::detachFromContext() {}
void Chunk::detachFromContext()
{
}

bool Chunk::supportsFormat(DataType type, PluginFormat format) const noexcept
{
return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR);
}

void Chunk::configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept
{
size_t typeSize = sizeof(float);
switch (type)
{
case DataType::kFLOAT:
typeSize = sizeof(float);
break;
case DataType::kHALF:
typeSize = sizeof(float) / 2;
break;
case DataType::kINT8:
typeSize = 1;
break;
case DataType::kINT32:
typeSize = 4;
break;
case DataType::kBOOL:
typeSize = 1;
break;
}
_n_size_split = inputDims->d[0] / 2 * inputDims->d[1] * inputDims->d[2] * typeSize;
}

// Clone the plugin
IPluginV2IOExt* Chunk::clone() const
IPluginV2* Chunk::clone() const noexcept
{
Chunk *p = new Chunk();
p->_n_size_split = _n_size_split;
Expand All @@ -153,41 +186,41 @@ namespace nvinfer1
_fc.fields = _vec_plugin_attributes.data();
}

const char* ChunkPluginCreator::getPluginName() const
const char* ChunkPluginCreator::getPluginName() const noexcept
{
return "CHUNK_TRT";
}

const char* ChunkPluginCreator::getPluginVersion() const
const char* ChunkPluginCreator::getPluginVersion() const noexcept
{
return "1.0";
}

const PluginFieldCollection* ChunkPluginCreator::getFieldNames()
const PluginFieldCollection* ChunkPluginCreator::getFieldNames()noexcept
{
return &_fc;
}

IPluginV2IOExt* ChunkPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
IPluginV2* ChunkPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)noexcept
{
Chunk* obj = new Chunk();
obj->setPluginNamespace(_s_name_space.c_str());
return obj;
}

IPluginV2IOExt* ChunkPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
IPluginV2* ChunkPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)noexcept
{
Chunk* obj = new Chunk(serialData,serialLength);
obj->setPluginNamespace(_s_name_space.c_str());
return obj;
}

void ChunkPluginCreator::setPluginNamespace(const char* libNamespace)
void ChunkPluginCreator::setPluginNamespace(const char* libNamespace)noexcept
{
_s_name_space = libNamespace;
}

const char* ChunkPluginCreator::getPluginNamespace() const
const char* ChunkPluginCreator::getPluginNamespace() const noexcept
{
return _s_name_space.c_str();
}
Expand Down
Loading