From ed1457a3919a469061bde3ad39d69dc1f7c3913a Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 8 Jan 2024 12:55:16 +0100 Subject: [PATCH] update loss tf and pt --- .../differentiable_binarization/pytorch.py | 92 +++++++------------ .../differentiable_binarization/tensorflow.py | 50 +++++----- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 611308cb0b..2230b73875 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -213,7 +213,15 @@ def forward( return out - def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor: + def compute_loss( + self, + out_map: torch.Tensor, + thresh_map: torch.Tensor, + target: List[np.ndarray], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> torch.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -222,6 +230,9 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: out_map: output feature map of the model of shape (N, C, H, W) thresh_map: threshold map of shape (N, C, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -229,9 +240,7 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: """ prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) - binary_mask = torch.reciprocal(1.0 + torch.exp(-50 * (prob_map - thresh_map))) - # TODO: needs also some checks again shrunk masks targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) @@ -239,64 +248,33 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) - balanced_bce_loss = torch.zeros(1, device=out_map.device) + focal_loss = torch.zeros(1, device=out_map.device) dice_loss = torch.zeros(1, device=out_map.device) l1_loss = torch.zeros(1, device=out_map.device) - - # TODO: Still in progress @Oliver @Charles - if torch.any(seg_mask): - # Compute balanced bce loss - bce_loss = F.binary_cross_entropy_with_logits( - out_map, - seg_target, - reduction="none", - ) - - positive = (seg_target * seg_mask).float() - negative = ((1 - seg_target) * seg_mask).float() - positive_count = int(positive.sum()) - if positive_count == 0: - negative_count = min(int(negative.sum()), 0) - else: - negative_count = min(int(negative.sum()), int(positive_count * 3)) - - positive_loss = bce_loss * positive - negative_loss = bce_loss * negative - - negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) - - balanced_bce_loss = ( - (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + 1e-6) * 5.0 - ) - - # Compute dice loss - binary_mask = binary_mask.contiguous().view(binary_mask.size(0), -1) - seg_target = seg_target.contiguous().view(seg_target.size(0), -1) - - seg_mask = seg_mask.contiguous().view(seg_mask.size(0), -1) - binary_mask = binary_mask * seg_mask - seg_target = seg_target * seg_mask - - dice_coeff = (2 * (binary_mask * seg_target).sum()) / (binary_mask.sum() + seg_target.sum() + 1e-6) - - dice_loss = 1 - dice_coeff - + # Focal loss + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target) + alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) + # Unreduced version + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum() + + # Dice loss + inter = (seg_mask * prob_map * seg_target).sum() + cardinality = (seg_mask * (prob_map + seg_target)).sum() + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) + + # Compute l1 loss for thresh_map + l1_scale = 10.0 if torch.any(thresh_mask): - # Compute l1 loss - assert thresh_map.size() == thresh_target.size() and thresh_target.numel() > 0 - assert thresh_mask.size() == thresh_target.size() - x = thresh_map * thresh_mask - y = thresh_target * thresh_mask - - loss = torch.zeros_like(thresh_target) - diff = torch.abs(x - y) - mask_beta = diff < 1 - loss[mask_beta] = 0.5 * torch.square(diff)[mask_beta] / 1 - loss[~mask_beta] = diff[~mask_beta] - 0.5 * 1 - l1_loss = loss.sum() / (thresh_mask.sum() + 1e-6) * 10.0 - - return (balanced_bce_loss + dice_loss + l1_loss) / 3 + l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + + return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value] def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 88265cb9e4..9e58356544 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -166,6 +166,9 @@ def compute_loss( out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[Dict[str, np.ndarray]], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, ) -> tf.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -175,6 +178,9 @@ def compute_loss( out_map: output feature map of the model of shape (N, H, W, C) thresh_map: threshold map of shape (N, H, W, C) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -186,33 +192,27 @@ def compute_loss( seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 - bce_loss = tf.keras.losses.binary_crossentropy( - seg_target[..., None], - out_map[..., None], - from_logits=True, - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = tf.math.reduce_sum(seg_target[seg_mask]) - negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) - negative_loss = bce_loss * neg_target - negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) - sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) - - bce_min = tf.math.reduce_min(bce_loss) - weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 - inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) - union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 - dice_loss = 1 - 2.0 * inter / union + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor + p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Dice loss + inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3)) + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 @@ -221,7 +221,7 @@ def compute_loss( else: l1_loss = tf.constant(0.0) - return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + return l1_scale * l1_loss + focal_loss + dice_loss def call( self,