From bc32daa837e400a44706c62366540de67bab0939 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 1 Dec 2023 15:23:52 +0100 Subject: [PATCH 01/13] up --- .../differentiable_binarization/base.py | 36 ++++++++++--------- .../differentiable_binarization/pytorch.py | 2 +- .../differentiable_binarization/tensorflow.py | 2 +- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index acb0bb3144..3a62eed3a3 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -201,9 +201,10 @@ def compute_distance( square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1]) square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1]) cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps) + cosin = np.clip(cosin, -1.0, 1.0) square_sin = 1 - np.square(cosin) square_sin = np.nan_to_num(square_sin) - result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist) + result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps) result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0] return result @@ -265,7 +266,10 @@ def draw_thresh_map( # Fill the canvas with the distances computed inside the valid padded polygon canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax( - 1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1], + 1 + - distance_map[ + ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width + ], canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1], ) @@ -274,7 +278,7 @@ def draw_thresh_map( def build_target( self, target: List[Dict[str, np.ndarray]], - output_shape: Tuple[int, int, int, int], + output_shape: Tuple[int, int, int], channels_last: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): @@ -284,12 +288,14 @@ def build_target( input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32 + h: int + w: int if channels_last: - h, w = output_shape[1:-1] - target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w) + h, w, num_classes = output_shape else: - h, w = output_shape[-2:] - target_shape = output_shape # (Batch_size, num_classes, h, w) + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) @@ -300,7 +306,6 @@ def build_target( # Draw each polygon on gt if _tgt.shape[0] == 0: # Empty image, full masked - # seg_mask[idx, :, :, class_idx] = False seg_mask[idx, class_idx] = False # Absolute bounding boxes @@ -326,10 +331,9 @@ def build_target( ) boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) - for box, box_size, poly in zip(abs_boxes, boxes_size, polys): + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): # Mask boxes that are too small if box_size < self.min_size_box: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue @@ -339,19 +343,17 @@ def build_target( subject = [tuple(coor) for coor in poly] padding = pyclipper.PyclipperOffset() padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - shrinked = padding.Execute(-distance) + shrunken = padding.Execute(-distance) # Draw polygon on gt if it is valid - if len(shrinked) == 0: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + if len(shrunken) == 0: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - shrinked = np.array(shrinked[0]).reshape(-1, 2) - if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Draw on both thresh map and thresh mask poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 99adcd0e05..a13a2887e7 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -230,7 +230,7 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) - targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type] + targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 21943f9479..88265cb9e4 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -183,7 +183,7 @@ def compute_loss( prob_map = tf.math.sigmoid(out_map) thresh_map = tf.math.sigmoid(thresh_map) - seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True) + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) From de80ad854af085574f79622958ad2066751c3e5b Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 4 Dec 2023 11:24:21 +0100 Subject: [PATCH 02/13] update --- doctr/models/detection/differentiable_binarization/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 3a62eed3a3..0d261d2991 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -299,7 +299,7 @@ def build_target( seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) - thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8) + thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8) for idx, tgt in enumerate(target): for class_idx, _tgt in enumerate(tgt.values()): From 9ca46db5a4d5f1eff85810b5e4ce3f4506a27695 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 4 Dec 2023 19:21:49 +0100 Subject: [PATCH 03/13] up --- .../differentiable_binarization/pytorch.py | 77 ++++++++++++------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index a13a2887e7..611308cb0b 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -229,7 +229,9 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: """ prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) + binary_mask = torch.reciprocal(1.0 + torch.exp(-50 * (prob_map - thresh_map))) + # TODO: needs also some checks again shrunk masks targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) @@ -237,41 +239,64 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 balanced_bce_loss = torch.zeros(1, device=out_map.device) dice_loss = torch.zeros(1, device=out_map.device) l1_loss = torch.zeros(1, device=out_map.device) + + # TODO: Still in progress @Oliver @Charles + if torch.any(seg_mask): + # Compute balanced bce loss bce_loss = F.binary_cross_entropy_with_logits( out_map, seg_target, reduction="none", - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = seg_target[seg_mask].sum() - negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count) - negative_loss = bce_loss * neg_target - negative_loss = negative_loss.sort().values[-int(negative_count.item()) :] - sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) - - bce_min = bce_loss.min() - weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0 - inter = torch.sum(bin_map * seg_target[seg_mask] * weights) - union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 # type: ignore[call-overload] - dice_loss = 1 - 2.0 * inter / union - - # Compute l1 loss for thresh_map - l1_scale = 10.0 - if torch.any(thresh_mask): - l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + ) + + positive = (seg_target * seg_mask).float() + negative = ((1 - seg_target) * seg_mask).float() + positive_count = int(positive.sum()) + if positive_count == 0: + negative_count = min(int(negative.sum()), 0) + else: + negative_count = min(int(negative.sum()), int(positive_count * 3)) + + positive_loss = bce_loss * positive + negative_loss = bce_loss * negative + + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) - return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss # type: ignore[return-value] + balanced_bce_loss = ( + (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + 1e-6) * 5.0 + ) + + # Compute dice loss + binary_mask = binary_mask.contiguous().view(binary_mask.size(0), -1) + seg_target = seg_target.contiguous().view(seg_target.size(0), -1) + + seg_mask = seg_mask.contiguous().view(seg_mask.size(0), -1) + binary_mask = binary_mask * seg_mask + seg_target = seg_target * seg_mask + + dice_coeff = (2 * (binary_mask * seg_target).sum()) / (binary_mask.sum() + seg_target.sum() + 1e-6) + + dice_loss = 1 - dice_coeff + + if torch.any(thresh_mask): + # Compute l1 loss + assert thresh_map.size() == thresh_target.size() and thresh_target.numel() > 0 + assert thresh_mask.size() == thresh_target.size() + x = thresh_map * thresh_mask + y = thresh_target * thresh_mask + + loss = torch.zeros_like(thresh_target) + diff = torch.abs(x - y) + mask_beta = diff < 1 + loss[mask_beta] = 0.5 * torch.square(diff)[mask_beta] / 1 + loss[~mask_beta] = diff[~mask_beta] - 0.5 * 1 + l1_loss = loss.sum() / (thresh_mask.sum() + 1e-6) * 10.0 + + return (balanced_bce_loss + dice_loss + l1_loss) / 3 def _dbnet( From 2efb7973ba1866a78be9fd989814c02cb0a61d02 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 8 Jan 2024 12:55:16 +0100 Subject: [PATCH 04/13] update loss tf and pt --- .../differentiable_binarization/pytorch.py | 92 +++++++------------ .../differentiable_binarization/tensorflow.py | 50 +++++----- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 611308cb0b..2230b73875 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -213,7 +213,15 @@ def forward( return out - def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor: + def compute_loss( + self, + out_map: torch.Tensor, + thresh_map: torch.Tensor, + target: List[np.ndarray], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> torch.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -222,6 +230,9 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: out_map: output feature map of the model of shape (N, C, H, W) thresh_map: threshold map of shape (N, C, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -229,9 +240,7 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: """ prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) - binary_mask = torch.reciprocal(1.0 + torch.exp(-50 * (prob_map - thresh_map))) - # TODO: needs also some checks again shrunk masks targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) @@ -239,64 +248,33 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) - balanced_bce_loss = torch.zeros(1, device=out_map.device) + focal_loss = torch.zeros(1, device=out_map.device) dice_loss = torch.zeros(1, device=out_map.device) l1_loss = torch.zeros(1, device=out_map.device) - - # TODO: Still in progress @Oliver @Charles - if torch.any(seg_mask): - # Compute balanced bce loss - bce_loss = F.binary_cross_entropy_with_logits( - out_map, - seg_target, - reduction="none", - ) - - positive = (seg_target * seg_mask).float() - negative = ((1 - seg_target) * seg_mask).float() - positive_count = int(positive.sum()) - if positive_count == 0: - negative_count = min(int(negative.sum()), 0) - else: - negative_count = min(int(negative.sum()), int(positive_count * 3)) - - positive_loss = bce_loss * positive - negative_loss = bce_loss * negative - - negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) - - balanced_bce_loss = ( - (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + 1e-6) * 5.0 - ) - - # Compute dice loss - binary_mask = binary_mask.contiguous().view(binary_mask.size(0), -1) - seg_target = seg_target.contiguous().view(seg_target.size(0), -1) - - seg_mask = seg_mask.contiguous().view(seg_mask.size(0), -1) - binary_mask = binary_mask * seg_mask - seg_target = seg_target * seg_mask - - dice_coeff = (2 * (binary_mask * seg_target).sum()) / (binary_mask.sum() + seg_target.sum() + 1e-6) - - dice_loss = 1 - dice_coeff - + # Focal loss + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target) + alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) + # Unreduced version + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum() + + # Dice loss + inter = (seg_mask * prob_map * seg_target).sum() + cardinality = (seg_mask * (prob_map + seg_target)).sum() + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) + + # Compute l1 loss for thresh_map + l1_scale = 10.0 if torch.any(thresh_mask): - # Compute l1 loss - assert thresh_map.size() == thresh_target.size() and thresh_target.numel() > 0 - assert thresh_mask.size() == thresh_target.size() - x = thresh_map * thresh_mask - y = thresh_target * thresh_mask - - loss = torch.zeros_like(thresh_target) - diff = torch.abs(x - y) - mask_beta = diff < 1 - loss[mask_beta] = 0.5 * torch.square(diff)[mask_beta] / 1 - loss[~mask_beta] = diff[~mask_beta] - 0.5 * 1 - l1_loss = loss.sum() / (thresh_mask.sum() + 1e-6) * 10.0 - - return (balanced_bce_loss + dice_loss + l1_loss) / 3 + l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + + return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value] def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 88265cb9e4..9e58356544 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -166,6 +166,9 @@ def compute_loss( out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[Dict[str, np.ndarray]], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, ) -> tf.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -175,6 +178,9 @@ def compute_loss( out_map: output feature map of the model of shape (N, H, W, C) thresh_map: threshold map of shape (N, H, W, C) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -186,33 +192,27 @@ def compute_loss( seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 - bce_loss = tf.keras.losses.binary_crossentropy( - seg_target[..., None], - out_map[..., None], - from_logits=True, - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = tf.math.reduce_sum(seg_target[seg_mask]) - negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) - negative_loss = bce_loss * neg_target - negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) - sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) - - bce_min = tf.math.reduce_min(bce_loss) - weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 - inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) - union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 - dice_loss = 1 - 2.0 * inter / union + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor + p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Dice loss + inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3)) + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 @@ -221,7 +221,7 @@ def compute_loss( else: l1_loss = tf.constant(0.0) - return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + return l1_scale * l1_loss + focal_loss + dice_loss def call( self, From 9ec4ce5d243bfa4211ec9dd63697ce5187e38059 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 9 Jan 2024 13:00:25 +0100 Subject: [PATCH 05/13] update --- .../differentiable_binarization/pytorch.py | 12 +-- .../differentiable_binarization/tensorflow.py | 18 +++-- doctr/models/detection/linknet/tensorflow.py | 76 +++++++++---------- 3 files changed, 54 insertions(+), 52 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 2230b73875..f0349669c4 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -253,6 +253,7 @@ def compute_loss( l1_loss = torch.zeros(1, device=out_map.device) if torch.any(seg_mask): # Focal loss + focal_scale = 5.0 bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") if gamma < 0: @@ -264,17 +265,18 @@ def compute_loss( # Class reduced focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum() - # Dice loss - inter = (seg_mask * prob_map * seg_target).sum() - cardinality = (seg_mask * (prob_map + seg_target)).sum() + # Compute dice loss for approx binary_map + binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) + inter = (seg_mask * binary_map * seg_target).sum() + cardinality = (seg_mask * (binary_map + seg_target)).sum() dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 if torch.any(thresh_mask): - l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value] + return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 9e58356544..31a9907c72 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -196,9 +196,9 @@ def compute_loss( thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) - # Focal loss + focal_scale = 5.0 + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) if gamma < 0: raise ValueError("Value of gamma should be greater than or equal to zero.") # Convert logits to prob, compute gamma factor @@ -209,19 +209,23 @@ def compute_loss( # Class reduced focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) - # Dice loss - inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3)) - cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3)) + # Compute dice loss for approx binary_map + binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map))) + inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3)) dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 if tf.reduce_any(thresh_mask): - l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + thresh_mask = tf.cast(thresh_mask, tf.float32) + l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / ( + tf.reduce_sum(thresh_mask) + eps + ) else: l1_loss = tf.constant(0.0) - return l1_scale * l1_loss + focal_loss + dice_loss + return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss def call( self, diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index cfb15b3108..e339588726 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -46,22 +46,20 @@ def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential: """Creates a LinkNet decoder block""" - return Sequential( - [ - *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs), - layers.Conv2DTranspose( - filters=in_chan // 4, - kernel_size=3, - strides=stride, - padding="same", - use_bias=False, - kernel_initializer="he_normal", - ), - layers.BatchNormalization(), - layers.Activation("relu"), - *conv_sequence(out_chan, "relu", True, kernel_size=1), - ] - ) + return Sequential([ + *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs), + layers.Conv2DTranspose( + filters=in_chan // 4, + kernel_size=3, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(out_chan, "relu", True, kernel_size=1), + ]) class LinkNetFPN(Model, NestedObject): @@ -131,30 +129,28 @@ def __init__( self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape]) self.fpn.build(self.feat_extractor.output_shape) - self.classifier = Sequential( - [ - layers.Conv2DTranspose( - filters=32, - kernel_size=3, - strides=2, - padding="same", - use_bias=False, - kernel_initializer="he_normal", - input_shape=self.fpn.decoders[-1].output_shape[1:], - ), - layers.BatchNormalization(), - layers.Activation("relu"), - *conv_sequence(32, "relu", True, kernel_size=3, strides=1), - layers.Conv2DTranspose( - filters=num_classes, - kernel_size=2, - strides=2, - padding="same", - use_bias=True, - kernel_initializer="he_normal", - ), - ] - ) + self.classifier = Sequential([ + layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=2, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + input_shape=self.fpn.decoders[-1].output_shape[1:], + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(32, "relu", True, kernel_size=3, strides=1), + layers.Conv2DTranspose( + filters=num_classes, + kernel_size=2, + strides=2, + padding="same", + use_bias=True, + kernel_initializer="he_normal", + ), + ]) self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) From e380e7ad6993a6a25e3295a6b783c97a0d53eeb4 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 9 Jan 2024 16:30:10 +0100 Subject: [PATCH 06/13] update --- .../differentiable_binarization/base.py | 28 ++++++------------- .../differentiable_binarization/pytorch.py | 4 +-- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 0d261d2991..0290b2940c 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -33,12 +33,12 @@ class DBPostProcessor(DetectionPostProcessor): def __init__( self, + bin_thresh: float = 0.1, box_thresh: float = 0.1, - bin_thresh: float = 0.3, assume_straight_pages: bool = True, ) -> None: super().__init__(box_thresh, bin_thresh, assume_straight_pages) - self.unclip_ratio = 1.5 if assume_straight_pages else 2.2 + self.unclip_ratio = 1.2 def polygon_to_box( self, @@ -93,28 +93,27 @@ def bitmap_to_boxes( pred: np.ndarray, bitmap: np.ndarray, ) -> np.ndarray: - """Compute boxes from a bitmap/pred_map + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes Args: ---- - pred: Pred map from differentiable binarization output + pred: Pred map from differentiable linknet output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: ------- - np tensor boxes for the bitmap, each box is a 5-element list - containing x, y, w, h, score for the box + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box """ height, width = bitmap.shape[:2] - min_size_box = 1 + int(height / 512) boxes: List[Union[np.ndarray, List[float]]] = [] # get contours from connected components on the bitmap contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: # Check whether smallest enclosing bounding box is not too small - if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): continue # Compute objectness if self.assume_straight_pages: @@ -132,22 +131,13 @@ def bitmap_to_boxes( else: _box = self.polygon_to_box(np.squeeze(contour)) - # Remove too small boxes if self.assume_straight_pages: - if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: - continue - elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: - continue - - if self.assume_straight_pages: - x, y, w, h = _box # compute relative polygon to get rid of img shape + x, y, w, h = _box xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height boxes.append([xmin, ymin, xmax, ymax, score]) else: - # compute relative box to get rid of img shape, in that case _box is a 4pt polygon - if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): - raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") + # compute relative box to get rid of img shape _box[:, 0] /= width _box[:, 1] /= height boxes.append(_box) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index f0349669c4..6ba42f158a 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -267,8 +267,8 @@ def compute_loss( # Compute dice loss for approx binary_map binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) - inter = (seg_mask * binary_map * seg_target).sum() - cardinality = (seg_mask * (binary_map + seg_target)).sum() + inter = (seg_mask * binary_map * seg_target).sum() # type: ignore[attr-defined] + cardinality = (seg_mask * (binary_map + seg_target)).sum() # type: ignore[attr-defined] dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map From d4a2f365bfd9e10e0fae82ac6060434aa69369e9 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 9 Jan 2024 16:31:15 +0100 Subject: [PATCH 07/13] update --- doctr/models/detection/differentiable_binarization/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 0290b2940c..8f017af0ee 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -97,15 +97,15 @@ def bitmap_to_boxes( Args: ---- - pred: Pred map from differentiable linknet output + pred: Pred map from differentiable binarization output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: ------- - np tensor boxes for the bitmap, each box is a 6-element list - containing x, y, w, h, alpha, score for the box + np tensor boxes for the bitmap, each box is a 5-element list + containing x, y, w, h, score for the box """ height, width = bitmap.shape[:2] boxes: List[Union[np.ndarray, List[float]]] = [] From 8e858b911283e17172e1b5f63f170f09baaeb585 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 9 Jan 2024 16:38:56 +0100 Subject: [PATCH 08/13] thresh value --- doctr/models/detection/differentiable_binarization/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 8f017af0ee..055e578eb9 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -33,7 +33,7 @@ class DBPostProcessor(DetectionPostProcessor): def __init__( self, - bin_thresh: float = 0.1, + bin_thresh: float = 0.3, box_thresh: float = 0.1, assume_straight_pages: bool = True, ) -> None: From 2e6fc182d1ae0f1d3980b1da8ba1800cb1bb294f Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 10 Jan 2024 11:27:30 +0100 Subject: [PATCH 09/13] experimential test --- .../differentiable_binarization/pytorch.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 6ba42f158a..46a9e6385a 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -245,15 +245,15 @@ def compute_loss( seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) - thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) - thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) + #thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) + #thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) focal_loss = torch.zeros(1, device=out_map.device) dice_loss = torch.zeros(1, device=out_map.device) - l1_loss = torch.zeros(1, device=out_map.device) + # l1_loss = torch.zeros(1, device=out_map.device) if torch.any(seg_mask): # Focal loss - focal_scale = 5.0 + #focal_scale = 5.0 bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") if gamma < 0: @@ -272,11 +272,12 @@ def compute_loss( dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map - l1_scale = 10.0 - if torch.any(thresh_mask): - l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) + #l1_scale = 10.0 + #if torch.any(thresh_mask): + # l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] + #return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] + return focal_loss + dice_loss # type: ignore[return-value] def _dbnet( From 5e758e7c8b80989c85b5e74d525245ec159c002b Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 11 Jan 2024 13:50:46 +0100 Subject: [PATCH 10/13] update --- .../differentiable_binarization/base.py | 20 ++++++++++++++----- .../differentiable_binarization/pytorch.py | 18 +++++++---------- .../differentiable_binarization/tensorflow.py | 4 ++-- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 055e578eb9..c187cef287 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -33,8 +33,8 @@ class DBPostProcessor(DetectionPostProcessor): def __init__( self, - bin_thresh: float = 0.3, box_thresh: float = 0.1, + bin_thresh: float = 0.3, assume_straight_pages: bool = True, ) -> None: super().__init__(box_thresh, bin_thresh, assume_straight_pages) @@ -108,12 +108,13 @@ def bitmap_to_boxes( containing x, y, w, h, score for the box """ height, width = bitmap.shape[:2] + min_size_box = 2 boxes: List[Union[np.ndarray, List[float]]] = [] # get contours from connected components on the bitmap contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: # Check whether smallest enclosing bounding box is not too small - if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): continue # Compute objectness if self.assume_straight_pages: @@ -131,13 +132,22 @@ def bitmap_to_boxes( else: _box = self.polygon_to_box(np.squeeze(contour)) + # Remove too small boxes + if self.assume_straight_pages: + if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: + continue + elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: + continue + if self.assume_straight_pages: - # compute relative polygon to get rid of img shape x, y, w, h = _box + # compute relative polygon to get rid of img shape xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height boxes.append([xmin, ymin, xmax, ymax, score]) else: - # compute relative box to get rid of img shape + # compute relative box to get rid of img shape, in that case _box is a 4pt polygon + if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): + raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") _box[:, 0] /= width _box[:, 1] /= height boxes.append(_box) @@ -170,7 +180,7 @@ def compute_distance( ys: np.ndarray, a: np.ndarray, b: np.ndarray, - eps: float = 1e-7, + eps: float = 1e-6, ) -> float: """Compute the distance for each point of the map (xs, ys) to the (a, b) segment diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 46a9e6385a..44405780ce 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -245,15 +245,12 @@ def compute_loss( seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) - #thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) - #thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) + thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) + thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) - focal_loss = torch.zeros(1, device=out_map.device) - dice_loss = torch.zeros(1, device=out_map.device) - # l1_loss = torch.zeros(1, device=out_map.device) if torch.any(seg_mask): # Focal loss - #focal_scale = 5.0 + focal_scale = 10.0 bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") if gamma < 0: @@ -272,12 +269,11 @@ def compute_loss( dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map - #l1_scale = 10.0 - #if torch.any(thresh_mask): - # l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) + l1_scale = 1.0 + if torch.any(thresh_mask): + l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - #return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] - return focal_loss + dice_loss # type: ignore[return-value] + return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 31a9907c72..8487c7cfcf 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -197,7 +197,7 @@ def compute_loss( thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) # Focal loss - focal_scale = 5.0 + focal_scale = 10.0 bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) if gamma < 0: raise ValueError("Value of gamma should be greater than or equal to zero.") @@ -216,7 +216,7 @@ def compute_loss( dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map - l1_scale = 10.0 + l1_scale = 1.0 if tf.reduce_any(thresh_mask): thresh_mask = tf.cast(thresh_mask, tf.float32) l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / ( From 391ffab82374fe3245236f9a86df1395098f16ea Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 11 Jan 2024 13:55:38 +0100 Subject: [PATCH 11/13] mypy --- doctr/models/detection/differentiable_binarization/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 44405780ce..0c8424808b 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -273,7 +273,7 @@ def compute_loss( if torch.any(thresh_mask): l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value] + return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss def _dbnet( From 93f8d61b5ffaa48c33d8bbba491900572086412f Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 12 Jan 2024 10:59:34 +0100 Subject: [PATCH 12/13] unclip_ratio --- doctr/models/detection/differentiable_binarization/base.py | 2 +- doctr/models/detection/linknet/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index c187cef287..5f03a2e1bf 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -38,7 +38,7 @@ def __init__( assume_straight_pages: bool = True, ) -> None: super().__init__(box_thresh, bin_thresh, assume_straight_pages) - self.unclip_ratio = 1.2 + self.unclip_ratio = 1.5 def polygon_to_box( self, diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index 8e60d06d45..986f57d6ad 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -36,7 +36,7 @@ def __init__( assume_straight_pages: bool = True, ) -> None: super().__init__(box_thresh, bin_thresh, assume_straight_pages) - self.unclip_ratio = 1.2 + self.unclip_ratio = 1.5 def polygon_to_box( self, From 0375ee8f1cd331ea68a1825b2b268df2950da691 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 16 Jan 2024 12:18:58 +0100 Subject: [PATCH 13/13] suggestions --- .../detection/differentiable_binarization/pytorch.py | 8 ++++---- .../detection/differentiable_binarization/tensorflow.py | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 0c8424808b..c3011d3dae 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -238,6 +238,9 @@ def compute_loss( ------- A loss tensor """ + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) @@ -253,8 +256,6 @@ def compute_loss( focal_scale = 10.0 bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") - if gamma < 0: - raise ValueError("Value of gamma should be greater than or equal to zero.") p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target) alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) # Unreduced version @@ -269,11 +270,10 @@ def compute_loss( dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map - l1_scale = 1.0 if torch.any(thresh_mask): l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss + return l1_loss + focal_scale * focal_loss + dice_loss def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 8487c7cfcf..7d790088bb 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -186,6 +186,9 @@ def compute_loss( ------- A loss tensor """ + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + prob_map = tf.math.sigmoid(out_map) thresh_map = tf.math.sigmoid(thresh_map) @@ -199,8 +202,7 @@ def compute_loss( # Focal loss focal_scale = 10.0 bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) - if gamma < 0: - raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) @@ -216,7 +218,6 @@ def compute_loss( dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map - l1_scale = 1.0 if tf.reduce_any(thresh_mask): thresh_mask = tf.cast(thresh_mask, tf.float32) l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / ( @@ -225,7 +226,7 @@ def compute_loss( else: l1_loss = tf.constant(0.0) - return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss + return l1_loss + focal_scale * focal_loss + dice_loss def call( self,