diff --git a/models/yolo.py b/models/yolo.py index 0e37e1e..f4fde4b 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -272,8 +272,8 @@ def forward(self, x): bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() - x_det = x[i][..., :6].clone().detach() - x_kpt = x[i][..., 6:].clone().detach() + x_det = x[i][..., :5+self.nc].clone().detach() + x_kpt = x[i][..., 5+self.nc:].clone().detach() if not self.training: # inference if self.grid[i].shape[2:4] != x[i].shape[2:4]: @@ -295,12 +295,13 @@ def forward(self, x): x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid() y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1) - + else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh if self.nkpt != 0: y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i] # xy + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) diff --git a/test.py b/test.py index 8cc53e4..c8f51e1 100644 --- a/test.py +++ b/test.py @@ -166,7 +166,7 @@ def test(data, if nl: stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) continue - + # Predictions if single_cls: pred[:, 5] = 0 diff --git a/utils/general.py b/utils/general.py index e90c3b5..d724536 100644 --- a/utils/general.py +++ b/utils/general.py @@ -513,7 +513,6 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non if nc is None: nc = prediction.shape[2] - 5 if not kpt_label else prediction.shape[2] - 5 - kpt_label * 3 # number of classes xc = prediction[..., 4] > conf_thres # candidates - # Settings min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height #max_det = 300 # maximum number of detections per image @@ -524,6 +523,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non merge = False # use merge-NMS t = time.time() + ki = 5 + nc # kpt start index output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0] for xi, x in enumerate(prediction): # image index, image inference # Apply constraints @@ -544,23 +544,18 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non continue # Compute conf - x[:, 5:5+nc] *= x[:, 4:5] # conf = obj_conf * cls_conf + x[:, 5:ki] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) - + kpt_label = x[:, ki:] # zero columns if no kpt # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T - x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + i, j = (x[:, 5:ki] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float(), kpt_label[i]), 1) else: # best class only - if not kpt_label: - conf, j = x[:, 5:].max(1, keepdim=True) - x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] - else: - kpts = x[:, 6:] - conf, j = x[:, 5:6].max(1, keepdim=True) - x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] + conf, j = x[:, 5:ki].max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), kpt_label), 1)[conf.view(-1) > conf_thres] # Filter by class diff --git a/utils/loss.py b/utils/loss.py index e481e6d..17cd850 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -163,9 +163,9 @@ def __call__(self, p, targets): # predictions, targets, model lbox += (1.0 - iou).mean() # iou loss if self.kpt_label: #Direct kpt prediction - pkpt_x = ps[:, 6::3] * 2. - 0.5 - pkpt_y = ps[:, 7::3] * 2. - 0.5 - pkpt_score = ps[:, 8::3] + pkpt_x = ps[:, self.nc+5::3] * 2. - 0.5 + pkpt_y = ps[:, self.nc+6::3] * 2. - 0.5 + pkpt_score = ps[:, self.nc+7::3] #mask kpt_mask = (tkpt[i][:, 0::2] != 0) lkptv += self.BCEcls(pkpt_score, kpt_mask.float()) @@ -177,10 +177,9 @@ def __call__(self, p, targets): # predictions, targets, model # Classification if self.nc > 1: # cls loss (only if multiple classes) - t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets + t = torch.full_like(ps[:, 5:self.nc+5], self.cn, device=device) # targets t[range(n), tcls[i]] = self.cp - lcls += self.BCEcls(ps[:, 5:], t) # BCE - + lcls += self.BCEcls(ps[:, 5:self.nc+5], t) # BCE # Append targets to text file # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]