@@ -38,7 +38,7 @@ def __init__(
38
38
assume_straight_pages : bool = True ,
39
39
) -> None :
40
40
super ().__init__ (box_thresh , bin_thresh , assume_straight_pages )
41
- self .unclip_ratio = 1.5 if assume_straight_pages else 2.2
41
+ self .unclip_ratio = 1.5
42
42
43
43
def polygon_to_box (
44
44
self ,
@@ -93,7 +93,7 @@ def bitmap_to_boxes(
93
93
pred : np .ndarray ,
94
94
bitmap : np .ndarray ,
95
95
) -> np .ndarray :
96
- """Compute boxes from a bitmap/pred_map
96
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
97
97
98
98
Args:
99
99
----
@@ -108,7 +108,7 @@ def bitmap_to_boxes(
108
108
containing x, y, w, h, score for the box
109
109
"""
110
110
height , width = bitmap .shape [:2 ]
111
- min_size_box = 1 + int ( height / 512 )
111
+ min_size_box = 2
112
112
boxes : List [Union [np .ndarray , List [float ]]] = []
113
113
# get contours from connected components on the bitmap
114
114
contours , _ = cv2 .findContours (bitmap .astype (np .uint8 ), cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE )
@@ -180,7 +180,7 @@ def compute_distance(
180
180
ys : np .ndarray ,
181
181
a : np .ndarray ,
182
182
b : np .ndarray ,
183
- eps : float = 1e-7 ,
183
+ eps : float = 1e-6 ,
184
184
) -> float :
185
185
"""Compute the distance for each point of the map (xs, ys) to the (a, b) segment
186
186
@@ -201,9 +201,10 @@ def compute_distance(
201
201
square_dist_2 = np .square (xs - b [0 ]) + np .square (ys - b [1 ])
202
202
square_dist = np .square (a [0 ] - b [0 ]) + np .square (a [1 ] - b [1 ])
203
203
cosin = (square_dist - square_dist_1 - square_dist_2 ) / (2 * np .sqrt (square_dist_1 * square_dist_2 ) + eps )
204
+ cosin = np .clip (cosin , - 1.0 , 1.0 )
204
205
square_sin = 1 - np .square (cosin )
205
206
square_sin = np .nan_to_num (square_sin )
206
- result = np .sqrt (square_dist_1 * square_dist_2 * square_sin / square_dist )
207
+ result = np .sqrt (square_dist_1 * square_dist_2 * square_sin / square_dist + eps )
207
208
result [cosin < 0 ] = np .sqrt (np .fmin (square_dist_1 , square_dist_2 ))[cosin < 0 ]
208
209
return result
209
210
@@ -265,7 +266,10 @@ def draw_thresh_map(
265
266
266
267
# Fill the canvas with the distances computed inside the valid padded polygon
267
268
canvas [ymin_valid : ymax_valid + 1 , xmin_valid : xmax_valid + 1 ] = np .fmax (
268
- 1 - distance_map [ymin_valid - ymin : ymax_valid - ymin + 1 , xmin_valid - xmin : xmax_valid - xmin + 1 ],
269
+ 1
270
+ - distance_map [
271
+ ymin_valid - ymin : ymax_valid - ymax + height , xmin_valid - xmin : xmax_valid - xmax + width
272
+ ],
269
273
canvas [ymin_valid : ymax_valid + 1 , xmin_valid : xmax_valid + 1 ],
270
274
)
271
275
@@ -274,7 +278,7 @@ def draw_thresh_map(
274
278
def build_target (
275
279
self ,
276
280
target : List [Dict [str , np .ndarray ]],
277
- output_shape : Tuple [int , int , int , int ],
281
+ output_shape : Tuple [int , int , int ],
278
282
channels_last : bool = True ,
279
283
) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
280
284
if any (t .dtype != np .float32 for tgt in target for t in tgt .values ()):
@@ -284,23 +288,24 @@ def build_target(
284
288
285
289
input_dtype = next (iter (target [0 ].values ())).dtype if len (target ) > 0 else np .float32
286
290
291
+ h : int
292
+ w : int
287
293
if channels_last :
288
- h , w = output_shape [1 :- 1 ]
289
- target_shape = (output_shape [0 ], output_shape [- 1 ], h , w ) # (Batch_size, num_classes, h, w)
294
+ h , w , num_classes = output_shape
290
295
else :
291
- h , w = output_shape [- 2 :]
292
- target_shape = output_shape # (Batch_size, num_classes, h, w)
296
+ num_classes , h , w = output_shape
297
+ target_shape = (len (target ), num_classes , h , w )
298
+
293
299
seg_target : np .ndarray = np .zeros (target_shape , dtype = np .uint8 )
294
300
seg_mask : np .ndarray = np .ones (target_shape , dtype = bool )
295
301
thresh_target : np .ndarray = np .zeros (target_shape , dtype = np .float32 )
296
- thresh_mask : np .ndarray = np .ones (target_shape , dtype = np .uint8 )
302
+ thresh_mask : np .ndarray = np .zeros (target_shape , dtype = np .uint8 )
297
303
298
304
for idx , tgt in enumerate (target ):
299
305
for class_idx , _tgt in enumerate (tgt .values ()):
300
306
# Draw each polygon on gt
301
307
if _tgt .shape [0 ] == 0 :
302
308
# Empty image, full masked
303
- # seg_mask[idx, :, :, class_idx] = False
304
309
seg_mask [idx , class_idx ] = False
305
310
306
311
# Absolute bounding boxes
@@ -326,10 +331,9 @@ def build_target(
326
331
)
327
332
boxes_size = np .minimum (abs_boxes [:, 2 ] - abs_boxes [:, 0 ], abs_boxes [:, 3 ] - abs_boxes [:, 1 ])
328
333
329
- for box , box_size , poly in zip (abs_boxes , boxes_size , polys ):
334
+ for poly , box , box_size in zip (polys , abs_boxes , boxes_size ):
330
335
# Mask boxes that are too small
331
336
if box_size < self .min_size_box :
332
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
333
337
seg_mask [idx , class_idx , box [1 ] : box [3 ] + 1 , box [0 ] : box [2 ] + 1 ] = False
334
338
continue
335
339
@@ -339,19 +343,17 @@ def build_target(
339
343
subject = [tuple (coor ) for coor in poly ]
340
344
padding = pyclipper .PyclipperOffset ()
341
345
padding .AddPath (subject , pyclipper .JT_ROUND , pyclipper .ET_CLOSEDPOLYGON )
342
- shrinked = padding .Execute (- distance )
346
+ shrunken = padding .Execute (- distance )
343
347
344
348
# Draw polygon on gt if it is valid
345
- if len (shrinked ) == 0 :
346
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
349
+ if len (shrunken ) == 0 :
347
350
seg_mask [idx , class_idx , box [1 ] : box [3 ] + 1 , box [0 ] : box [2 ] + 1 ] = False
348
351
continue
349
- shrinked = np .array (shrinked [0 ]).reshape (- 1 , 2 )
350
- if shrinked .shape [0 ] <= 2 or not Polygon (shrinked ).is_valid :
351
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
352
+ shrunken = np .array (shrunken [0 ]).reshape (- 1 , 2 )
353
+ if shrunken .shape [0 ] <= 2 or not Polygon (shrunken ).is_valid :
352
354
seg_mask [idx , class_idx , box [1 ] : box [3 ] + 1 , box [0 ] : box [2 ] + 1 ] = False
353
355
continue
354
- cv2 .fillPoly (seg_target [idx , class_idx ], [shrinked .astype (np .int32 )], 1.0 ) # type: ignore[call-overload]
356
+ cv2 .fillPoly (seg_target [idx , class_idx ], [shrunken .astype (np .int32 )], 1.0 ) # type: ignore[call-overload]
355
357
356
358
# Draw on both thresh map and thresh mask
357
359
poly , thresh_target [idx , class_idx ], thresh_mask [idx , class_idx ] = self .draw_thresh_map (
0 commit comments