Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] db loss TF and PT also for training with rotated samples #1396

Merged
merged 13 commits into from
Jan 16, 2024
46 changes: 24 additions & 22 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.5 if assume_straight_pages else 2.2
self.unclip_ratio = 1.5

def polygon_to_box(
self,
Expand Down Expand Up @@ -93,7 +93,7 @@ def bitmap_to_boxes(
pred: np.ndarray,
bitmap: np.ndarray,
) -> np.ndarray:
"""Compute boxes from a bitmap/pred_map
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes

Args:
----
Expand All @@ -108,7 +108,7 @@ def bitmap_to_boxes(
containing x, y, w, h, score for the box
"""
height, width = bitmap.shape[:2]
min_size_box = 1 + int(height / 512)
min_size_box = 2
boxes: List[Union[np.ndarray, List[float]]] = []
# get contours from connected components on the bitmap
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Expand Down Expand Up @@ -180,7 +180,7 @@ def compute_distance(
ys: np.ndarray,
a: np.ndarray,
b: np.ndarray,
eps: float = 1e-7,
eps: float = 1e-6,
) -> float:
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment

Expand All @@ -201,9 +201,10 @@ def compute_distance(
square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
cosin = np.clip(cosin, -1.0, 1.0)
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist)
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
return result

Expand Down Expand Up @@ -265,7 +266,10 @@ def draw_thresh_map(

# Fill the canvas with the distances computed inside the valid padded polygon
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1],
1
- distance_map[
ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
],
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
)

Expand All @@ -274,7 +278,7 @@ def draw_thresh_map(
def build_target(
self,
target: List[Dict[str, np.ndarray]],
output_shape: Tuple[int, int, int, int],
output_shape: Tuple[int, int, int],
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
channels_last: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
Expand All @@ -284,23 +288,24 @@ def build_target(

input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32

h: int
w: int
if channels_last:
h, w = output_shape[1:-1]
target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w)
h, w, num_classes = output_shape
else:
h, w = output_shape[-2:]
target_shape = output_shape # (Batch_size, num_classes, h, w)
num_classes, h, w = output_shape
target_shape = (len(target), num_classes, h, w)

seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8)
thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved

for idx, tgt in enumerate(target):
for class_idx, _tgt in enumerate(tgt.values()):
# Draw each polygon on gt
if _tgt.shape[0] == 0:
# Empty image, full masked
# seg_mask[idx, :, :, class_idx] = False
seg_mask[idx, class_idx] = False

# Absolute bounding boxes
Expand All @@ -326,10 +331,9 @@ def build_target(
)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
# Mask boxes that are too small
if box_size < self.min_size_box:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

Expand All @@ -339,19 +343,17 @@ def build_target(
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance)
shrunken = padding.Execute(-distance)

# Draw polygon on gt if it is valid
if len(shrinked) == 0:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
if len(shrunken) == 0:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1.0) # type: ignore[call-overload]
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]

# Draw on both thresh map and thresh mask
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
Expand Down
66 changes: 34 additions & 32 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,15 @@

return out

def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
def compute_loss(
self,
out_map: torch.Tensor,
thresh_map: torch.Tensor,
target: List[np.ndarray],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> torch.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output

Expand All @@ -222,56 +230,50 @@
out_map: output feature map of the model of shape (N, C, H, W)
thresh_map: threshold map of shape (N, C, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss

Returns:
-------
A loss tensor
"""
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

Check warning on line 242 in doctr/models/detection/differentiable_binarization/pytorch.py

View check run for this annotation

Codecov / codecov/patch

doctr/models/detection/differentiable_binarization/pytorch.py#L242

Added line #L242 was not covered by tests

prob_map = torch.sigmoid(out_map)
thresh_map = torch.sigmoid(thresh_map)

targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type]
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
balanced_bce_loss = torch.zeros(1, device=out_map.device)
dice_loss = torch.zeros(1, device=out_map.device)
l1_loss = torch.zeros(1, device=out_map.device)
if torch.any(seg_mask):
bce_loss = F.binary_cross_entropy_with_logits(
out_map,
seg_target,
reduction="none",
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = seg_target[seg_mask].sum()
negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count)
negative_loss = bce_loss * neg_target
negative_loss = negative_loss.sort().values[-int(negative_count.item()) :]
sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = bce_loss.min()
weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0
inter = torch.sum(bin_map * seg_target[seg_mask] * weights)
union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 # type: ignore[call-overload]
dice_loss = 1 - 2.0 * inter / union
# Focal loss
focal_scale = 10.0
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")

p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
# Unreduced version
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()

# Compute dice loss for approx binary_map
binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
inter = (seg_mask * binary_map * seg_target).sum() # type: ignore[attr-defined]
cardinality = (seg_mask * (binary_map + seg_target)).sum() # type: ignore[attr-defined]
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if torch.any(thresh_mask):
l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss # type: ignore[return-value]
return l1_loss + focal_scale * focal_loss + dice_loss


def _dbnet(
Expand Down
61 changes: 33 additions & 28 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@
out_map: tf.Tensor,
thresh_map: tf.Tensor,
target: List[Dict[str, np.ndarray]],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -175,53 +178,55 @@
out_map: output feature map of the model of shape (N, H, W, C)
thresh_map: threshold map of shape (N, H, W, C)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss

Returns:
-------
A loss tensor
"""
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

Check warning on line 190 in doctr/models/detection/differentiable_binarization/tensorflow.py

View check run for this annotation

Codecov / codecov/patch

doctr/models/detection/differentiable_binarization/tensorflow.py#L190

Added line #L190 was not covered by tests

prob_map = tf.math.sigmoid(out_map)
thresh_map = tf.math.sigmoid(thresh_map)

seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True)
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
seg_mask = tf.cast(seg_mask, tf.float32)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None],
out_map[..., None],
from_logits=True,
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * inter / union
# Focal loss
focal_scale = 10.0
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)

# Convert logits to prob, compute gamma factor
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
# Unreduced loss
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Compute dice loss for approx binary_map
binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3))
cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
thresh_mask = tf.cast(thresh_mask, tf.float32)
l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
tf.reduce_sum(thresh_mask) + eps
)
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
return l1_loss + focal_scale * focal_loss + dice_loss

def call(
self,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.2
self.unclip_ratio = 1.5

def polygon_to_box(
self,
Expand Down
Loading
Loading