Skip to content

Commit

Permalink
[Feat] Add torch.compile support
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 21, 2024
1 parent b53517a commit 93b4a54
Show file tree
Hide file tree
Showing 18 changed files with 290 additions and 38 deletions.
77 changes: 72 additions & 5 deletions docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ Advantages:
.. code:: python3
import torch
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half()
predictor = ocr_predictor(
reco_arch="crnn_mobilenet_v3_small",
det_arch="linknet_resnet34",
pretrained=True
).cuda().half()
res = predictor(doc)
.. tab:: TensorFlow
Expand All @@ -41,8 +45,63 @@ Advantages:
import tensorflow as tf
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)
predictor = ocr_predictor(
reco_arch="crnn_mobilenet_v3_small",
det_arch="linknet_resnet34",
pretrained=True
)
Compiling your models (PyTorch only)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**NOTE:** This feature is only available for PyTorch models.

**NOTE:** The recognition `master` architecture is not supported for model compilation yet.

- **What it does**:

Compiling your PyTorch models with `torch.compile` optimizes the model by converting it to a graph representation and applying backends that can improve performance.
This process can make inference faster and reduce memory overhead during execution.

Further information can be found in the `PyTorch documentation <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.

.. code::
import torch
from doctr.models import (
ocr_predictor,
vitstr_small,
fast_base,
mobilenet_v3_small_crop_orientation,
mobilenet_v3_small_page_orientation,
crop_orientation_predictor,
page_orientation_predictor
)
# Compile the models
detection_model = torch.compile(
fast_base(pretrained=True).eval()
)
recognition_model = torch.compile(
vitstr_small(pretrained=True).eval()
)
crop_orientation_model = torch.compile(
mobilenet_v3_small_crop_orientation(pretrained=True).eval()
)
page_orientation_model = torch.compile(
mobilenet_v3_small_page_orientation(pretrained=True).eval()
)
predictor = models.ocr_predictor(
detection_model, recognition_model, assume_straight_pages=False
)
# NOTE: Only required for non-straight pages (`assume_straight_pages=False`) and non-disabled orientation classification
# Set the orientation predictors
predictor.crop_orientation_predictor = crop_orientation_predictor(crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(page_orientation_model)
compiled_out = predictor(doc)
Export to ONNX
^^^^^^^^^^^^^^
Expand All @@ -64,7 +123,11 @@ It defines a common format for representing models, including the network struct
input_shape = (3, 32, 128)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32)
model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input)
model_path = export_model_to_onnx(
model,
model_name="vitstr.onnx,
dummy_input=dummy_input
)
.. tab:: TensorFlow

Expand All @@ -78,7 +141,11 @@ It defines a common format for representing models, including the network struct
input_shape = (32, 128, 3)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")]
model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input)
model_path, output = export_model_to_onnx(
model,
model_name="vitstr.onnx",
dummy_input=dummy_input
)
Using your ONNX exported model
Expand Down
12 changes: 6 additions & 6 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True)
Expand All @@ -309,7 +309,7 @@ Additionally, you can change the batch size of the underlying detection and reco

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, det_bs=4, reco_bs=1024)
To modify the output structure you can pass the following arguments to the predictor which will be handled by the underlying `DocumentBuilder`:
Expand All @@ -322,7 +322,7 @@ For example to disable the automatic grouping of lines into blocks:

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, resolve_blocks=False)
Expand Down Expand Up @@ -477,7 +477,7 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_page_orientation=True)
Expand All @@ -489,15 +489,15 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_crop_orientation=True)
* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model.

.. code:: python3
from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
class CustomHook:
def __call__(self, loc_preds):
Expand Down
10 changes: 8 additions & 2 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_torch_available

from .. import classification
from ..preprocessor import PreProcessor
Expand Down Expand Up @@ -48,7 +48,13 @@ def _orientation_predictor(
# Load directly classifier from backbone
_model = classification.__dict__[arch](pretrained=pretrained)
else:
if not isinstance(arch, classification.MobileNetV3):
allowed_archs = (classification.MobileNetV3,)
if is_torch_available():
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)

if not isinstance(arch, allowed_archs):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down
13 changes: 9 additions & 4 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
out["preds"] = _postprocess(prob_map)

if target is not None:
thresh_map = self.thresh_head(feat_concat)
Expand Down
13 changes: 9 additions & 4 deletions doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
out["preds"] = _postprocess(prob_map)

if target is not None:
loss = self.compute_loss(logits, target)
Expand Down
15 changes: 10 additions & 5 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Post-process boxes
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = _postprocess(prob_map)

if target is not None:
loss = self.compute_loss(logits, target)
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
if isinstance(_model, detection.FAST):
_model = reparameterize(_model)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
allowed_archs = (detection.DBNet, detection.LinkNet, detection.FAST)
if is_torch_available():
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)
if not isinstance(arch, allowed_archs):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,13 @@ def forward(
out["out_map"] = logits

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = self.postprocessor(logits)
out["preds"] = _postprocess(logits)

if target is not None:
out["loss"] = self.compute_loss(logits, target)
Expand Down
8 changes: 7 additions & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,13 @@ def forward(
out["out_map"] = logits

if return_preds:
out["preds"] = self.postprocessor(logits)
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = _postprocess(logits)

return out

Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,13 @@ def forward(
out["out_map"] = logits

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = self.postprocessor(logits)
out["preds"] = _postprocess(logits)

if target is not None:
out["loss"] = loss
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,13 @@ def forward(
out["out_map"] = decoded_features

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
out["preds"] = self.postprocessor(decoded_features)
out["preds"] = _postprocess(decoded_features)

if target is not None:
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
A merged character sequence.
Example::
>>> from doctr.model.recognition.utils import merge_sequences
>>> from doctr.models.recognition.utils import merge_sequences
>>> merge_sequences('abcd', 'cdefgh', 1.4)
'abcdefgh'
>>> merge_sequences('abcdi', 'cdefgh', 1.4)
Expand Down Expand Up @@ -71,7 +71,7 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
A merged character sequence
Example::
>>> from doctr.model.recognition.utils import merge_multi_sequences
>>> from doctr.models.recognition.utils import merge_multi_sequences
>>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
'abcdefghijkl'
"""
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,13 @@ def forward(
out["out_map"] = decoded_features

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
out["preds"] = self.postprocessor(decoded_features)
out["preds"] = _postprocess(decoded_features)

if target is not None:
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
Expand Down
11 changes: 7 additions & 4 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_torch_available
from doctr.models.preprocessor import PreProcessor

from .. import recognition
Expand Down Expand Up @@ -35,9 +35,12 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
)
else:
if not isinstance(
arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
):
allowed_archs = (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
if is_torch_available():
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)
if not isinstance(arch, allowed_archs):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down
Loading

0 comments on commit 93b4a54

Please sign in to comment.