Skip to content

Commit

Permalink
Modify get_rect_adapt_landmark to adaptively decode any number of key…
Browse files Browse the repository at this point in the history
… points abd repair gen_wts.py
  • Loading branch information
lindsayshuo committed May 7, 2024
1 parent 6d36a3c commit a3ceb43
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions yolov8/gen_wts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import sys # noqa: F401
import argparse
import os
import struct
Expand All @@ -12,7 +12,7 @@ def parse_args():
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg'],
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
Expand All @@ -39,7 +39,7 @@ def parse_args():
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32

if m_type in ['detect', 'seg']:
if m_type in ['detect', 'seg', 'pose']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

delattr(model.model[-1], 'anchors')
Expand Down
8 changes: 4 additions & 4 deletions yolov8/src/postprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cv::Rect get_rect(cv::Mat& img, float bbox[4]) {
return cv::Rect(round(l), round(t), round(r - l), round(b - t));
}

cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) {
cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[kNumberOfPoints * 3]) {
int l, r, t, b;
float r_w = kInputW / (img.cols * 1.0);
float r_h = kInputH / (img.rows * 1.0);
Expand All @@ -37,7 +37,7 @@ cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) {
r = bbox[2] / r_w;
t = (bbox[1] - (kInputH - r_w * img.rows) / 2) / r_w;
b = (bbox[3] - (kInputH - r_w * img.rows) / 2) / r_w;
for (int i = 0; i < 51; i += 3) {
for (int i = 0; i < kNumberOfPoints * 3; i += 3) {
lmk[i] /= r_w;
lmk[i + 1] = (lmk[i + 1] - (kInputH - r_w * img.rows) / 2) / r_w;
// lmk[i + 2]
Expand All @@ -47,7 +47,7 @@ cv::Rect get_rect_adapt_landmark(cv::Mat& img, float bbox[4], float lmk[51]) {
r = (bbox[2] - (kInputW - r_h * img.cols) / 2) / r_h;
t = bbox[1] / r_h;
b = bbox[3] / r_h;
for (int i = 0; i < 51; i += 3) {
for (int i = 0; i < kNumberOfPoints * 3; i += 3) {
lmk[i] = (lmk[i] - (kInputW - r_h * img.cols) / 2) / r_h;
lmk[i + 1] /= r_h;
// lmk[i + 2]
Expand Down Expand Up @@ -172,7 +172,7 @@ void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::
cv::putText(img, std::to_string((int)res[j].class_id), cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 1.2,
cv::Scalar(0xFF, 0xFF, 0xFF), 2);

for (int k = 0; k < 51; k += 3) {
for (int k = 0; k < kNumberOfPoints * 3; k += 3) {
if (res[j].keypoints[k + 2] > 0.5) {
cv::circle(img, cv::Point((int)res[j].keypoints[k], (int)res[j].keypoints[k + 1]), 3,
cv::Scalar(0, 0x27, 0xC1), -1);
Expand Down

0 comments on commit a3ceb43

Please sign in to comment.