Skip to content

Commit

Permalink
[TF/PT] Add FAST detection model (#1443)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Feb 28, 2024
1 parent 347e8a5 commit 15d0dbf
Show file tree
Hide file tree
Showing 18 changed files with 1,140 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Credits where it's due: this repository is implementing, among others, architect

- DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf).
- LinkNet: [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/pdf/1707.03718.pdf)
- FAST: [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/pdf/2111.02394.pdf)

### Text Recognition

Expand Down
3 changes: 3 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
Expand Down
3 changes: 3 additions & 0 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Text detection models
"""""""""""""""""""""
* DBNet from `"Real-time Scene Text Detection with Differentiable Binarization" <https://arxiv.org/pdf/1911.08947.pdf>`_
* LinkNet from `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" <https://arxiv.org/pdf/1707.03718.pdf>`_
* FAST from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" <https://arxiv.org/pdf/2111.02394.pdf>`_

Text recognition models
"""""""""""""""""""""""
Expand Down
6 changes: 6 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ doctr.models.detection

.. autofunction:: doctr.models.detection.db_mobilenet_v3_large

.. autofunction:: doctr.models.detection.fast_tiny

.. autofunction:: doctr.models.detection.fast_small

.. autofunction:: doctr.models.detection.fast_base

.. autofunction:: doctr.models.detection.detection_predictor


Expand Down
15 changes: 15 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The following architectures are currently supported:
* :py:meth:`linknet_resnet50 <doctr.models.detection.linknet_resnet50>`
* :py:meth:`db_resnet50 <doctr.models.detection.db_resnet50>`
* :py:meth:`db_mobilenet_v3_large <doctr.models.detection.db_mobilenet_v3_large>`
* :py:meth:`fast_tiny <doctr.models.detection.fast_tiny>`
* :py:meth:`fast_small <doctr.models.detection.fast_small>`
* :py:meth:`fast_base <doctr.models.detection.fast_base>`

For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets:

Expand All @@ -46,6 +49,12 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 80.70 | 83.51 | 86.46 | 84.94 | 1.1 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_tiny | (1024, 1024, 3) | 13.5 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_small | (1024, 1024, 3) | 14.7 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_base | (1024, 1024, 3) | 16.3 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet34 | (1024, 1024, 3) | 22.4 M | 82.76 | 76.75 | 89.20 | 71.74 | 0.8 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet50 | (1024, 1024, 3) | 25.4 M | 83.56 | 86.68 | 92.61 | 86.39 | 1.1 |
Expand All @@ -58,6 +67,12 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_small | (1024, 1024, 3) | 14.7 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_base | (1024, 1024, 3) | 16.3 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+


All text detection models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`).
Expand Down
1 change: 1 addition & 0 deletions doctr/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .differentiable_binarization import *
from .linknet import *
from .fast import *
from .zoo import *
6 changes: 6 additions & 0 deletions doctr/models/detection/fast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[assignment]
256 changes: 256 additions & 0 deletions doctr/models/detection/fast/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization

from typing import Dict, List, Tuple, Union

import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon

from doctr.models.core import BaseModel

from ..core import DetectionPostProcessor

__all__ = ["_FAST", "FASTPostProcessor"]


class FASTPostProcessor(DetectionPostProcessor):
"""Implements a post processor for FAST model.
Args:
----
bin_thresh: threshold used to binzarized p_map at inference time
box_thresh: minimal objectness score to consider a box
assume_straight_pages: whether the inputs were expected to have horizontal text elements
"""

def __init__(
self,
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.0

def polygon_to_box(
self,
points: np.ndarray,
) -> np.ndarray:
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
Args:
----
points: The first parameter.
Returns:
-------
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
"""
if not self.assume_straight_pages:
# Compute the rectangle polygon enclosing the raw polygon
rect = cv2.minAreaRect(points)
points = cv2.boxPoints(rect)
# Add 1 pixel to correct cv2 approx
area = (rect[1][0] + 1) * (1 + rect[1][1])
length = 2 * (rect[1][0] + rect[1][1]) + 2
else:
poly = Polygon(points)
area = poly.area
length = poly.length
distance = area * self.unclip_ratio / length # compute distance to expand polygon
offset = pyclipper.PyclipperOffset()
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
_points = offset.Execute(distance)
# Take biggest stack of points
idx = 0
if len(_points) > 1:
max_size = 0
for _idx, p in enumerate(_points):
if len(p) > max_size:
idx = _idx
max_size = len(p)
# We ensure that _points can be correctly casted to a ndarray
_points = [_points[idx]]
expanded_points: np.ndarray = np.asarray(_points) # expand polygon
if len(expanded_points) < 1:
return None # type: ignore[return-value]
return (
cv2.boundingRect(expanded_points) # type: ignore[return-value]
if self.assume_straight_pages
else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
)

def bitmap_to_boxes(
self,
pred: np.ndarray,
bitmap: np.ndarray,
) -> np.ndarray:
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
Args:
----
pred: Pred map from differentiable linknet output
bitmap: Bitmap map computed from pred (binarized)
angle_tol: Comparison tolerance of the angle with the median angle across the page
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
Returns:
-------
np tensor boxes for the bitmap, each box is a 6-element list
containing x, y, w, h, alpha, score for the box
"""
height, width = bitmap.shape[:2]
boxes: List[Union[np.ndarray, List[float]]] = []
# get contours from connected components on the bitmap
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
continue
# Compute objectness
if self.assume_straight_pages:
x, y, w, h = cv2.boundingRect(contour)
points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
score = self.box_score(pred, points, assume_straight_pages=True)
else:
score = self.box_score(pred, contour, assume_straight_pages=False)

if score < self.box_thresh: # remove polygons with a weak objectness
continue

if self.assume_straight_pages:
_box = self.polygon_to_box(points)
else:
_box = self.polygon_to_box(np.squeeze(contour))

if self.assume_straight_pages:
# compute relative polygon to get rid of img shape
x, y, w, h = _box
xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
boxes.append([xmin, ymin, xmax, ymax, score])
else:
# compute relative box to get rid of img shape
_box[:, 0] /= width
_box[:, 1] /= height
boxes.append(_box)

if not self.assume_straight_pages:
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
else:
return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)


class _FAST(BaseModel):
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
<https://arxiv.org/pdf/2111.02394.pdf>`_.
"""

min_size_box: int = 3
assume_straight_pages: bool = True
shrink_ratio = 0.1

def build_target(
self,
target: List[Dict[str, np.ndarray]],
output_shape: Tuple[int, int, int],
channels_last: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Build the target, and it's mask to be used from loss computation.
Args:
----
target: target coming from dataset
output_shape: shape of the output of the model without batch_size
channels_last: whether channels are last or not
Returns:
-------
the new formatted target, mask and shrunken text kernel
"""
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")

h: int
w: int
if channels_last:
h, w, num_classes = output_shape
else:
num_classes, h, w = output_shape
target_shape = (len(target), num_classes, h, w)

seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8)

for idx, tgt in enumerate(target):
for class_idx, _tgt in enumerate(tgt.values()):
# Draw each polygon on gt
if _tgt.shape[0] == 0:
# Empty image, full masked
seg_mask[idx, class_idx] = False

# Absolute bounding boxes
abs_boxes = _tgt.copy()

if abs_boxes.ndim == 3:
abs_boxes[:, :, 0] *= w
abs_boxes[:, :, 1] *= h
polys = abs_boxes
boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
else:
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)
polys = np.stack(
[
abs_boxes[:, [0, 1]],
abs_boxes[:, [0, 3]],
abs_boxes[:, [2, 3]],
abs_boxes[:, [2, 1]],
],
axis=1,
)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
# Mask boxes that are too small
if box_size < self.min_size_box:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

# Negative shrink for gt, as described in paper
polygon = Polygon(poly)
distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrunken = padding.Execute(-distance)

# Draw polygon on gt if it is valid
if len(shrunken) == 0:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
# draw the original polygon on the segmentation target
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]

# Don't forget to switch back to channel last if Tensorflow is used
if channels_last:
seg_target = seg_target.transpose((0, 2, 3, 1))
seg_mask = seg_mask.transpose((0, 2, 3, 1))
shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))

return seg_target, seg_mask, shrunken_kernel
Loading

0 comments on commit 15d0dbf

Please sign in to comment.