Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug for multilabel training #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def forward(self, x):

bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
x_det = x[i][..., :6].clone().detach()
x_kpt = x[i][..., 6:].clone().detach()
x_det = x[i][..., :5+self.nc].clone().detach()
x_kpt = x[i][..., 5+self.nc:].clone().detach()

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
Expand All @@ -295,12 +295,13 @@ def forward(self, x):
x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid()

y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1)

else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
if self.nkpt != 0:
y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i] # xy

y = torch.cat((xy, wh, y[..., 4:]), -1)

z.append(y.view(bs, -1, self.no))
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test(data,
if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
continue

# Predictions
if single_cls:
pred[:, 5] = 0
Expand Down
19 changes: 7 additions & 12 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
if nc is None:
nc = prediction.shape[2] - 5 if not kpt_label else prediction.shape[2] - 5 - kpt_label * 3 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
#max_det = 300 # maximum number of detections per image
Expand All @@ -524,6 +523,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
merge = False # use merge-NMS

t = time.time()
ki = 5 + nc # kpt start index
output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
Expand All @@ -544,23 +544,18 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
continue

# Compute conf
x[:, 5:5+nc] *= x[:, 4:5] # conf = obj_conf * cls_conf
x[:, 5:ki] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])

kpt_label = x[:, ki:] # zero columns if no kpt
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
i, j = (x[:, 5:ki] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float(), kpt_label[i]), 1)
else: # best class only
if not kpt_label:
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
else:
kpts = x[:, 6:]
conf, j = x[:, 5:6].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres]
conf, j = x[:, 5:ki].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), kpt_label), 1)[conf.view(-1) > conf_thres]


# Filter by class
Expand Down
11 changes: 5 additions & 6 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def __call__(self, p, targets): # predictions, targets, model
lbox += (1.0 - iou).mean() # iou loss
if self.kpt_label:
#Direct kpt prediction
pkpt_x = ps[:, 6::3] * 2. - 0.5
pkpt_y = ps[:, 7::3] * 2. - 0.5
pkpt_score = ps[:, 8::3]
pkpt_x = ps[:, self.nc+5::3] * 2. - 0.5
pkpt_y = ps[:, self.nc+6::3] * 2. - 0.5
pkpt_score = ps[:, self.nc+7::3]
#mask
kpt_mask = (tkpt[i][:, 0::2] != 0)
lkptv += self.BCEcls(pkpt_score, kpt_mask.float())
Expand All @@ -177,10 +177,9 @@ def __call__(self, p, targets): # predictions, targets, model

# Classification
if self.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
t = torch.full_like(ps[:, 5:self.nc+5], self.cn, device=device) # targets
t[range(n), tcls[i]] = self.cp
lcls += self.BCEcls(ps[:, 5:], t) # BCE

lcls += self.BCEcls(ps[:, 5:self.nc+5], t) # BCE
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
Expand Down