diff --git a/yolov8/README.md b/yolov8/README.md index e4762e2f..0ceb1b5e 100644 --- a/yolov8/README.md +++ b/yolov8/README.md @@ -129,9 +129,10 @@ sudo ./yolov8_pose -d yolov8n-pose.engine ../images g //gpu postprocess ``` // install python-tensorrt, pycuda, etc. // ensure the yolov8n.engine and libmyplugins.so have been built -python yolov8_det.py # Detection -python yolov8_seg.py # Segmentation -python yolov8_cls.py # Classification +python yolov8_det_trt.py # Detection +python yolov8_seg_trt.py # Segmentation +python yolov8_cls_trt.py # Classification +python yolov8_pose_trt.py # Pose Estimation ``` # INT8 Quantization diff --git a/yolov8/include/model.h b/yolov8/include/model.h index 6546aa54..82586da1 100644 --- a/yolov8/include/model.h +++ b/yolov8/include/model.h @@ -25,3 +25,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Seg(nvinfer1::IBuilder* builder, nvinfer nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, int& max_channels); + +nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, + int& max_channels); diff --git a/yolov8/src/model.cpp b/yolov8/src/model.cpp index 4cc4088e..a5f7e8e5 100644 --- a/yolov8/src/model.cpp +++ b/yolov8/src/model.cpp @@ -1448,9 +1448,6 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe 1, 0, "model.22.dfl.conv.weight"); // det0 - std::cout << "conv15->getOutput(0)->getDimensions().d[0] : " << conv15->getOutput(0)->getDimensions().d[0] - << " (kInputH / strides[0]) * (kInputW / strides[0]) : " - << (kInputH / strides[0]) * (kInputW / strides[0]) << std::endl; auto shuffle_conv15 = cv4_conv_combined(network, weightMap, *conv15->getOutput(0), "model.22.cv4.0", (kInputH / strides[0]) * (kInputW / strides[0]), gw, "pose"); @@ -1530,3 +1527,333 @@ nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfe } return serialized_model; } + +nvinfer1::IHostMemory* buildEngineYolov8PoseP6(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, + nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, + int& max_channels) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); + /******************************************************************************************************* + ****************************************** YOLOV8 INPUT ********************************************** + *******************************************************************************************************/ + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{3, kInputH, kInputW}); + assert(data); + /******************************************************************************************************* + ***************************************** YOLOV8 BACKBONE ******************************************** + *******************************************************************************************************/ + nvinfer1::IElementWiseLayer* conv0 = + convBnSiLU(network, weightMap, *data, get_width(64, gw, max_channels), 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = + convBnSiLU(network, weightMap, *conv0->getOutput(0), get_width(128, gw, max_channels), 3, 2, 1, "model.1"); + // 11233 + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), get_width(128, gw, max_channels), + get_width(128, gw, max_channels), get_depth(3, gd), true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = + convBnSiLU(network, weightMap, *conv2->getOutput(0), get_width(256, gw, max_channels), 3, 2, 1, "model.3"); + // 22466 + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), get_width(256, gw, max_channels), + get_width(256, gw, max_channels), get_depth(6, gd), true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = + convBnSiLU(network, weightMap, *conv4->getOutput(0), get_width(512, gw, max_channels), 3, 2, 1, "model.5"); + // 22466 + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(6, gd), true, 0.5, "model.6"); + + nvinfer1::IElementWiseLayer* conv7 = + convBnSiLU(network, weightMap, *conv6->getOutput(0), get_width(768, gw, max_channels), 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), get_width(768, gw, max_channels), + get_width(768, gw, max_channels), get_depth(3, gd), true, 0.5, "model.8"); + + nvinfer1::IElementWiseLayer* conv9 = + convBnSiLU(network, weightMap, *conv8->getOutput(0), get_width(1024, gw, max_channels), 3, 2, 1, "model.9"); + nvinfer1::IElementWiseLayer* conv10 = + C2F(network, weightMap, *conv9->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), get_depth(3, gd), true, 0.5, "model.10"); + + nvinfer1::IElementWiseLayer* conv11 = + SPPF(network, weightMap, *conv10->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), 5, "model.11"); + + /******************************************************************************************************* + ********************************************* YOLOV8 HEAD ******************************************** + *******************************************************************************************************/ + // Head + float scale[] = {1.0, 2.0, 2.0}; // scale used for upsampling + + // P5 + nvinfer1::IResizeLayer* upsample12 = network->addResize(*conv11->getOutput(0)); + upsample12->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + upsample12->setScales(scale, 3); + nvinfer1::ITensor* concat13_inputs[] = {upsample12->getOutput(0), conv8->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat13 = network->addConcatenation(concat13_inputs, 2); + nvinfer1::IElementWiseLayer* conv14 = + C2(network, weightMap, *concat13->getOutput(0), get_width(768, gw, max_channels), + get_width(768, gw, max_channels), get_depth(3, gd), false, 0.5, "model.14"); + + // P4 + nvinfer1::IResizeLayer* upsample15 = network->addResize(*conv14->getOutput(0)); + upsample15->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + upsample15->setScales(scale, 3); + nvinfer1::ITensor* concat16_inputs[] = {upsample15->getOutput(0), conv6->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat16 = network->addConcatenation(concat16_inputs, 2); + nvinfer1::IElementWiseLayer* conv17 = + C2(network, weightMap, *concat16->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(3, gd), false, 0.5, "model.17"); + + // P3 + nvinfer1::IResizeLayer* upsample18 = network->addResize(*conv17->getOutput(0)); + upsample18->setResizeMode(nvinfer1::ResizeMode::kNEAREST); + upsample18->setScales(scale, 3); + nvinfer1::ITensor* concat19_inputs[] = {upsample18->getOutput(0), conv4->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat19 = network->addConcatenation(concat19_inputs, 2); + nvinfer1::IElementWiseLayer* conv20 = + C2(network, weightMap, *concat19->getOutput(0), get_width(256, gw, max_channels), + get_width(256, gw, max_channels), get_depth(3, gd), false, 0.5, "model.20"); + + // Additional layers for P4, P5, P6 + // P4/16-medium + nvinfer1::IElementWiseLayer* conv21 = convBnSiLU(network, weightMap, *conv20->getOutput(0), + get_width(256, gw, max_channels), 3, 2, 1, "model.21"); + nvinfer1::ITensor* concat22_inputs[] = {conv21->getOutput(0), conv17->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat22 = network->addConcatenation(concat22_inputs, 2); + nvinfer1::IElementWiseLayer* conv23 = + C2(network, weightMap, *concat22->getOutput(0), get_width(512, gw, max_channels), + get_width(512, gw, max_channels), get_depth(3, gd), false, 0.5, "model.23"); + + // P5/32-large + nvinfer1::IElementWiseLayer* conv24 = convBnSiLU(network, weightMap, *conv23->getOutput(0), + get_width(512, gw, max_channels), 3, 2, 1, "model.24"); + nvinfer1::ITensor* concat25_inputs[] = {conv24->getOutput(0), conv14->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat25 = network->addConcatenation(concat25_inputs, 2); + nvinfer1::IElementWiseLayer* conv26 = + C2(network, weightMap, *concat25->getOutput(0), get_width(768, gw, max_channels), + get_width(768, gw, max_channels), get_depth(3, gd), false, 0.5, "model.26"); + + // P6/64-xlarge + nvinfer1::IElementWiseLayer* conv27 = convBnSiLU(network, weightMap, *conv26->getOutput(0), + get_width(768, gw, max_channels), 3, 2, 1, "model.27"); + nvinfer1::ITensor* concat28_inputs[] = {conv27->getOutput(0), conv11->getOutput(0)}; + nvinfer1::IConcatenationLayer* concat28 = network->addConcatenation(concat28_inputs, 2); + nvinfer1::IElementWiseLayer* conv29 = + C2(network, weightMap, *concat28->getOutput(0), get_width(1024, gw, max_channels), + get_width(1024, gw, max_channels), get_depth(3, gd), false, 0.5, "model.29"); + + /******************************************************************************************************* + ********************************************* YOLOV8 OUTPUT ****************************************** + *******************************************************************************************************/ + int base_in_channel = (gw == 1.25) ? 80 : 64; + int base_out_channel = (gw == 0.25) ? std::max(64, std::min(kNumClass, 100)) : get_width(256, gw, max_channels); + + // output0 + nvinfer1::IElementWiseLayer* conv30_cv2_0_0 = + convBnSiLU(network, weightMap, *conv20->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv30_cv2_0_1 = + convBnSiLU(network, weightMap, *conv30_cv2_0_0->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv30_cv2_0_2 = + network->addConvolutionNd(*conv30_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv2.0.2.weight"], weightMap["model.30.cv2.0.2.bias"]); + conv30_cv2_0_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + + conv30_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + + nvinfer1::IElementWiseLayer* conv30_cv3_0_0 = + convBnSiLU(network, weightMap, *conv20->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.0.0"); + + nvinfer1::IElementWiseLayer* conv30_cv3_0_1 = convBnSiLU(network, weightMap, *conv30_cv3_0_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.30.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv30_cv3_0_2 = + network->addConvolutionNd(*conv30_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv3.0.2.weight"], weightMap["model.30.cv3.0.2.bias"]); + conv30_cv3_0_2->setStride(nvinfer1::DimsHW{1, 1}); + conv30_cv3_0_2->setPadding(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor30_0[] = {conv30_cv2_0_2->getOutput(0), conv30_cv3_0_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_0 = network->addConcatenation(inputTensor30_0, 2); + + // output1 + nvinfer1::IElementWiseLayer* conv30_cv2_1_0 = + convBnSiLU(network, weightMap, *conv23->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv30_cv2_1_1 = + convBnSiLU(network, weightMap, *conv30_cv2_1_0->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv30_cv2_1_2 = + network->addConvolutionNd(*conv30_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv2.1.2.weight"], weightMap["model.30.cv2.1.2.bias"]); + conv30_cv2_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::IElementWiseLayer* conv30_cv3_1_0 = + convBnSiLU(network, weightMap, *conv23->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv30_cv3_1_1 = convBnSiLU(network, weightMap, *conv30_cv3_1_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.30.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv30_cv3_1_2 = + network->addConvolutionNd(*conv30_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv3.1.2.weight"], weightMap["model.30.cv3.1.2.bias"]); + conv30_cv3_1_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor30_1[] = {conv30_cv2_1_2->getOutput(0), conv30_cv3_1_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_1 = network->addConcatenation(inputTensor30_1, 2); + + // output2 + nvinfer1::IElementWiseLayer* conv30_cv2_2_0 = + convBnSiLU(network, weightMap, *conv26->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv30_cv2_2_1 = + convBnSiLU(network, weightMap, *conv30_cv2_2_0->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv30_cv2_2_2 = + network->addConvolution(*conv30_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv2.2.2.weight"], weightMap["model.30.cv2.2.2.bias"]); + conv30_cv2_2_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv2_2_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::IElementWiseLayer* conv30_cv3_2_0 = + convBnSiLU(network, weightMap, *conv26->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv30_cv3_2_1 = convBnSiLU(network, weightMap, *conv30_cv3_2_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.30.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv30_cv3_2_2 = + network->addConvolution(*conv30_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv3.2.2.weight"], weightMap["model.30.cv3.2.2.bias"]); + conv30_cv3_2_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv3_2_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor30_2[] = {conv30_cv2_2_2->getOutput(0), conv30_cv3_2_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_2 = network->addConcatenation(inputTensor30_2, 2); + + // output3 + nvinfer1::IElementWiseLayer* conv30_cv2_3_0 = + convBnSiLU(network, weightMap, *conv29->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.3.0"); + nvinfer1::IElementWiseLayer* conv30_cv2_3_1 = + convBnSiLU(network, weightMap, *conv30_cv2_3_0->getOutput(0), base_in_channel, 3, 1, 1, "model.30.cv2.3.1"); + nvinfer1::IConvolutionLayer* conv30_cv2_3_2 = + network->addConvolution(*conv30_cv2_3_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv2.3.2.weight"], weightMap["model.30.cv2.3.2.bias"]); + conv30_cv2_3_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv2_3_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::IElementWiseLayer* conv30_cv3_3_0 = + convBnSiLU(network, weightMap, *conv29->getOutput(0), base_out_channel, 3, 1, 1, "model.30.cv3.3.0"); + nvinfer1::IElementWiseLayer* conv30_cv3_3_1 = convBnSiLU(network, weightMap, *conv30_cv3_3_0->getOutput(0), + base_out_channel, 3, 1, 1, "model.30.cv3.3.1"); + nvinfer1::IConvolutionLayer* conv30_cv3_3_2 = + network->addConvolution(*conv30_cv3_3_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, + weightMap["model.30.cv3.3.2.weight"], weightMap["model.30.cv3.3.2.bias"]); + conv30_cv3_3_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv30_cv3_3_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor30_3[] = {conv30_cv2_3_2->getOutput(0), conv30_cv3_3_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_3 = network->addConcatenation(inputTensor30_3, 2); + + /******************************************************************************************************* + ********************************************* YOLOV8 DETECT ****************************************** + *******************************************************************************************************/ + nvinfer1::IElementWiseLayer* conv_layers[] = {conv3, conv5, conv7, conv9}; + int strides[sizeof(conv_layers) / sizeof(conv_layers[0])]; + calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides); + int stridesLength = sizeof(strides) / sizeof(int); + + // P3 processing steps (remains unchanged) + nvinfer1::IShuffleLayer* shuffle30_0 = + network->addShuffle(*cat30_0->getOutput(0)); // Reusing the previous cat30_0 as P3 concatenation layer + shuffle30_0->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}); + nvinfer1::ISliceLayer* split30_0_0 = network->addSlice( + *shuffle30_0->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split30_0_1 = network->addSlice( + *shuffle30_0->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kNumClass, (kInputH / strides[0]) * (kInputW / strides[0])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl30_0 = + DFL(network, weightMap, *split30_0_0->getOutput(0), 4, (kInputH / strides[0]) * (kInputW / strides[0]), 1, + 1, 0, "model.30.dfl.conv.weight"); + + // det0 + auto shuffle_conv20 = cv4_conv_combined(network, weightMap, *conv20->getOutput(0), "model.30.cv4.0", + (kInputH / strides[0]) * (kInputW / strides[0]), gw, "pose"); + nvinfer1::ITensor* inputTensor30_dfl_0[] = {dfl30_0->getOutput(0), split30_0_1->getOutput(0), + shuffle_conv20->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_dfl_0 = network->addConcatenation(inputTensor30_dfl_0, 2); + + // P4 processing steps (remains unchanged) + nvinfer1::IShuffleLayer* shuffle30_1 = + network->addShuffle(*cat30_1->getOutput(0)); // Reusing the previous cat30_1 as P4 concatenation layer + shuffle30_1->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}); + nvinfer1::ISliceLayer* split30_1_0 = network->addSlice( + *shuffle30_1->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split30_1_1 = network->addSlice( + *shuffle30_1->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kNumClass, (kInputH / strides[1]) * (kInputW / strides[1])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl30_1 = + DFL(network, weightMap, *split30_1_0->getOutput(0), 4, (kInputH / strides[1]) * (kInputW / strides[1]), 1, + 1, 0, "model.30.dfl.conv.weight"); + + // det1 + auto shuffle_conv23 = cv4_conv_combined(network, weightMap, *conv23->getOutput(0), "model.30.cv4.1", + (kInputH / strides[1]) * (kInputW / strides[1]), gw, "pose"); + nvinfer1::ITensor* inputTensor30_dfl_1[] = {dfl30_1->getOutput(0), split30_1_1->getOutput(0), + shuffle_conv23->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_dfl_1 = network->addConcatenation(inputTensor30_dfl_1, 2); + + // P5 processing steps (remains unchanged) + nvinfer1::IShuffleLayer* shuffle30_2 = + network->addShuffle(*cat30_2->getOutput(0)); // Reusing the previous cat30_2 as P5 concatenation layer + shuffle30_2->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}); + nvinfer1::ISliceLayer* split30_2_0 = network->addSlice( + *shuffle30_2->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split30_2_1 = network->addSlice( + *shuffle30_2->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kNumClass, (kInputH / strides[2]) * (kInputW / strides[2])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl30_2 = + DFL(network, weightMap, *split30_2_0->getOutput(0), 4, (kInputH / strides[2]) * (kInputW / strides[2]), 1, + 1, 0, "model.30.dfl.conv.weight"); + + // det2 + auto shuffle_conv26 = cv4_conv_combined(network, weightMap, *conv26->getOutput(0), "model.30.cv4.2", + (kInputH / strides[2]) * (kInputW / strides[2]), gw, "pose"); + nvinfer1::ITensor* inputTensor30_dfl_2[] = {dfl30_2->getOutput(0), split30_2_1->getOutput(0), + shuffle_conv26->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_dfl_2 = network->addConcatenation(inputTensor30_dfl_2, 2); + + // P6 processing steps + nvinfer1::IShuffleLayer* shuffle30_3 = network->addShuffle(*cat30_3->getOutput(0)); + shuffle30_3->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}); + nvinfer1::ISliceLayer* split30_3_0 = network->addSlice( + *shuffle30_3->getOutput(0), nvinfer1::Dims2{0, 0}, + nvinfer1::Dims2{64, (kInputH / strides[3]) * (kInputW / strides[3])}, nvinfer1::Dims2{1, 1}); + nvinfer1::ISliceLayer* split30_3_1 = network->addSlice( + *shuffle30_3->getOutput(0), nvinfer1::Dims2{64, 0}, + nvinfer1::Dims2{kNumClass, (kInputH / strides[3]) * (kInputW / strides[3])}, nvinfer1::Dims2{1, 1}); + nvinfer1::IShuffleLayer* dfl30_3 = + DFL(network, weightMap, *split30_3_0->getOutput(0), 4, (kInputH / strides[3]) * (kInputW / strides[3]), 1, + 1, 0, "model.30.dfl.conv.weight"); + + // det2 + auto shuffle_conv29 = cv4_conv_combined(network, weightMap, *conv29->getOutput(0), "model.30.cv4.3", + (kInputH / strides[3]) * (kInputW / strides[3]), gw, "pose"); + nvinfer1::ITensor* inputTensor30_dfl_3[] = {dfl30_3->getOutput(0), split30_3_1->getOutput(0), + shuffle_conv29->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat30_dfl_3 = network->addConcatenation(inputTensor30_dfl_3, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer( + network, std::vector{cat30_dfl_0, cat30_dfl_1, cat30_dfl_2, cat30_dfl_3}, + strides, stridesLength, false, false); + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + builder->setMaxBatchSize(kBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); + +#if defined(USE_FP16) + config->setFlag(nvinfer1::BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + auto* calibrator = + new Int8EntropyCalibrator2(1, kInputW, kInputH, "../coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + delete network; + + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + return serialized_model; +} diff --git a/yolov8/yolov8_pose.cpp b/yolov8/yolov8_pose.cpp index 2b354e46..84d35aea 100644 --- a/yolov8/yolov8_pose.cpp +++ b/yolov8/yolov8_pose.cpp @@ -20,7 +20,7 @@ void serialize_engine(std::string& wts_name, std::string& engine_name, int& is_p IHostMemory* serialized_engine = nullptr; if (is_p == 6) { - std::cout << "p6 is not supported right nowe" << std::endl; + serialized_engine = buildEngineYolov8PoseP6(builder, config, DataType::kFLOAT, wts_name, gd, gw, max_channels); } else if (is_p == 2) { std::cout << "p2 is not supported right now" << std::endl; } else { diff --git a/yolov8/yolov8_pose_trt.py b/yolov8/yolov8_pose_trt.py new file mode 100644 index 00000000..f56a61f2 --- /dev/null +++ b/yolov8/yolov8_pose_trt.py @@ -0,0 +1,500 @@ +""" +An example that uses TensorRT's Python api to make inferences. +""" +import ctypes +import os +import shutil +import random +import sys +import threading +import time +import cv2 +import numpy as np +import pycuda.autoinit # noqa: F401 +import pycuda.driver as cuda +import tensorrt as trt + + +CONF_THRESH = 0.5 +IOU_THRESHOLD = 0.4 + +keypoint_pairs = [ + (0, 1), (0, 2), (0, 5), (0, 6), (1, 2), + (1, 3), (2, 4), (5, 6), (5, 7), (5, 11), + (6, 8), (6, 12), (7, 9), (8, 10), (11, 12), + (11, 13), (12, 14), (13, 15), (14, 16) +] + + +def get_img_path_batches(batch_size, img_dir): + ret = [] + batch = [] + for root, dirs, files in os.walk(img_dir): + for name in files: + if len(batch) == batch_size: + ret.append(batch) + batch = [] + batch.append(os.path.join(root, name)) + if len(batch) > 0: + ret.append(batch) + return ret + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + """ + description: Plots one bounding box on image img, + this function comes from YoLov8 project. + param: + x: a box likes [x1,y1,x2,y2] + img: a opencv image object + color: color to draw rectangle, such as (0,255,0) + label: str + line_thickness: int + return: + no return + + """ + tl = ( + line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 + ) # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (c1[0], c1[1] - 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + + +class YoLov8TRT(object): + """ + description: A YOLOv8 class that warps TensorRT ops, preprocess and postprocess ops. + """ + + def __init__(self, engine_file_path): + # Create a Context on this device, + self.ctx = cuda.Device(0).make_context() + stream = cuda.Stream() + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + runtime = trt.Runtime(TRT_LOGGER) + + # Deserialize the engine from file + with open(engine_file_path, "rb") as f: + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + host_inputs = [] + cuda_inputs = [] + host_outputs = [] + cuda_outputs = [] + bindings = [] + + for binding in engine: + print('bingding:', binding, engine.get_binding_shape(binding)) + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + cuda_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(cuda_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + self.input_w = engine.get_binding_shape(binding)[-1] + self.input_h = engine.get_binding_shape(binding)[-2] + host_inputs.append(host_mem) + cuda_inputs.append(cuda_mem) + else: + host_outputs.append(host_mem) + cuda_outputs.append(cuda_mem) + + # Store + self.stream = stream + self.context = context + self.host_inputs = host_inputs + self.cuda_inputs = cuda_inputs + self.host_outputs = host_outputs + self.cuda_outputs = cuda_outputs + self.bindings = bindings + self.batch_size = engine.max_batch_size + self.det_output_size = 89001 + + def infer(self, raw_image_generator): + threading.Thread.__init__(self) + # Make self the active context, pushing it on top of the context stack. + self.ctx.push() + # Restore + stream = self.stream + context = self.context + host_inputs = self.host_inputs + cuda_inputs = self.cuda_inputs + host_outputs = self.host_outputs + cuda_outputs = self.cuda_outputs + bindings = self.bindings + # Do image preprocess + batch_image_raw = [] + batch_origin_h = [] + batch_origin_w = [] + batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) + for i, image_raw in enumerate(raw_image_generator): + input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) + batch_image_raw.append(image_raw) + batch_origin_h.append(origin_h) + batch_origin_w.append(origin_w) + np.copyto(batch_input_image[i], + input_image) + batch_input_image = np.ascontiguousarray(batch_input_image) + + # Copy input image to host buffer + np.copyto(host_inputs[0], batch_input_image.ravel()) + start = time.time() + # Transfer input data to the GPU. + cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) + # Run inference. + context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) + # Synchronize the stream + stream.synchronize() + end = time.time() + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + # Here we use the first row of output in that batch_size = 1 + output = host_outputs[0] + # Do postprocess + for i in range(self.batch_size): + + result_boxes, result_scores, result_classid, keypoints = self.post_process( + output[i * (self.det_output_size): (i + 1) * (self.det_output_size)], + batch_origin_h[i], batch_origin_w[i] + ) + + # Draw rectangles and labels on the original image + for j in range(len(result_boxes)): + box = result_boxes[j] + plot_one_box( + box, + batch_image_raw[i], + label="{}:{:.2f}".format( + categories[int(result_classid[j])], result_scores[j] + ), + ) + + num_keypoints = len(keypoints[j]) // 3 + points = [] + for k in range(num_keypoints): + x = keypoints[j][k * 3] + y = keypoints[j][k * 3 + 1] + confidence = keypoints[j][k * 3 + 2] + if confidence > 0: + points.append((int(x), int(y))) + else: + points.append(None) + + # 根据关键点索引对绘制线条 + for pair in keypoint_pairs: + partA, partB = pair + if points[partA] and points[partB]: + cv2.line(batch_image_raw[i], points[partA], points[partB], (0, 255, 0), 2) + + return batch_image_raw, end - start + + def destroy(self): + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + + def get_raw_image(self, image_path_batch): + """ + description: Read an image from image path + """ + for img_path in image_path_batch: + yield cv2.imread(img_path) + + def get_raw_image_zeros(self, image_path_batch=None): + """ + description: Ready data for warmup + """ + for _ in range(self.batch_size): + yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) + + def preprocess_image(self, raw_bgr_image): + """ + description: Convert BGR image to RGB, + resize and pad it to target size, normalize to [0,1], + transform to NCHW format. + param: + input_image_path: str, image path + return: + image: the processed image + image_raw: the original image + h: original height + w: original width + """ + image_raw = raw_bgr_image + h, w, c = image_raw.shape + image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB) + # Calculate widht and height and paddings + r_w = self.input_w / w + r_h = self.input_h / h + if r_h > r_w: + tw = self.input_w + th = int(r_w * h) + tx1 = tx2 = 0 + ty1 = int((self.input_h - th) / 2) + ty2 = self.input_h - th - ty1 + else: + tw = int(r_h * w) + th = self.input_h + tx1 = int((self.input_w - tw) / 2) + tx2 = self.input_w - tw - tx1 + ty1 = ty2 = 0 + # Resize the image with long side while maintaining ratio + image = cv2.resize(image, (tw, th)) + # Pad the short side with (128,128,128) + image = cv2.copyMakeBorder( + image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128) + ) + image = image.astype(np.float32) + # Normalize to [0,1] + image /= 255.0 + # HWC to CHW format: + image = np.transpose(image, [2, 0, 1]) + # CHW to NCHW format + image = np.expand_dims(image, axis=0) + # Convert the image to row-major order, also known as "C order": + image = np.ascontiguousarray(image) + return image, image_raw, h, w + + def xywh2xyxy_with_keypoints(self, origin_h, origin_w, boxes, keypoints): + + n = len(boxes) + box_array = np.zeros_like(boxes) + keypoint_array = np.zeros_like(keypoints) + r_w = self.input_w / origin_w + r_h = self.input_h / origin_h + for i in range(n): + if r_h > r_w: + box = boxes[i] + lmk = keypoints[i] + box_array[i, 0] = box[0] / r_w + box_array[i, 2] = box[2] / r_w + box_array[i, 1] = (box[1] - (self.input_h - r_w * origin_h) / 2) / r_w + box_array[i, 3] = (box[3] - (self.input_h - r_w * origin_h) / 2) / r_w + + for j in range(0, len(lmk), 3): + keypoint_array[i, j] = lmk[j] / r_w + keypoint_array[i, j + 1] = (lmk[j + 1] - (self.input_h - r_w * origin_h) / 2) / r_w + keypoint_array[i, j + 2] = lmk[j + 2] + else: + + box = boxes[i] + lmk = keypoints[i] + + box_array[i, 0] = (box[0] - (self.input_w - r_h * origin_w) / 2) / r_h + box_array[i, 2] = (box[2] - (self.input_w - r_h * origin_w) / 2) / r_h + box_array[i, 1] = box[1] / r_h + box_array[i, 3] = box[3] / r_h + + for j in range(0, len(lmk), 3): + keypoint_array[i, j] = (lmk[j] - (self.input_w - r_h * origin_w) / 2) / r_h + keypoint_array[i, j + 1] = lmk[j + 1] / r_h + keypoint_array[i, j + 2] = lmk[j + 2] + + return box_array, keypoint_array + + def post_process(self, output, origin_h, origin_w): + """ + description: Post-process the prediction to include pose keypoints + param: + output: A numpy array like [num_boxes, cx, cy, w, h, conf, + cls_id, px1, py1, pconf1,...px17, py17, pconf17] where p denotes pose keypoint + origin_h: Height of original image + origin_w: Width of original image + return: + result_boxes: Final boxes, a numpy array, each row is a box [x1, y1, x2, y2] + result_scores: Final scores, a numpy array, each element is the score corresponding to box + result_classid: Final classID, a numpy array, each element is the classid corresponding to box + result_keypoints: Final keypoints, a list of numpy arrays, + each element represents keypoints for a box, shaped as (#keypoints, 3) + """ + # Number of values per detection: 38 base values + 17 keypoints * 3 values each + num_values_per_detection = 38 + 17 * 3 + # Get the number of boxes detected + num = int(output[0]) + # Reshape to a two-dimensional ndarray with the full detection shape + pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] + + # Perform non-maximum suppression to filter the detections + boxes = self.non_max_suppression( + pred[:, :num_values_per_detection], origin_h, origin_w, + conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD) + + # Extract the bounding boxes, confidence scores, and class IDs + result_boxes = boxes[:, :4] if len(boxes) else np.array([]) + result_scores = boxes[:, 4] if len(boxes) else np.array([]) + result_classid = boxes[:, 5] if len(boxes) else np.array([]) + result_keypoints = boxes[:, -51:] if len(boxes) else np.array([]) + + # Return the post-processed results including keypoints + return result_boxes, result_scores, result_classid, result_keypoints + + def bbox_iou(self, box1, box2, x1y1x2y2=True): + """ + description: compute the IoU of two bounding boxes + param: + box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h)) + box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h)) + x1y1x2y2: select the coordinate format + return: + iou: computed iou + """ + if not x1y1x2y2: + # Transform from center and width to exact coordinates + b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 + b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 + b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 + b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 + else: + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + + # Get the coordinates of the intersection rectangle + inter_rect_x1 = np.maximum(b1_x1, b2_x1) + inter_rect_y1 = np.maximum(b1_y1, b2_y1) + inter_rect_x2 = np.minimum(b1_x2, b2_x2) + inter_rect_y2 = np.minimum(b1_y2, b2_y2) + # Intersection area + inter_area = np.clip( + inter_rect_x2 - inter_rect_x1 + 1, 0, None) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None) + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + + iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) + + return iou + + def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4): + """ + description: Removes detections with lower object confidence score than 'conf_thres' and performs + Non-Maximum Suppression to further filter detections. + param: + prediction: detections, (x1, y1, x2, y2, conf, cls_id) + origin_h: original image height + origin_w: original image width + conf_thres: a confidence threshold to filter detections + nms_thres: a iou threshold to filter detections + return: + boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id) + """ + # Get the boxes that score > CONF_THRESH + boxes = prediction[prediction[:, 4] >= conf_thres] + # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2] + res_array = np.copy(boxes) + box_pred_deep_copy = np.copy(boxes[:, :4]) + keypoints_pred_deep_copy = np.copy(boxes[:, -51:]) + res_box, res_keypoints = self.xywh2xyxy_with_keypoints( + origin_h, origin_w, box_pred_deep_copy, keypoints_pred_deep_copy) + res_array[:, :4] = res_box + res_array[:, -51:] = res_keypoints + # clip the coordinates + res_array[:, 0] = np.clip(res_array[:, 0], 0, origin_w - 1) + res_array[:, 2] = np.clip(res_array[:, 2], 0, origin_w - 1) + res_array[:, 1] = np.clip(res_array[:, 1], 0, origin_h - 1) + res_array[:, 3] = np.clip(res_array[:, 3], 0, origin_h - 1) + # Object confidence + confs = res_array[:, 4] + # Sort by the confs + res_array = res_array[np.argsort(-confs)] + # Perform non-maximum suppression + keep_res_array = [] + while res_array.shape[0]: + large_overlap = self.bbox_iou(np.expand_dims(res_array[0, :4], 0), res_array[:, :4]) > nms_thres + label_match = res_array[0, 5] == res_array[:, 5] + invalid = large_overlap & label_match + keep_res_array.append(res_array[0]) + res_array = res_array[~invalid] + + res_array = np.stack(keep_res_array, 0) if len(keep_res_array) else np.array([]) + return res_array + + +class inferThread(threading.Thread): + def __init__(self, yolov8_wrapper, image_path_batch): + threading.Thread.__init__(self) + self.yolov8_wrapper = yolov8_wrapper + self.image_path_batch = image_path_batch + + def run(self): + batch_image_raw, use_time = self.yolov8_wrapper.infer(self.yolov8_wrapper.get_raw_image(self.image_path_batch)) + for i, img_path in enumerate(self.image_path_batch): + parent, filename = os.path.split(img_path) + save_name = os.path.join('output', filename) + # Save image + + cv2.imwrite(save_name, batch_image_raw[i]) + print('input->{}, time->{:.2f}ms, saving into output/'.format(self.image_path_batch, use_time * 1000)) + + +class warmUpThread(threading.Thread): + def __init__(self, yolov8_wrapper): + threading.Thread.__init__(self) + self.yolov8_wrapper = yolov8_wrapper + + def run(self): + batch_image_raw, use_time = self.yolov8_wrapper.infer(self.yolov8_wrapper.get_raw_image_zeros()) + print('warm_up->{}, time->{:.2f}ms'.format(batch_image_raw[0].shape, use_time * 1000)) + + +if __name__ == "__main__": + # load custom plugin and engine + PLUGIN_LIBRARY = "./build/libmyplugins.so" + engine_file_path = "yolov8n-pose.engine" + + if len(sys.argv) > 1: + engine_file_path = sys.argv[1] + if len(sys.argv) > 2: + PLUGIN_LIBRARY = sys.argv[2] + + ctypes.CDLL(PLUGIN_LIBRARY) + + # load coco labels + + categories = ["person"] + + if os.path.exists('output/'): + shutil.rmtree('output/') + os.makedirs('output/') + # a YoLov8TRT instance + yolov8_wrapper = YoLov8TRT(engine_file_path) + try: + print('batch size is', yolov8_wrapper.batch_size) + + image_dir = "samples/" + image_path_batches = get_img_path_batches(yolov8_wrapper.batch_size, image_dir) + + for i in range(10): + # create a new thread to do warm_up + thread1 = warmUpThread(yolov8_wrapper) + thread1.start() + thread1.join() + for batch in image_path_batches: + # create a new thread to do inference + thread1 = inferThread(yolov8_wrapper, batch) + thread1.start() + thread1.join() + finally: + # destroy the instance + yolov8_wrapper.destroy()