From b3909cbdb124335baec2194ab95691cc6900ee87 Mon Sep 17 00:00:00 2001 From: Babak Azad Date: Wed, 18 Aug 2021 08:18:04 +0200 Subject: [PATCH] Sample-wise normalization added as an option --- ivadomed/losses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivadomed/losses.py b/ivadomed/losses.py index 662de24da..fd400e2c1 100644 --- a/ivadomed/losses.py +++ b/ivadomed/losses.py @@ -51,11 +51,14 @@ class DiceLoss(nn.Module): Attributes: smooth (float): Value to avoid division by zero when images and predictions are empty. """ - def __init__(self, smooth=1.0): + def __init__(self, smooth=1.0, sample_wise=False): super(DiceLoss, self).__init__() self.smooth = smooth + self.sample_wise = sample_wise def forward(self, prediction, target): + if not self.sample_wise: + prediction, target = prediction.unsqueeze(dim=0), target.unsqueeze(dim=0) iflat = prediction.reshape(prediction.shape[0], -1) tflat = target.reshape(target.shape[0], -1) intersection = (iflat * tflat).sum(dim = 1)