Skip to content

Commit 2267955

Browse files
authored
main code
update loss
1 parent ef4dde4 commit 2267955

File tree

1 file changed

+42
-18
lines changed

1 file changed

+42
-18
lines changed

utils/loss.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -1102,12 +1102,20 @@ def build_targets(self, p, targets, imgs):
11021102
matching_anchs[i].append(all_anch[layer_idx])
11031103

11041104
for i in range(nl):
1105-
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1106-
matching_as[i] = torch.cat(matching_as[i], dim=0)
1107-
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1108-
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1109-
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1110-
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1105+
if matching_targets[i] != []:
1106+
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1107+
matching_as[i] = torch.cat(matching_as[i], dim=0)
1108+
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1109+
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1110+
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1111+
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1112+
else:
1113+
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1114+
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1115+
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1116+
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1117+
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1118+
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
11111119

11121120
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
11131121

@@ -1403,12 +1411,20 @@ def build_targets(self, p, targets, imgs):
14031411
matching_anchs[i].append(all_anch[layer_idx])
14041412

14051413
for i in range(nl):
1406-
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1407-
matching_as[i] = torch.cat(matching_as[i], dim=0)
1408-
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1409-
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1410-
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1411-
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1414+
if matching_targets[i] != []:
1415+
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1416+
matching_as[i] = torch.cat(matching_as[i], dim=0)
1417+
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1418+
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1419+
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1420+
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1421+
else:
1422+
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1423+
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1424+
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1425+
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1426+
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1427+
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
14121428

14131429
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
14141430

@@ -1548,12 +1564,20 @@ def build_targets2(self, p, targets, imgs):
15481564
matching_anchs[i].append(all_anch[layer_idx])
15491565

15501566
for i in range(nl):
1551-
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1552-
matching_as[i] = torch.cat(matching_as[i], dim=0)
1553-
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1554-
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1555-
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1556-
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1567+
if matching_targets[i] != []:
1568+
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
1569+
matching_as[i] = torch.cat(matching_as[i], dim=0)
1570+
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
1571+
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
1572+
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
1573+
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
1574+
else:
1575+
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1576+
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1577+
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1578+
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1579+
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
1580+
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
15571581

15581582
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
15591583

0 commit comments

Comments
 (0)