Skip to content

Commit

Permalink
fix orient train and modify for page (#1559)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Apr 19, 2024
1 parent 99eb5a3 commit 0285711
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 69 deletions.
2 changes: 1 addition & 1 deletion docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.mobilenet_v3_large_r

.. autofunction:: doctr.models.classification.mobilenet_v3_small_orientation
.. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation

.. autofunction:: doctr.models.classification.magc_resnet31

Expand Down
21 changes: 14 additions & 7 deletions doctr/models/classification/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"mobilenet_v3_small_r",
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
]

default_cfgs: Dict[str, Dict[str, Any]] = {
Expand Down Expand Up @@ -51,13 +51,20 @@
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
},
"mobilenet_v3_small_orientation": {
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 128, 128),
"classes": [0, 90, 180, 270],
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 512, 512),
"classes": [0, -90, 180, 90],
"url": None,
},
}


Expand Down Expand Up @@ -212,14 +219,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
)


def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_small_orientation
>>> model = mobilenet_v3_small_orientation(pretrained=False)
>>> from doctr.models import mobilenet_v3_small_crop_orientation
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Expand All @@ -233,7 +240,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
Expand Down
21 changes: 14 additions & 7 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"mobilenet_v3_small_r",
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
]


Expand Down Expand Up @@ -54,13 +54,20 @@
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
},
"mobilenet_v3_small_orientation": {
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, 90, 180, 270],
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": None,
},
}


Expand Down Expand Up @@ -386,14 +393,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)


def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import tensorflow as tf
>>> from doctr.models import mobilenet_v3_small_orientation
>>> model = mobilenet_v3_small_orientation(pretrained=False)
>>> from doctr.models import mobilenet_v3_small_crop_orientation
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Expand All @@ -406,4 +413,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
-------
a keras.Model
"""
return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
10 changes: 3 additions & 7 deletions doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

__all__ = ["CropOrientationPredictor"]
__all__ = ["OrientationPredictor"]


class CropOrientationPredictor(nn.Module):
class OrientationPredictor(nn.Module):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
Expand Down Expand Up @@ -57,11 +57,7 @@ def forward(
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]

class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
# Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
classes = [
int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
for idx in class_idxs
]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
confs = [round(float(p), 2) for prob in probs for p in prob]

return [class_idxs, classes, confs]
10 changes: 3 additions & 7 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject

__all__ = ["CropOrientationPredictor"]
__all__ = ["OrientationPredictor"]


class CropOrientationPredictor(NestedObject):
class OrientationPredictor(NestedObject):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
Expand Down Expand Up @@ -52,11 +52,7 @@ def __call__(
predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]

class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
# Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
classes = [
int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
for idx in class_idxs
]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
confs = [round(float(p), 2) for prob in probs for p in prob]

return [class_idxs, classes, confs]
18 changes: 9 additions & 9 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .. import classification
from ..preprocessor import PreProcessor
from .predictor import CropOrientationPredictor
from .predictor import OrientationPredictor

__all__ = ["crop_orientation_predictor"]

Expand All @@ -31,10 +31,10 @@
"vit_s",
"vit_b",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation"]


def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

Expand All @@ -44,15 +44,15 @@ def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> C
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128)
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
predictor = CropOrientationPredictor(
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
)
return predictor


def crop_orientation_predictor(
arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
) -> CropOrientationPredictor:
arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Orientation classification architecture.
>>> import numpy as np
Expand All @@ -65,10 +65,10 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
**kwargs: keyword arguments to be passed to the CropOrientationPredictor
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
CropOrientationPredictor
OrientationPredictor
"""
return _crop_orientation_predictor(arch, pretrained, **kwargs)
return _orientation_predictor(arch, pretrained, **kwargs)
4 changes: 2 additions & 2 deletions doctr/models/kie_predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from doctr.models.builder import KIEDocumentBuilder

from ..classification.predictor import CropOrientationPredictor
from ..classification.predictor import OrientationPredictor
from ..predictor.base import _OCRPredictor

__all__ = ["_KIEPredictor"]
Expand All @@ -28,7 +28,7 @@ class _KIEPredictor(_OCRPredictor):
kwargs: keyword args of `DocumentBuilder`
"""

crop_orientation_predictor: Optional[CropOrientationPredictor]
crop_orientation_predictor: Optional[OrientationPredictor]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._utils import rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor
from ..classification.predictor import CropOrientationPredictor
from ..classification.predictor import OrientationPredictor

__all__ = ["_OCRPredictor"]

Expand All @@ -32,7 +32,7 @@ class _OCRPredictor:
**kwargs: keyword args of `DocumentBuilder`
"""

crop_orientation_predictor: Optional[CropOrientationPredictor]
crop_orientation_predictor: Optional[OrientationPredictor]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_pytorch_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from doctr.models.utils import export_model_to_onnx
from utils import EarlyStopper, plot_recorder, plot_samples

CLASSES = [0, 90, 180, 270]
CLASSES = [0, -90, 180, 90]


def rnd_rotate(img: torch.Tensor, target):
Expand Down Expand Up @@ -191,7 +191,7 @@ def main(args):

torch.backends.cudnn.benchmark = True

input_size = (256, 256) if args.type == "page" else (32, 32)
input_size = (512, 512) if args.type == "page" else (256, 256)

# Load val data generator
st = time.time()
Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from doctr.transforms.functional import rotated_img_tensor
from utils import EarlyStopper, plot_recorder, plot_samples

CLASSES = [0, 90, 180, 270]
CLASSES = [0, -90, 180, 90]


def rnd_rotate(img: tf.Tensor, target):
Expand Down Expand Up @@ -147,7 +147,7 @@ def main(args):
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())

input_size = (256, 256) if args.type == "page" else (32, 32)
input_size = (512, 512) if args.type == "page" else (256, 256)

# AMP
if args.amp:
Expand Down
23 changes: 12 additions & 11 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.utils import export_model_to_onnx


Expand Down Expand Up @@ -60,7 +60,7 @@ def test_classification_architectures(arch_name, input_shape, output_size):
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["mobilenet_v3_small_orientation", (3, 128, 128)],
["mobilenet_v3_small_crop_orientation", (3, 128, 128)],
],
)
def test_classification_models(arch_name, input_shape):
Expand All @@ -80,7 +80,7 @@ def test_classification_models(arch_name, input_shape):
@pytest.mark.parametrize(
"arch_name",
[
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
],
)
def test_classification_zoo(arch_name):
Expand All @@ -92,7 +92,7 @@ def test_classification_zoo(arch_name):
with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
# object check
assert isinstance(predictor, CropOrientationPredictor)
assert isinstance(predictor, OrientationPredictor)
input_tensor = torch.rand((batch_size, 3, 128, 128))
if torch.cuda.is_available():
predictor.model.cuda()
Expand All @@ -112,14 +112,15 @@ def test_classification_zoo(arch_name):

def test_crop_orientation_model(mock_text_box):
text_box_0 = cv2.imread(mock_text_box)
text_box_90 = np.rot90(text_box_0, 1)
# rotates counter-clockwise
text_box_270 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[0] == [0, 1, 2, 3]
text_box_90 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
# 270 degrees is equivalent to -90 degrees
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[1] == [0, 90, 180, -90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_90, text_box_180, text_box_270])[2])
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])


@pytest.mark.parametrize(
Expand All @@ -134,7 +135,7 @@ def test_crop_orientation_model(mock_text_box):
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
["mobilenet_v3_small_crop_orientation", (3, 128, 128), (4,)],
["vit_s", (3, 32, 32), (126,)],
["vit_b", (3, 32, 32), (126,)],
["textnet_tiny", (3, 32, 32), (126,)],
Expand Down
Loading

0 comments on commit 0285711

Please sign in to comment.