Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinguish the numclass of pose and obb in config. h and add the Oriented Bounding Boxes (OBB) Estimation algorithm #1593

Merged
merged 24 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e8417b0
Merge pull request #1 from wang-xinyu/master
lindsayshuo Apr 30, 2024
e9e7adb
Add the generation of multi-class pose engines
May 2, 2024
e92bf89
Merge pull request #2 from lindsayshuo/lindsay
lindsayshuo May 2, 2024
6d36a3c
Merge pull request #3 from wang-xinyu/master
lindsayshuo May 7, 2024
46dd2c4
Merge pull request #4 from wang-xinyu/master
lindsayshuo May 9, 2024
e71a2ce
Merge pull request #5 from wang-xinyu/master
lindsayshuo May 14, 2024
9be6384
Change grids in forwardGpu to one-dimensional arrays
lindsayshuo May 14, 2024
3f182fe
Merge pull request #6 from lindsayshuo/shuo
lindsayshuo May 14, 2024
15e2e95
Update README.md
lindsayshuo May 14, 2024
66c6c22
Merge pull request #7 from lindsayshuo/shuo
lindsayshuo May 14, 2024
4bb7662
Merge pull request #8 from wang-xinyu/master
lindsayshuo May 15, 2024
434f3f2
Merge pull request #9 from wang-xinyu/master
lindsayshuo May 15, 2024
f5b29bf
Merge pull request #10 from wang-xinyu/master
lindsayshuo Aug 28, 2024
045d739
Update types.h
lindsayshuo Aug 28, 2024
67c4e7a
Merge pull request #11 from wang-xinyu/master
lindsayshuo Sep 9, 2024
e4bba88
yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split hea…
lindsayshuo Sep 19, 2024
84147aa
update
lindsayshuo Sep 19, 2024
817ed26
fix code style
lindsayshuo Sep 19, 2024
6794858
yolov8_5u_det model download link
lindsayshuo Sep 19, 2024
8a35c98
yolov8_5u_det model download link
lindsayshuo Sep 19, 2024
8641cbd
Merge pull request #12 from wang-xinyu/master
lindsayshuo Oct 18, 2024
46c20d9
Merge pull request #13 from wang-xinyu/master
lindsayshuo Oct 22, 2024
ad34101
Distinguish the numclass of pose and obb in config. h and add the Ori…
lindsayshuo Oct 22, 2024
1667ad4
fix code style
lindsayshuo Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions yolov8/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ else()
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/home/lindsay/TensorRT-8.4.1.5/include)
link_directories(/home/lindsay/TensorRT-8.4.1.5/lib)
include_directories(/home/lindsay/TensorRT-8.6.1.6/include)
link_directories(/home/lindsay/TensorRT-6.1.6/lib)
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)

Expand Down Expand Up @@ -60,3 +60,6 @@ target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_5u_det ${PROJECT_SOURCE_DIR}/yolov8_5u_det.cpp ${SRCS})
target_link_libraries(yolov8_5u_det nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_obb ${PROJECT_SOURCE_DIR}/yolov8_obb.cpp ${SRCS})
target_link_libraries(yolov8_obb nvinfer cudart myplugins ${OpenCV_LIBS})
27 changes: 25 additions & 2 deletions yolov8/README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ wget -O coco.txt https://raw.githubusercontent.com/amikelive/coco-labels/master/
```
cd {tensorrtx}/yolov8/
// Download inference images
wget https://github.com/lindsayshuo/infer_pic/blob/main/1709970363.6990473rescls.jpg
wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/1709970363.6990473rescls.jpg
mkdir samples
cp -r 1709970363.6990473rescls.jpg samples
// Download ImageNet labels
Expand All @@ -130,7 +130,7 @@ sudo ./yolov8_cls -d yolov8n-cls.engine ../samples
### Pose Estimation
```
cd {tensorrtx}/yolov8/
// update "kNumClass = 1" in config.h
// update "kPOseNumClass = 1" in config.h
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov8-pose.wts {tensorrtx}/yolov8/build
Expand All @@ -146,6 +146,28 @@ sudo ./yolov8_pose -d yolov8n-pose.engine ../images g //gpu postprocess
```


### Oriented Bounding Boxes (OBB) Estimation
```
cd {tensorrtx}/yolov8/
// update "kObbNumClass = 15" "kInputH = 1024" "kInputW = 1024" in config.h
wget https://github.com/lindsayshuo/infer_pic/releases/download/pics/obb.png
mkdir images
mv obb.png ./images
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov8-obb.wts {tensorrtx}/yolov8/build
cmake ..
make
sudo ./yolov8_obb -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to plan file
sudo ./yolov8_obb -d [.engine] [image folder] [c/g] // deserialize and run inference, the images in [image folder] will be processed.

// For example yolov8-obb
sudo ./yolov8_obb -s yolov8n-obb.wts yolov8n-obb.engine n
sudo ./yolov8_obb -d yolov8n-obb.engine ../images c //cpu postprocess
sudo ./yolov8_obb -d yolov8n-obb.engine ../images g //gpu postprocess
```


4. optional, load and run the tensorrt model in python

```
Expand All @@ -156,6 +178,7 @@ python yolov8_seg_trt.py # Segmentation
python yolov8_cls_trt.py # Classification
python yolov8_pose_trt.py # Pose Estimation
python yolov8_5u_det_trt.py # yolov8_5u_det(YOLOv5u with the anchor-free, objectness-free split head structure based on YOLOv8 features) model
python yolov8_obb_trt.py # Oriented Bounding Boxes (OBB) Estimation
```

# INT8 Quantization
Expand Down
4 changes: 2 additions & 2 deletions yolov8/gen_wts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parse_args():
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose', 'obb'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
Expand All @@ -39,7 +39,7 @@ def parse_args():
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32

if m_type in ['detect', 'seg', 'pose']:
if m_type in ['detect', 'seg', 'pose', 'obb']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

delattr(model.model[-1], 'anchors')
Expand Down
2 changes: 1 addition & 1 deletion yolov8/include/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation, bool is_pose);
int px_arry_num, int NumClass, bool is_segmentation, bool is_pose, bool is_obb);
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 7 additions & 1 deletion yolov8/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static int kNumClass = 80;
const static int kNumberOfPoints = 17; // number of keypoints total
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
Expand All @@ -23,3 +22,10 @@ constexpr static int kClsNumClass = 1000;
// Classfication model's input shape
constexpr static int kClsInputH = 224;
constexpr static int kClsInputW = 224;

// pose model's number of classes
constexpr static int kPOseNumClass = 1;
const static int kNumberOfPoints = 17; // number of keypoints total

// obb model's number of classes
constexpr static int kObbNumClass = 15;
4 changes: 4 additions & 0 deletions yolov8/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ nvinfer1::IHostMemory* buildEngineYolov8_5uDet(nvinfer1::IBuilder* builder, nvin
nvinfer1::IHostMemory* buildEngineYolov8_5uDetP6(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd,
float& gw, int& max_channels);

nvinfer1::IHostMemory* buildEngineYolov8Obb(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
int& max_channels);
33 changes: 22 additions & 11 deletions yolov8/include/postprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@
#include "NvInfer.h"
#include "types.h"

// Preprocessing functions
cv::Rect get_rect(cv::Mat& img, float bbox[4]);

void nms(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);

void batch_nms(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);

void draw_bbox(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

// Processing functions
void batch_process(std::vector<std::vector<Detection>>& res_batch, const float* decode_ptr_host, int batch_size,
int bbox_element, const std::vector<cv::Mat>& img_batch);

void batch_process_obb(std::vector<std::vector<Detection>>& res_batch, const float* decode_ptr_host, int batch_size,
int bbox_element, const std::vector<cv::Mat>& img_batch);
void process_decode_ptr_host(std::vector<Detection>& res, const float* decode_ptr_host, int bbox_element, cv::Mat& img,
int count);
void process_decode_ptr_host_obb(std::vector<Detection>& res, const float* decode_ptr_host, int bbox_element,
cv::Mat& img, int count);

// NMS functions
void nms(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);
void batch_nms(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);
void nms_obb(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);
void batch_nms_obb(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);

// CUDA-related functions
void cuda_decode(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects,
cudaStream_t stream);

void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream);
void cuda_decode_obb(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects,
cudaStream_t stream);
void cuda_nms_obb(float* parray, float nms_threshold, int max_objects, cudaStream_t stream);

// Drawing functions
void draw_bbox(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_bbox_obb(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);
void draw_mask_bbox(cv::Mat& img, std::vector<Detection>& dets, std::vector<cv::Mat>& masks,
std::unordered_map<int, std::string>& labels_map);
1 change: 1 addition & 0 deletions yolov8/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct alignas(float) Detection {
float class_id;
float mask[32];
float keypoints[kNumberOfPoints * 3]; // keypoints array with dynamic size based on kNumberOfPoints
float angle; // obb angle
};

struct AffineMatrix {
Expand Down
66 changes: 45 additions & 21 deletions yolov8/plugin/yololayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ __device__ float sigmoid(float x) {

namespace nvinfer1 {
YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth,
int netHeight, int maxOut, bool is_segmentation, bool is_pose, const int* strides,
int stridesLength) {
int netHeight, int maxOut, bool is_segmentation, bool is_pose, bool is_obb,
const int* strides, int stridesLength) {

mClassCount = classCount;
mNumberofpoints = numberofpoints;
Expand All @@ -40,6 +40,7 @@ YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float conft
memcpy(mStrides, strides, stridesLength * sizeof(int));
is_segmentation_ = is_segmentation;
is_pose_ = is_pose;
is_obb_ = is_obb;
}

YoloLayerPlugin::~YoloLayerPlugin() {
Expand All @@ -66,6 +67,7 @@ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) {
}
read(d, is_segmentation_);
read(d, is_pose_);
read(d, is_obb_);

assert(d == a + length);
}
Expand All @@ -87,14 +89,15 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
}
write(d, is_segmentation_);
write(d, is_pose_);
write(d, is_obb_);

assert(d == a + getSerializationSize());
}

size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT {
return sizeof(mClassCount) + sizeof(mNumberofpoints) + sizeof(mConfthreshkeypoints) + sizeof(mThreadCount) +
sizeof(mYoloV8netHeight) + sizeof(mYoloV8NetWidth) + sizeof(mMaxOutObject) + sizeof(mStridesLength) +
sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_);
sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_) + sizeof(is_obb_);
}

int YoloLayerPlugin::initialize() TRT_NOEXCEPT {
Expand Down Expand Up @@ -156,7 +159,7 @@ nvinfer1::IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT {

YoloLayerPlugin* p =
new YoloLayerPlugin(mClassCount, mNumberofpoints, mConfthreshkeypoints, mYoloV8NetWidth, mYoloV8netHeight,
mMaxOutObject, is_segmentation_, is_pose_, mStrides, mStridesLength);
mMaxOutObject, is_segmentation_, is_pose_, is_obb_, mStrides, mStridesLength);
p->setPluginNamespace(mPluginNamespace);
return p;
}
Expand All @@ -174,14 +177,14 @@ __device__ float Logist(float data) {

__global__ void CalDetection(const float* input, float* output, int numElements, int maxoutobject, const int grid_h,
int grid_w, const int stride, int classes, int nk, float confkeypoints, int outputElem,
bool is_segmentation, bool is_pose) {
bool is_segmentation, bool is_pose, bool is_obb) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= numElements)
return;

const int N_kpts = nk;
int total_grid = grid_h * grid_w;
int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0);
int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0);
int batchIdx = idx / total_grid;
int elemIdx = idx % total_grid;
const float* curInput = input + batchIdx * total_grid * info_len;
Expand Down Expand Up @@ -218,15 +221,16 @@ __global__ void CalDetection(const float* input, float* output, int numElements,

if (is_segmentation) {
for (int k = 0; k < 32; ++k) {
det->mask[k] = curInput[elemIdx + (4 + classes + k) * total_grid];
det->mask[k] =
curInput[elemIdx + (4 + classes + (is_pose ? N_kpts * 3 : 0) + (is_obb ? 1 : 0) + k) * total_grid];
}
}

if (is_pose) {
for (int kpt = 0; kpt < N_kpts; kpt++) {
int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3) * total_grid;
int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 1) * total_grid;
int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 2) * total_grid;
int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3) * total_grid;
int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 1) * total_grid;
int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + (is_obb ? 1 : 0) + kpt * 3 + 2) * total_grid;

float kpt_confidence = sigmoid(curInput[elemIdx + kpt_conf_idx]);

Expand All @@ -247,24 +251,43 @@ __global__ void CalDetection(const float* input, float* output, int numElements,
}
}
}

if (is_obb) {
double pi = 3.141592653589793f;
wang-xinyu marked this conversation as resolved.
Show resolved Hide resolved
auto angle_inx = curInput[elemIdx + (4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0) +
0) * total_grid];
auto angle = (sigmoid(angle_inx) - 0.25f) * pi;

auto cos1 = cos(angle);
auto sin1 = sin(angle);
auto xf = (curInput[elemIdx + 2 * total_grid] - curInput[elemIdx + 0 * total_grid]) / 2;
auto yf = (curInput[elemIdx + 3 * total_grid] - curInput[elemIdx + 1 * total_grid]) / 2;

auto x = xf * cos1 - yf * sin1;
auto y = xf * sin1 + yf * cos1;

float cx = (col + 0.5f + x) * stride;
float cy = (row + 0.5f + y) * stride;

float w1 = (curInput[elemIdx + 0 * total_grid] + curInput[elemIdx + 2 * total_grid]) * stride;
float h1 = (curInput[elemIdx + 1 * total_grid] + curInput[elemIdx + 3 * total_grid]) * stride;
det->bbox[0] = cx;
det->bbox[1] = cy;
det->bbox[2] = w1;
det->bbox[3] = h1;
det->angle = angle;
}
}

void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int mYoloV8netHeight,
int mYoloV8NetWidth, int batchSize) {

int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
cudaMemsetAsync(output, 0, sizeof(float), stream);
for (int idx = 0; idx < batchSize; ++idx) {
CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
}
int numElem = 0;

// const int maxGrids = mStridesLength;
// int grids[maxGrids][2];
// for (int i = 0; i < maxGrids; ++i) {
// grids[i][0] = mYoloV8netHeight / mStrides[i];
// grids[i][1] = mYoloV8NetWidth / mStrides[i];
// }

int maxGrids = mStridesLength;
int flatGridsLen = 2 * maxGrids;
int* flatGrids = new int[flatGridsLen];
Expand All @@ -286,7 +309,7 @@ void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cuda
// The CUDA kernel call remains unchanged
CalDetection<<<(numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>(
inputs[i], output, numElem, mMaxOutObject, grid_h, grid_w, stride, mClassCount, mNumberofpoints,
mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_);
mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_, is_obb_);
}

delete[] flatGrids;
Expand Down Expand Up @@ -317,7 +340,7 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi
assert(fc->nbFields == 1);
assert(strcmp(fc->fields[0].name, "combinedInfo") == 0);
const int* combinedInfo = static_cast<const int*>(fc->fields[0].data);
int netinfo_count = 8;
int netinfo_count = 9;
int class_count = combinedInfo[0];
int numberofpoints = combinedInfo[1];
float confthreshkeypoints = combinedInfo[2];
Expand All @@ -326,11 +349,12 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi
int max_output_object_count = combinedInfo[5];
bool is_segmentation = combinedInfo[6];
bool is_pose = combinedInfo[7];
bool is_obb = combinedInfo[8];
const int* px_arry = combinedInfo + netinfo_count;
int px_arry_length = fc->fields[0].length - netinfo_count;
YoloLayerPlugin* obj =
new YoloLayerPlugin(class_count, numberofpoints, confthreshkeypoints, input_w, input_h,
max_output_object_count, is_segmentation, is_pose, px_arry, px_arry_length);
max_output_object_count, is_segmentation, is_pose, is_obb, px_arry, px_arry_length);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
Expand Down
3 changes: 2 additions & 1 deletion yolov8/plugin/yololayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace nvinfer1 {
class API YoloLayerPlugin : public IPluginV2IOExt {
public:
YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth, int netHeight,
int maxOut, bool is_segmentation, bool is_pose, const int* strides, int stridesLength);
int maxOut, bool is_segmentation, bool is_pose, bool is_obb, const int* strides, int stridesLength);

YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
Expand Down Expand Up @@ -75,6 +75,7 @@ class API YoloLayerPlugin : public IPluginV2IOExt {
int mMaxOutObject;
bool is_segmentation_;
bool is_pose_;
bool is_obb_;
int* mStrides;
int mStridesLength;
};
Expand Down
7 changes: 4 additions & 3 deletions yolov8/src/block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,21 +258,22 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation, bool is_pose) {
int px_arry_num, int NumClass, bool is_segmentation, bool is_pose, bool is_obb) {
auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");
const int netinfo_count = 8; // Assuming the first 5 elements are for netinfo as per existing code.
const int netinfo_count = 9; // Assuming the first 5 elements are for netinfo as per existing code.
const int total_count = netinfo_count + px_arry_num; // Total number of elements for netinfo and px_arry combined.

std::vector<int> combinedInfo(total_count);
// Fill in the first 5 elements as per existing netinfo.
combinedInfo[0] = kNumClass;
combinedInfo[0] = NumClass;
combinedInfo[1] = kNumberOfPoints;
combinedInfo[2] = kConfThreshKeypoints;
combinedInfo[3] = kInputW;
combinedInfo[4] = kInputH;
combinedInfo[5] = kMaxNumOutputBbox;
combinedInfo[6] = is_segmentation;
combinedInfo[7] = is_pose;
combinedInfo[8] = is_obb;

// Copy the contents of px_arry into the combinedInfo vector after the initial
// 5 elements.
Expand Down
Loading
Loading