Skip to content

Commit

Permalink
fix strides (#1492)
Browse files Browse the repository at this point in the history
* yolov8 p2

* yolov8 p2

* yolov8 p2

* yolov8 p2

* Update yolov8_det.cpp

* Update model.cpp

* Update model.cpp

* Update model.cpp

* Update model.cpp

* fix strides
  • Loading branch information
lindsayshuo authored Apr 23, 2024
1 parent d1a184f commit d033a63
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions yolov8/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ static int get_depth(int x, float gd) {
return std::max<int>(r, 1);
}

void calculateStrides(nvinfer1::IElementWiseLayer* conv_layers[], int size, int reference_size, int strides[]) {
for (int i = 0; i < size; ++i) {
nvinfer1::ILayer* layer = conv_layers[i];
nvinfer1::Dims dims = layer->getOutput(0)->getDimensions();
int feature_map_size = dims.d[1];
strides[i] = reference_size / feature_map_size;
}
}

static nvinfer1::IElementWiseLayer* Proto(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input,
std::string lname, float gw, int max_channels) {
Expand Down Expand Up @@ -220,7 +229,9 @@ nvinfer1::IHostMemory* buildEngineYolov8Det(nvinfer1::IBuilder* builder, nvinfer
********************************************* YOLOV8 DETECT ******************************************
*******************************************************************************************************/

int strides[] = {8, 16, 32};
nvinfer1::IElementWiseLayer* conv_layers[] = {conv3, conv5, conv7};
int strides[sizeof(conv_layers) / sizeof(conv_layers[0])];
calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides);
int stridesLength = sizeof(strides) / sizeof(int);

nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0));
Expand Down Expand Up @@ -507,7 +518,9 @@ nvinfer1::IHostMemory* buildEngineYolov8DetP6(nvinfer1::IBuilder* builder, nvinf
/*******************************************************************************************************
********************************************* YOLOV8 DETECT ******************************************
*******************************************************************************************************/
int strides[] = {8, 16, 32, 64};
nvinfer1::IElementWiseLayer* conv_layers[] = {conv3, conv5, conv7, conv9};
int strides[sizeof(conv_layers) / sizeof(conv_layers[0])];
calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides);
int stridesLength = sizeof(strides) / sizeof(int);

// P3 processing steps (remains unchanged)
Expand Down Expand Up @@ -817,7 +830,9 @@ nvinfer1::IHostMemory* buildEngineYolov8DetP2(nvinfer1::IBuilder* builder, nvinf
********************************************* YOLOV8 DETECT ******************************************
*******************************************************************************************************/

int strides[] = {4, 8, 16, 32};
nvinfer1::IElementWiseLayer* conv_layers[] = {conv1, conv3, conv5, conv7};
int strides[sizeof(conv_layers) / sizeof(conv_layers[0])];
calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides);
int stridesLength = sizeof(strides) / sizeof(int);

// P2 processing steps (remains unchanged)
Expand Down Expand Up @@ -1148,7 +1163,9 @@ nvinfer1::IHostMemory* buildEngineYolov8Seg(nvinfer1::IBuilder* builder, nvinfer
********************************************* YOLOV8 DETECT ******************************************
*******************************************************************************************************/

int strides[] = {8, 16, 32};
nvinfer1::IElementWiseLayer* conv_layers[] = {conv3, conv5, conv7};
int strides[sizeof(conv_layers) / sizeof(conv_layers[0])];
calculateStrides(conv_layers, sizeof(conv_layers) / sizeof(conv_layers[0]), kInputH, strides);
int stridesLength = sizeof(strides) / sizeof(int);

nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0));
Expand Down

0 comments on commit d033a63

Please sign in to comment.