Skip to content

Commit

Permalink
default to fast base (#1588)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored May 15, 2024
1 parent 0588750 commit 3f116ad
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
6 changes: 3 additions & 3 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from doctr.models.predictor import OCRPredictor

DET_ARCHS = [
"fast_base",
"fast_small",
"fast_tiny",
"db_resnet50",
"db_resnet34",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
Expand Down
6 changes: 3 additions & 3 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from doctr.models.predictor import OCRPredictor

DET_ARCHS = [
"fast_base",
"fast_small",
"fast_tiny",
"db_resnet50",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,


def detection_predictor(
arch: Any = "db_resnet50",
arch: Any = "fast_base",
pretrained: bool = False,
assume_straight_pages: bool = True,
**kwargs: Any,
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _predictor(


def ocr_predictor(
det_arch: Any = "db_resnet50",
det_arch: Any = "fast_base",
reco_arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
pretrained_backbone: bool = True,
Expand Down Expand Up @@ -175,7 +175,7 @@ def _kie_predictor(


def kie_predictor(
det_arch: Any = "db_resnet50",
det_arch: Any = "fast_base",
reco_arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
pretrained_backbone: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_args():
)

parser.add_argument("path", type=str, help="Path to the input document (PDF or image)")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis")
parser.add_argument(
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
Expand Down
2 changes: 1 addition & 1 deletion scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis")
parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.")
parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.")
parser.add_argument(
Expand Down
12 changes: 6 additions & 6 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_trained_ocr_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_trained_ocr_predictor(mock_payslip):
assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_trained_kie_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand All @@ -222,12 +222,12 @@ def test_trained_kie_predictor(mock_payslip):
geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]])
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr, rtol=0.05)

assert out.pages[0].predictions[CLASS_NAME][4].value == "revised"
assert out.pages[0].predictions[CLASS_NAME][3].value == "revised"
geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]])
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][4].geometry), geometry_revised, rtol=0.05)
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][3].geometry), geometry_revised, rtol=0.05)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down
8 changes: 4 additions & 4 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_trained_ocr_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_trained_ocr_predictor(mock_payslip):
assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_trained_kie_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_trained_kie_predictor(mock_payslip):
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][3].geometry), geometry_revised, rtol=0.05)

det_predictor = detection_predictor(
"db_resnet50",
"fast_base",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
Expand Down

0 comments on commit 3f116ad

Please sign in to comment.