From a3ceb43b08ed9f4869db223ce3681b67216611be Mon Sep 17 00:00:00 2001 From: lindsayshuo Date: Tue, 7 May 2024 14:27:39 +0800 Subject: [PATCH] Modify get_rect_adapt_landmark to adaptively decode any number of key points abd repair gen_wts.py --- yolov8/gen_wts.py | 6 +++--- yolov8/src/postprocess.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/yolov8/gen_wts.py b/yolov8/gen_wts.py index bad2c281..5f037db2 100644 --- a/yolov8/gen_wts.py +++ b/yolov8/gen_wts.py @@ -1,4 +1,4 @@ -import sys +import sys # noqa: F401 import argparse import os import struct @@ -12,7 +12,7 @@ def parse_args(): parser.add_argument( '-o', '--output', help='Output (.wts) file path (optional)') parser.add_argument( - '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg'], + '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'], help='determines the model is detection/classification') args = parser.parse_args() if not os.path.isfile(args.weights): @@ -39,7 +39,7 @@ def parse_args(): # Load model model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 -if m_type in ['detect', 'seg']: +if m_type in ['detect', 'seg', 'pose']: anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] delattr(model.model[-1], 'anchors') diff --git a/yolov8/src/postprocess.cpp b/yolov8/src/postprocess.cpp index 25f86cd3..a7869145 100644 --- a/yolov8/src/postprocess.cpp +++ b/yolov8/src/postprocess.cpp @@ -28,7 +28,7 @@ cv::Rect get_rect(cv::Mat& img, float bbox[4]) { return cv::Rect(round(l), round(t), round(r - l), round(b - t)); } -cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) { +cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[kNumberOfPoints * 3]) { int l, r, t, b; float r_w = kInputW / (img.cols * 1.0); float r_h = kInputH / (img.rows * 1.0); @@ -37,7 +37,7 @@ cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) { r = bbox[2] / r_w; t = (bbox[1] - (kInputH - r_w * img.rows) / 2) / r_w; b = (bbox[3] - (kInputH - r_w * img.rows) / 2) / r_w; - for (int i = 0; i < 51; i += 3) { + for (int i = 0; i < kNumberOfPoints * 3; i += 3) { lmk[i] /= r_w; lmk[i + 1] = (lmk[i + 1] - (kInputH - r_w * img.rows) / 2) / r_w; // lmk[i + 2] @@ -47,7 +47,7 @@ cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) { r = (bbox[2] - (kInputW - r_h * img.cols) / 2) / r_h; t = bbox[1] / r_h; b = bbox[3] / r_h; - for (int i = 0; i < 51; i += 3) { + for (int i = 0; i < kNumberOfPoints * 3; i += 3) { lmk[i] = (lmk[i] - (kInputW - r_h * img.cols) / 2) / r_h; lmk[i + 1] /= r_h; // lmk[i + 2] @@ -172,7 +172,7 @@ void draw_bbox_keypoints_line(std::vector& img_batch, std::vector 0.5) { cv::circle(img, cv::Point((int)res[j].keypoints[k], (int)res[j].keypoints[k + 1]), 3, cv::Scalar(0, 0x27, 0xC1), -1);