Skip to content

Commit

Permalink
[CVCUDA] Utilize CV-CUDA batch processing function (PaddlePaddle#1223)
Browse files Browse the repository at this point in the history
* norm and permute batch processing

* move cache to mat, batch processors

* get batched tensor logic, resize on cpu logic

* fix cpu compile error

* remove vector mat api

* nits

* add comments

* nits

* fix batch size

* move initial resize on cpu option to use_cuda api

* fix pybind

* processor manager pybind

* rename mat and matbatch

* move initial resize on cpu to ppcls preprocessor

---------

Co-authored-by: Jason <[email protected]>
  • Loading branch information
wang-xinyu and jiangjiajun authored Feb 7, 2023
1 parent 7c9bf11 commit d3d9148
Show file tree
Hide file tree
Showing 29 changed files with 708 additions and 239 deletions.
43 changes: 27 additions & 16 deletions fastdeploy/core/fd_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"

#include <algorithm>
#include <cstring>

#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
Expand Down Expand Up @@ -142,6 +143,9 @@ void FDTensor::Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type,
const std::string& tensor_name,
const Device& new_device) {
if (device != new_device) {
FreeFn();
}
external_data_ptr = nullptr;
name = tensor_name;
device = new_device;
Expand Down Expand Up @@ -269,9 +273,10 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}
return buffer_ != nullptr;
#else
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
} else {
if (is_pinned_memory) {
Expand All @@ -285,9 +290,10 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}
return buffer_ != nullptr;
#else
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
}
buffer_ = realloc(buffer_, nbytes);
Expand All @@ -296,8 +302,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
}

void FDTensor::FreeFn() {
if (external_data_ptr != nullptr)
external_data_ptr = nullptr;
if (external_data_ptr != nullptr) external_data_ptr = nullptr;
if (buffer_ != nullptr) {
if (device == Device::GPU) {
#ifdef WITH_GPU
Expand Down Expand Up @@ -381,13 +386,16 @@ FDTensor::FDTensor(const Scalar& scalar) {
(reinterpret_cast<double*>(Data()))[0] = scalar.to<double>();
break;
default:
break;
break;
}
}

FDTensor::FDTensor(const FDTensor& other)
: shape(other.shape), name(other.name), dtype(other.dtype),
device(other.device), external_data_ptr(other.external_data_ptr),
: shape(other.shape),
name(other.name),
dtype(other.dtype),
device(other.device),
external_data_ptr(other.external_data_ptr),
device_id(other.device_id) {
// Copy buffer
if (other.buffer_ == nullptr) {
Expand All @@ -401,9 +409,12 @@ FDTensor::FDTensor(const FDTensor& other)
}

FDTensor::FDTensor(FDTensor&& other)
: buffer_(other.buffer_), shape(std::move(other.shape)),
name(std::move(other.name)), dtype(other.dtype),
external_data_ptr(other.external_data_ptr), device(other.device),
: buffer_(other.buffer_),
shape(std::move(other.shape)),
name(std::move(other.name)),
dtype(other.dtype),
external_data_ptr(other.external_data_ptr),
device(other.device),
device_id(other.device_id) {
other.name = "";
// Note(zhoushunjie): Avoid double free.
Expand Down
32 changes: 6 additions & 26 deletions fastdeploy/vision/classification/ppcls/ppcls_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,20 @@

namespace fastdeploy {
void BindPaddleClas(pybind11::module& m) {
pybind11::class_<vision::classification::PaddleClasPreprocessor>(
m, "PaddleClasPreprocessor")
pybind11::class_<vision::classification::PaddleClasPreprocessor,
vision::ProcessorManager>(m, "PaddleClasPreprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::classification::PaddleClasPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error(
"Failed to preprocess the input data in "
"PaddleClasPreprocessor.");
}
if (!self.CudaUsed()) {
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
}
return outputs;
})
.def("use_cuda",
[](vision::classification::PaddleClasPreprocessor& self,
bool enable_cv_cuda = false,
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); })
.def("disable_normalize",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisableNormalize();
})
.def("disable_permute",
[](vision::classification::PaddleClasPreprocessor& self) {
self.DisablePermute();
})
.def("initial_resize_on_cpu",
[](vision::classification::PaddleClasPreprocessor& self, bool v) {
self.InitialResizeOnCpu(v);
});

pybind11::class_<vision::classification::PaddleClasPostprocessor>(
Expand Down
33 changes: 12 additions & 21 deletions fastdeploy/vision/classification/ppcls/preprocessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,23 @@ void PaddleClasPreprocessor::DisablePermute() {
}
}

bool PaddleClasPreprocessor::Apply(std::vector<FDMat>* images,
bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
for (size_t i = 0; i < images->size(); ++i) {
for (size_t j = 0; j < processors_.size(); ++j) {
bool ret = false;
ret = (*(processors_[j].get()))(&((*images)[i]));
if (!ret) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[j]->Name() << "." << std::endl;
return false;
}
for (size_t j = 0; j < processors_.size(); ++j) {
ProcLib lib = ProcLib::DEFAULT;
if (initial_resize_on_cpu_ && j == 0 &&
processors_[j]->Name().find("Resize") == 0) {
lib = ProcLib::OPENCV;
}
if (!(*(processors_[j].get()))(image_batch, lib)) {
FDERROR << "Failed to processs image in " << processors_[j]->Name() << "."
<< std::endl;
return false;
}
}

outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
(*outputs)[0] = std::move(*(image_batch->Tensor()));
(*outputs)[0].device_id = DeviceId();
return true;
}
Expand Down
13 changes: 11 additions & 2 deletions fastdeploy/vision/classification/ppcls/preprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,26 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {

/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(std::vector<FDMat>* images,
virtual bool Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs);

/// This function will disable normalize in preprocessing step.
void DisableNormalize();
/// This function will disable hwc2chw in preprocessing step.
void DisablePermute();

/** \brief When the initial operator is Resize, and input image size is large,
* maybe it's better to run resize on CPU, because the HostToDevice memcpy
* is time consuming. Set this true to run the initial resize on CPU.
*
* \param[in] v ture or false
*/
void InitialResizeOnCpu(bool v) { initial_resize_on_cpu_ = v; }

private:
bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_;
Expand All @@ -54,6 +62,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
bool disable_normalize_ = false;
// read config file
std::string config_file_;
bool initial_resize_on_cpu_ = false;
};

} // namespace classification
Expand Down
59 changes: 29 additions & 30 deletions fastdeploy/vision/common/processors/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace fastdeploy {
namespace vision {

bool Processor::operator()(Mat* mat, ProcLib lib) {
bool Processor::operator()(FDMat* mat, ProcLib lib) {
ProcLib target = lib;
if (lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib;
Expand Down Expand Up @@ -52,39 +52,38 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
return ImplByOpenCV(mat);
}

FDTensor* Processor::UpdateAndGetCachedTensor(
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
const std::string& tensor_name, const Device& new_device,
const bool& use_pinned_memory) {
if (cached_tensors_.count(tensor_name) == 0) {
cached_tensors_[tensor_name] = FDTensor();
bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) {
ProcLib target = lib;
if (lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib;
}
cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory;
cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name,
new_device);
return &cached_tensors_[tensor_name];
}

FDTensor* Processor::CreateCachedGpuInputTensor(
Mat* mat, const std::string& tensor_name) {
if (target == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV
return ImplByFlyCV(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
#endif
} else if (target == ProcLib::CUDA) {
#ifdef WITH_GPU
FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) {
return src;
} else if (src->device == Device::CPU) {
FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(),
tensor_name, Device::GPU);
FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(),
cudaMemcpyHostToDevice, mat->Stream()) == 0,
"[ERROR] Error occurs while copy memory from CPU to GPU.");
return tensor;
} else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
}
FDASSERT(
mat_batch->Stream() != nullptr,
"CUDA processor requires cuda stream, please set stream for mat_batch");
return ImplByCuda(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
} else if (target == ProcLib::CVCUDA) {
#ifdef ENABLE_CVCUDA
FDASSERT(mat_batch->Stream() != nullptr,
"CV-CUDA processor requires cuda stream, please set stream for "
"mat_batch");
return ImplByCvCuda(mat_batch);
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
#endif
return nullptr;
}
// DEFAULT & OPENCV
return ImplByOpenCV(mat_batch);
}

void EnableFlyCV() {
Expand Down
Loading

0 comments on commit d3d9148

Please sign in to comment.