Skip to content

Commit ef4dde4

Browse files
authored
main code
fix gain for train_aux
1 parent 4cebf40 commit ef4dde4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

utils/loss.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,7 @@ def find_3_positive(self, p, targets):
11151115
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
11161116
na, nt = self.na, targets.shape[0] # number of anchors, targets
11171117
indices, anch = [], []
1118-
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
1118+
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
11191119
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
11201120
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
11211121

@@ -1561,7 +1561,7 @@ def find_5_positive(self, p, targets):
15611561
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
15621562
na, nt = self.na, targets.shape[0] # number of anchors, targets
15631563
indices, anch = [], []
1564-
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
1564+
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
15651565
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
15661566
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
15671567

@@ -1614,7 +1614,7 @@ def find_3_positive(self, p, targets):
16141614
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
16151615
na, nt = self.na, targets.shape[0] # number of anchors, targets
16161616
indices, anch = [], []
1617-
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
1617+
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
16181618
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
16191619
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
16201620

0 commit comments

Comments
 (0)