Skip to content

Commit ff9982b

Browse files
[FIX] db loss TF and PT also for training with rotated samples (#1396)
1 parent e5b3f46 commit ff9982b

File tree

5 files changed

+128
-123
lines changed

5 files changed

+128
-123
lines changed

doctr/models/detection/differentiable_binarization/base.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
assume_straight_pages: bool = True,
3939
) -> None:
4040
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
41-
self.unclip_ratio = 1.5 if assume_straight_pages else 2.2
41+
self.unclip_ratio = 1.5
4242

4343
def polygon_to_box(
4444
self,
@@ -93,7 +93,7 @@ def bitmap_to_boxes(
9393
pred: np.ndarray,
9494
bitmap: np.ndarray,
9595
) -> np.ndarray:
96-
"""Compute boxes from a bitmap/pred_map
96+
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
9797
9898
Args:
9999
----
@@ -108,7 +108,7 @@ def bitmap_to_boxes(
108108
containing x, y, w, h, score for the box
109109
"""
110110
height, width = bitmap.shape[:2]
111-
min_size_box = 1 + int(height / 512)
111+
min_size_box = 2
112112
boxes: List[Union[np.ndarray, List[float]]] = []
113113
# get contours from connected components on the bitmap
114114
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -180,7 +180,7 @@ def compute_distance(
180180
ys: np.ndarray,
181181
a: np.ndarray,
182182
b: np.ndarray,
183-
eps: float = 1e-7,
183+
eps: float = 1e-6,
184184
) -> float:
185185
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
186186
@@ -201,9 +201,10 @@ def compute_distance(
201201
square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
202202
square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
203203
cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
204+
cosin = np.clip(cosin, -1.0, 1.0)
204205
square_sin = 1 - np.square(cosin)
205206
square_sin = np.nan_to_num(square_sin)
206-
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist)
207+
result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
207208
result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
208209
return result
209210

@@ -265,7 +266,10 @@ def draw_thresh_map(
265266

266267
# Fill the canvas with the distances computed inside the valid padded polygon
267268
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
268-
1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1],
269+
1
270+
- distance_map[
271+
ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
272+
],
269273
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
270274
)
271275

@@ -274,7 +278,7 @@ def draw_thresh_map(
274278
def build_target(
275279
self,
276280
target: List[Dict[str, np.ndarray]],
277-
output_shape: Tuple[int, int, int, int],
281+
output_shape: Tuple[int, int, int],
278282
channels_last: bool = True,
279283
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
280284
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
@@ -284,23 +288,24 @@ def build_target(
284288

285289
input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32
286290

291+
h: int
292+
w: int
287293
if channels_last:
288-
h, w = output_shape[1:-1]
289-
target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w)
294+
h, w, num_classes = output_shape
290295
else:
291-
h, w = output_shape[-2:]
292-
target_shape = output_shape # (Batch_size, num_classes, h, w)
296+
num_classes, h, w = output_shape
297+
target_shape = (len(target), num_classes, h, w)
298+
293299
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
294300
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
295301
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
296-
thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8)
302+
thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
297303

298304
for idx, tgt in enumerate(target):
299305
for class_idx, _tgt in enumerate(tgt.values()):
300306
# Draw each polygon on gt
301307
if _tgt.shape[0] == 0:
302308
# Empty image, full masked
303-
# seg_mask[idx, :, :, class_idx] = False
304309
seg_mask[idx, class_idx] = False
305310

306311
# Absolute bounding boxes
@@ -326,10 +331,9 @@ def build_target(
326331
)
327332
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
328333

329-
for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
334+
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
330335
# Mask boxes that are too small
331336
if box_size < self.min_size_box:
332-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
333337
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
334338
continue
335339

@@ -339,19 +343,17 @@ def build_target(
339343
subject = [tuple(coor) for coor in poly]
340344
padding = pyclipper.PyclipperOffset()
341345
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
342-
shrinked = padding.Execute(-distance)
346+
shrunken = padding.Execute(-distance)
343347

344348
# Draw polygon on gt if it is valid
345-
if len(shrinked) == 0:
346-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
349+
if len(shrunken) == 0:
347350
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
348351
continue
349-
shrinked = np.array(shrinked[0]).reshape(-1, 2)
350-
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
351-
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
352+
shrunken = np.array(shrunken[0]).reshape(-1, 2)
353+
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
352354
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
353355
continue
354-
cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1.0) # type: ignore[call-overload]
356+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
355357

356358
# Draw on both thresh map and thresh mask
357359
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(

doctr/models/detection/differentiable_binarization/pytorch.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,15 @@ def forward(
213213

214214
return out
215215

216-
def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
216+
def compute_loss(
217+
self,
218+
out_map: torch.Tensor,
219+
thresh_map: torch.Tensor,
220+
target: List[np.ndarray],
221+
gamma: float = 2.0,
222+
alpha: float = 0.5,
223+
eps: float = 1e-8,
224+
) -> torch.Tensor:
217225
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
218226
and a list of masks for each image. From there it computes the loss with the model output
219227
@@ -222,56 +230,50 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target:
222230
out_map: output feature map of the model of shape (N, C, H, W)
223231
thresh_map: threshold map of shape (N, C, H, W)
224232
target: list of dictionary where each dict has a `boxes` and a `flags` entry
233+
gamma: modulating factor in the focal loss formula
234+
alpha: balancing factor in the focal loss formula
235+
eps: epsilon factor in dice loss
225236
226237
Returns:
227238
-------
228239
A loss tensor
229240
"""
241+
if gamma < 0:
242+
raise ValueError("Value of gamma should be greater than or equal to zero.")
243+
230244
prob_map = torch.sigmoid(out_map)
231245
thresh_map = torch.sigmoid(thresh_map)
232246

233-
targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type]
247+
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
234248

235249
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
236250
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
237251
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
238252
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)
239253

240-
# Compute balanced BCE loss for proba_map
241-
bce_scale = 5.0
242-
balanced_bce_loss = torch.zeros(1, device=out_map.device)
243-
dice_loss = torch.zeros(1, device=out_map.device)
244-
l1_loss = torch.zeros(1, device=out_map.device)
245254
if torch.any(seg_mask):
246-
bce_loss = F.binary_cross_entropy_with_logits(
247-
out_map,
248-
seg_target,
249-
reduction="none",
250-
)[seg_mask]
251-
252-
neg_target = 1 - seg_target[seg_mask]
253-
positive_count = seg_target[seg_mask].sum()
254-
negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count)
255-
negative_loss = bce_loss * neg_target
256-
negative_loss = negative_loss.sort().values[-int(negative_count.item()) :]
257-
sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss)
258-
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
259-
260-
# Compute dice loss for approxbin_map
261-
bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))
262-
263-
bce_min = bce_loss.min()
264-
weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0
265-
inter = torch.sum(bin_map * seg_target[seg_mask] * weights)
266-
union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 # type: ignore[call-overload]
267-
dice_loss = 1 - 2.0 * inter / union
255+
# Focal loss
256+
focal_scale = 10.0
257+
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")
258+
259+
p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
260+
alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
261+
# Unreduced version
262+
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
263+
# Class reduced
264+
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()
265+
266+
# Compute dice loss for approx binary_map
267+
binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
268+
inter = (seg_mask * binary_map * seg_target).sum() # type: ignore[attr-defined]
269+
cardinality = (seg_mask * (binary_map + seg_target)).sum() # type: ignore[attr-defined]
270+
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
268271

269272
# Compute l1 loss for thresh_map
270-
l1_scale = 10.0
271273
if torch.any(thresh_mask):
272-
l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
274+
l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)
273275

274-
return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss # type: ignore[return-value]
276+
return l1_loss + focal_scale * focal_loss + dice_loss
275277

276278

277279
def _dbnet(

doctr/models/detection/differentiable_binarization/tensorflow.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def compute_loss(
166166
out_map: tf.Tensor,
167167
thresh_map: tf.Tensor,
168168
target: List[Dict[str, np.ndarray]],
169+
gamma: float = 2.0,
170+
alpha: float = 0.5,
171+
eps: float = 1e-8,
169172
) -> tf.Tensor:
170173
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
171174
and a list of masks for each image. From there it computes the loss with the model output
@@ -175,53 +178,55 @@ def compute_loss(
175178
out_map: output feature map of the model of shape (N, H, W, C)
176179
thresh_map: threshold map of shape (N, H, W, C)
177180
target: list of dictionary where each dict has a `boxes` and a `flags` entry
181+
gamma: modulating factor in the focal loss formula
182+
alpha: balancing factor in the focal loss formula
183+
eps: epsilon factor in dice loss
178184
179185
Returns:
180186
-------
181187
A loss tensor
182188
"""
189+
if gamma < 0:
190+
raise ValueError("Value of gamma should be greater than or equal to zero.")
191+
183192
prob_map = tf.math.sigmoid(out_map)
184193
thresh_map = tf.math.sigmoid(thresh_map)
185194

186-
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True)
195+
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
187196
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
188197
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
198+
seg_mask = tf.cast(seg_mask, tf.float32)
189199
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
190200
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
191201

192-
# Compute balanced BCE loss for proba_map
193-
bce_scale = 5.0
194-
bce_loss = tf.keras.losses.binary_crossentropy(
195-
seg_target[..., None],
196-
out_map[..., None],
197-
from_logits=True,
198-
)[seg_mask]
199-
200-
neg_target = 1 - seg_target[seg_mask]
201-
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
202-
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
203-
negative_loss = bce_loss * neg_target
204-
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
205-
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
206-
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
207-
208-
# Compute dice loss for approxbin_map
209-
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))
210-
211-
bce_min = tf.math.reduce_min(bce_loss)
212-
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
213-
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
214-
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
215-
dice_loss = 1 - 2.0 * inter / union
202+
# Focal loss
203+
focal_scale = 10.0
204+
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
205+
206+
# Convert logits to prob, compute gamma factor
207+
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
208+
alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
209+
# Unreduced loss
210+
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
211+
# Class reduced
212+
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
213+
214+
# Compute dice loss for approx binary_map
215+
binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
216+
inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3))
217+
cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3))
218+
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
216219

217220
# Compute l1 loss for thresh_map
218-
l1_scale = 10.0
219221
if tf.reduce_any(thresh_mask):
220-
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
222+
thresh_mask = tf.cast(thresh_mask, tf.float32)
223+
l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
224+
tf.reduce_sum(thresh_mask) + eps
225+
)
221226
else:
222227
l1_loss = tf.constant(0.0)
223228

224-
return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
229+
return l1_loss + focal_scale * focal_loss + dice_loss
225230

226231
def call(
227232
self,

doctr/models/detection/linknet/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
assume_straight_pages: bool = True,
3737
) -> None:
3838
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
39-
self.unclip_ratio = 1.2
39+
self.unclip_ratio = 1.5
4040

4141
def polygon_to_box(
4242
self,

0 commit comments

Comments
 (0)