Skip to content

Commit

Permalink
[FIX] db loss TF and PT also for training with rotated samples (#1396)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jan 16, 2024
1 parent e5b3f46 commit ff9982b
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 123 deletions.
46 changes: 24 additions & 22 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.5 if assume_straight_pages else 2.2
self.unclip_ratio = 1.5

def polygon_to_box(
self,
Expand Down Expand Up @@ -93,7 +93,7 @@ def bitmap_to_boxes(
pred: np.ndarray,
bitmap: np.ndarray,
) -> np.ndarray:
"""Compute boxes from a bitmap/pred_map
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
Args:
----
Expand All @@ -108,7 +108,7 @@ def bitmap_to_boxes(
containing x, y, w, h, score for the box
"""
height, width = bitmap.shape[:2]
min_size_box = 1 + int(height / 512)
min_size_box = 2
boxes: List[Union[np.ndarray, List[float]]] = []
# get contours from connected components on the bitmap
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Expand Down Expand Up @@ -180,7 +180,7 @@ def compute_distance(
ys: np.ndarray,
a: np.ndarray,
b: np.ndarray,
eps: float = 1e-7,
eps: float = 1e-6,
) -> float:
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
Expand All @@ -201,9 +201,10 @@ def compute_distance(
square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
cosin = np.clip(cosin, -1.0, 1.0)
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist)
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
return result

Expand Down Expand Up @@ -265,7 +266,10 @@ def draw_thresh_map(

# Fill the canvas with the distances computed inside the valid padded polygon
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1],
1
- distance_map[
ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
],
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
)

Expand All @@ -274,7 +278,7 @@ def draw_thresh_map(
def build_target(
self,
target: List[Dict[str, np.ndarray]],
output_shape: Tuple[int, int, int, int],
output_shape: Tuple[int, int, int],
channels_last: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
Expand All @@ -284,23 +288,24 @@ def build_target(

input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32

h: int
w: int
if channels_last:
h, w = output_shape[1:-1]
target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w)
h, w, num_classes = output_shape
else:
h, w = output_shape[-2:]
target_shape = output_shape # (Batch_size, num_classes, h, w)
num_classes, h, w = output_shape
target_shape = (len(target), num_classes, h, w)

seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8)
thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)

for idx, tgt in enumerate(target):
for class_idx, _tgt in enumerate(tgt.values()):
# Draw each polygon on gt
if _tgt.shape[0] == 0:
# Empty image, full masked
# seg_mask[idx, :, :, class_idx] = False
seg_mask[idx, class_idx] = False

# Absolute bounding boxes
Expand All @@ -326,10 +331,9 @@ def build_target(
)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
# Mask boxes that are too small
if box_size < self.min_size_box:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

Expand All @@ -339,19 +343,17 @@ def build_target(
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance)
shrunken = padding.Execute(-distance)

# Draw polygon on gt if it is valid
if len(shrinked) == 0:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
if len(shrunken) == 0:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1.0) # type: ignore[call-overload]
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]

# Draw on both thresh map and thresh mask
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
Expand Down
66 changes: 34 additions & 32 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,15 @@ def forward(

return out

def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
def compute_loss(
self,
out_map: torch.Tensor,
thresh_map: torch.Tensor,
target: List[np.ndarray],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> torch.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -222,56 +230,50 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target:
out_map: output feature map of the model of shape (N, C, H, W)
thresh_map: threshold map of shape (N, C, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss
Returns:
-------
A loss tensor
"""
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

prob_map = torch.sigmoid(out_map)
thresh_map = torch.sigmoid(thresh_map)

targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type]
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
balanced_bce_loss = torch.zeros(1, device=out_map.device)
dice_loss = torch.zeros(1, device=out_map.device)
l1_loss = torch.zeros(1, device=out_map.device)
if torch.any(seg_mask):
bce_loss = F.binary_cross_entropy_with_logits(
out_map,
seg_target,
reduction="none",
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = seg_target[seg_mask].sum()
negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count)
negative_loss = bce_loss * neg_target
negative_loss = negative_loss.sort().values[-int(negative_count.item()) :]
sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = bce_loss.min()
weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0
inter = torch.sum(bin_map * seg_target[seg_mask] * weights)
union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 # type: ignore[call-overload]
dice_loss = 1 - 2.0 * inter / union
# Focal loss
focal_scale = 10.0
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")

p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
# Unreduced version
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()

# Compute dice loss for approx binary_map
binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
inter = (seg_mask * binary_map * seg_target).sum() # type: ignore[attr-defined]
cardinality = (seg_mask * (binary_map + seg_target)).sum() # type: ignore[attr-defined]
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if torch.any(thresh_mask):
l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss # type: ignore[return-value]
return l1_loss + focal_scale * focal_loss + dice_loss


def _dbnet(
Expand Down
61 changes: 33 additions & 28 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def compute_loss(
out_map: tf.Tensor,
thresh_map: tf.Tensor,
target: List[Dict[str, np.ndarray]],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -175,53 +178,55 @@ def compute_loss(
out_map: output feature map of the model of shape (N, H, W, C)
thresh_map: threshold map of shape (N, H, W, C)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss
Returns:
-------
A loss tensor
"""
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

prob_map = tf.math.sigmoid(out_map)
thresh_map = tf.math.sigmoid(thresh_map)

seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True)
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
seg_mask = tf.cast(seg_mask, tf.float32)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None],
out_map[..., None],
from_logits=True,
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * inter / union
# Focal loss
focal_scale = 10.0
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)

# Convert logits to prob, compute gamma factor
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
# Unreduced loss
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Compute dice loss for approx binary_map
binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3))
cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
thresh_mask = tf.cast(thresh_mask, tf.float32)
l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
tf.reduce_sum(thresh_mask) + eps
)
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
return l1_loss + focal_scale * focal_loss + dice_loss

def call(
self,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.2
self.unclip_ratio = 1.5

def polygon_to_box(
self,
Expand Down
Loading

0 comments on commit ff9982b

Please sign in to comment.