Skip to content

Commit

Permalink
Dice loss modified to eliminate the batch size affects (issue ivadome…
Browse files Browse the repository at this point in the history
  • Loading branch information
gbazad93 authored Aug 16, 2021
1 parent 93b9e73 commit 69428b0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ivadomed/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(self, smooth=1.0):
self.smooth = smooth

def forward(self, prediction, target):
iflat = prediction.reshape(-1)
tflat = target.reshape(-1)
intersection = (iflat * tflat).sum()
iflat = prediction.reshape(prediction.shape[0], -1)
tflat = target.reshape(target.shape[0], -1)
intersection = (iflat * tflat).sum(dim = 1)

return - (2.0 * intersection + self.smooth) / (iflat.sum() + tflat.sum() + self.smooth)
return - ((2.0 * intersection + self.smooth) / (iflat.sum(dim = 1) + tflat.sum(dim = 1) + self.smooth)).mean()


class BinaryCrossEntropyLoss(nn.Module):
Expand Down

0 comments on commit 69428b0

Please sign in to comment.