diff --git a/.conda/meta.yaml b/.conda/meta.yaml
index 5eb28efa9c..f194539df1 100644
--- a/.conda/meta.yaml
+++ b/.conda/meta.yaml
@@ -24,7 +24,6 @@ requirements:
- numpy >=1.16.0, <2.0.0
- scipy >=1.4.0, <2.0.0
- pillow >=9.2.0
- - matplotlib >=3.1.0
- h5py >=3.1.0, <4.0.0
- opencv >=4.5.0, <5.0.0
- pypdfium2-team::pypdfium2_helpers >=4.0.0, <5.0.0
@@ -33,10 +32,7 @@ requirements:
- langdetect >=1.0.9, <2.0.0
- rapidfuzz >=3.0.0, <4.0.0
- huggingface_hub >=0.20.0, <1.0.0
- - matplotlib >=3.1.0
- - weasyprint >=55.0
- defusedxml >=0.7.0
- - mplcursors >=0.3
- anyascii >=0.3.2
- tqdm >=4.30.0
diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml
index 8f45a8ad2c..725579f373 100644
--- a/.github/workflows/builds.yml
+++ b/.github/workflows/builds.yml
@@ -41,12 +41,12 @@ jobs:
name: Install package (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- name: Import package
run: python -c "import doctr; print(doctr.__version__)"
diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml
index d62c7d79e0..0fc7f203ff 100644
--- a/.github/workflows/demo.yml
+++ b/.github/workflows/demo.yml
@@ -42,13 +42,13 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r demo/tf-requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r demo/pt-requirements.txt
- if: matrix.framework == 'tensorflow'
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 8efd563559..79965b560c 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[tf]
+ pip install -e .[tf,viz,html]
pip install -e .[docs]
- name: Build documentation
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index ecfb6a0e3a..c7c5120ef9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -e .[testing]
- name: Run unittests
run: |
@@ -61,7 +61,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -e .[testing]
- name: Run unittests
run: |
@@ -94,7 +94,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -e .[testing]
- name: Run unittests
diff --git a/.github/workflows/public_docker_images.yml b/.github/workflows/public_docker_images.yml
index 61a974e8fa..bf1cb4ea38 100644
--- a/.github/workflows/public_docker_images.yml
+++ b/.github/workflows/public_docker_images.yml
@@ -23,7 +23,7 @@ jobs:
matrix:
# Must match version at https://www.python.org/ftp/python/
python: ["3.9.18", "3.10.13", "3.11.8"]
- framework: ["tf", "torch"]
+ framework: ["tf", "torch", "tf,viz,html", "torch,viz,html"]
system: ["cpu", "gpu"]
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml
index 87baef1616..045c4676f6 100644
--- a/.github/workflows/pull_requests.yml
+++ b/.github/workflows/pull_requests.yml
@@ -22,7 +22,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -e .[docs]
- name: Build documentation
diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml
index e9f7bd54dc..b579c79e9c 100644
--- a/.github/workflows/references.yml
+++ b/.github/workflows/references.yml
@@ -42,14 +42,14 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r references/requirements.txt
sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r references/requirements.txt
sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y
- if: matrix.framework == 'tensorflow'
@@ -94,13 +94,13 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r references/requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r references/requirements.txt
- name: Download and extract detection toy set
run: |
@@ -160,13 +160,13 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r references/requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r references/requirements.txt
- name: Download and extract toy set
run: |
@@ -211,12 +211,12 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'tensorflow'
name: Evaluate text recognition (TF)
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset IIIT5K -b 32
@@ -255,12 +255,12 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'tensorflow'
name: Benchmark latency (TF)
run: python references/recognition/latency_tensorflow.py crnn_mobilenet_v3_small --it 5
@@ -303,13 +303,13 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r references/requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r references/requirements.txt
- name: Download and extract toy set
run: |
@@ -354,13 +354,13 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
pip install -r references/requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
pip install -r references/requirements.txt
- if: matrix.framework == 'tensorflow'
name: Evaluate text detection (TF)
@@ -400,12 +400,12 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'tensorflow'
name: Benchmark latency (TF)
run: python references/detection/latency_tensorflow.py db_mobilenet_v3_large --it 5 --size 512
@@ -444,12 +444,12 @@ jobs:
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Benchmark latency (PT)
run: python references/obj_detection/latency_pytorch.py fasterrcnn_mobilenet_v3_large_fpn --it 5 --size 512
diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml
index a91d4cf5e6..37b0ce3652 100644
--- a/.github/workflows/scripts.yml
+++ b/.github/workflows/scripts.yml
@@ -41,12 +41,12 @@ jobs:
name: Install package (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- name: Run analysis script
run: |
@@ -87,12 +87,12 @@ jobs:
name: Install package (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- name: Run detection script
run: |
@@ -133,12 +133,12 @@ jobs:
name: Install package (TF)
run: |
python -m pip install --upgrade pip
- pip install -e .[tf] --upgrade
+ pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
- pip install -e .[torch] --upgrade
+ pip install -e .[torch,viz,html] --upgrade
- name: Run evaluation script
run: |
python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
diff --git a/README.md b/README.md
index 82cc3c2480..d82ef8344b 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ from doctr.io import DocumentFile
pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
# Image
single_img_doc = DocumentFile.from_images("path/to/your/img.jpg")
-# Webpage
+# Webpage (requires `weasyprint` to be installed)
webpage_doc = DocumentFile.from_url("https://www.yoursite.com")
# Multiple page images
multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"])
@@ -75,6 +75,7 @@ If both options are set to False, the predictor will always fit and return rotat
To interpret your model's predictions, you can visualize them interactively as follows:
```python
+# Display the result (requires matplotlib & mplcursors to be installed)
result.show()
```
@@ -136,16 +137,6 @@ The KIE predictor results per page are in a dictionary format with each key repr
Python 3.9 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR.
-Since we use [weasyprint](https://weasyprint.org/), you will need extra dependencies if you are not running Linux.
-
-For MacOS users, you can install them as follows:
-
-```shell
-brew install cairo pango gdk-pixbuf libffi
-```
-
-For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases).
-
### Latest release
You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows:
@@ -350,9 +341,15 @@ Your API should now be running locally on your port 8002. Access your automatica
```python
import requests
+
+headers = {"accept": "application/json"}
+params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"}
+
with open('/path/to/your/doc.jpg', 'rb') as f:
- data = f.read()
-response = requests.post("http://localhost:8002/ocr", files={'file': data}).json()
+ files = [ # application/pdf, image/jpeg, image/png supported
+ ("files", ("doc.jpg", f.read(), "image/jpeg")),
+ ]
+print(requests.post("http://localhost:8080/ocr", headers=headers, params=params, files=files).json())
```
### Example notebooks
diff --git a/docs/README.md b/docs/README.md
index 972bda511d..bf92c7bc45 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -8,6 +8,6 @@ To install only the `docs` environment, you can do:
```bash
# Make sure you are at the root of the repository before executing these commands
python -m pip install --upgrade pip
-pip install -e .[tf] # or .[torch]
+pip install -e .[tf,viz,html] # or .[torch,viz,html]
pip install -e .[docs]
```
diff --git a/docs/source/getting_started/installing.rst b/docs/source/getting_started/installing.rst
index d746ebc0b4..a00451514a 100644
--- a/docs/source/getting_started/installing.rst
+++ b/docs/source/getting_started/installing.rst
@@ -19,17 +19,6 @@ For MacBooks with M1 chip, you will need some additional packages or specific ve
* `TensorFlow 2 Metal Plugin `_
* `PyTorch >= 1.12.0 `_
-If you are running another OS than Linux, you will need a few extra dependencies.
-
-For MacOS users, you can install them using `Homebrew `_ as follows:
-
-.. code:: shell
-
- brew install cairo pango gdk-pixbuf libffi
-
-For Windows users, those dependencies are included in GTK. You can find the latest installer over `here `_.
-
-
Via Python Package
==================
@@ -49,12 +38,18 @@ We strive towards reducing framework-specific dependencies to a minimum, but som
.. code:: bash
pip install "python-doctr[tf]"
+ # or with preinstalled packages for visualization & html support
+ pip install "python-doctr[tf,viz,html]"
.. tab:: PyTorch
.. code:: bash
pip install "python-doctr[torch]"
+ # or with preinstalled packages for visualization & html support
+ pip install "python-doctr[torch,viz,html]"
+
+
Via Conda (Only for Linux)
diff --git a/doctr/file_utils.py b/doctr/file_utils.py
index 93f12a55cb..68e9dfffac 100644
--- a/doctr/file_utils.py
+++ b/doctr/file_utils.py
@@ -5,21 +5,16 @@
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
+import importlib.metadata
import importlib.util
import logging
import os
-import sys
+from typing import Optional
CLASS_NAME: str = "words"
-if sys.version_info < (3, 8): # pragma: no cover
- import importlib_metadata
-else:
- import importlib.metadata as importlib_metadata
-
-
-__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
+__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
@@ -32,9 +27,9 @@
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
- _torch_version = importlib_metadata.version("torch")
+ _torch_version = importlib.metadata.version("torch")
logging.info(f"PyTorch version {_torch_version} available.")
- except importlib_metadata.PackageNotFoundError: # pragma: no cover
+ except importlib.metadata.PackageNotFoundError: # pragma: no cover
_torch_available = False
else: # pragma: no cover
logging.info("Disabling PyTorch because USE_TF is set")
@@ -59,9 +54,9 @@
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
- _tf_version = importlib_metadata.version(pkg)
+ _tf_version = importlib.metadata.version(pkg)
break
- except importlib_metadata.PackageNotFoundError:
+ except importlib.metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
@@ -82,6 +77,25 @@
)
+def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
+ """
+ package requirement helper
+
+ Args:
+ ----
+ name: name of the package
+ extra_message: additional message to display if the package is not found
+ """
+ try:
+ _pkg_version = importlib.metadata.version(name)
+ logging.info(f"{name} version {_pkg_version} available.")
+ except importlib.metadata.PackageNotFoundError:
+ raise ImportError(
+ f"\n\n{extra_message if extra_message is not None else ''} "
+ f"\nPlease install it with the following command: pip install {name}\n"
+ )
+
+
def is_torch_available():
"""Whether PyTorch is installed."""
return _torch_available
diff --git a/doctr/io/elements.py b/doctr/io/elements.py
index 4862b17b6b..b6d2cb9b5b 100644
--- a/doctr/io/elements.py
+++ b/doctr/io/elements.py
@@ -12,14 +12,19 @@
from xml.etree.ElementTree import Element as ETElement
from xml.etree.ElementTree import SubElement
-import matplotlib.pyplot as plt
import numpy as np
import doctr
+from doctr.file_utils import requires_package
from doctr.utils.common_types import BoundingBox
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
+from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page
from doctr.utils.repr import NestedObject
-from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
+
+try: # optional dependency for visualization
+ from doctr.utils.visualization import visualize_kie_page, visualize_page
+except ModuleNotFoundError:
+ pass
__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
@@ -282,6 +287,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **
preserve_aspect_ratio: pass True if you passed True to the predictor
**kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
"""
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
+ import matplotlib.pyplot as plt
+
visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
plt.show(**kwargs)
@@ -457,6 +466,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **
preserve_aspect_ratio: pass True if you passed True to the predictor
**kwargs: keyword arguments passed to the matplotlib.pyplot.show method
"""
+ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
+ requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
+ import matplotlib.pyplot as plt
+
visualize_kie_page(
self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
)
diff --git a/doctr/io/html.py b/doctr/io/html.py
index dd2d00c102..f8a8da237d 100644
--- a/doctr/io/html.py
+++ b/doctr/io/html.py
@@ -5,8 +5,6 @@
from typing import Any
-from weasyprint import HTML
-
__all__ = ["read_html"]
@@ -25,4 +23,6 @@ def read_html(url: str, **kwargs: Any) -> bytes:
-------
decoded PDF file as a bytes stream
"""
+ from weasyprint import HTML
+
return HTML(url, **kwargs).write_pdf()
diff --git a/doctr/io/reader.py b/doctr/io/reader.py
index 37af393e46..76f7317cb1 100644
--- a/doctr/io/reader.py
+++ b/doctr/io/reader.py
@@ -8,6 +8,7 @@
import numpy as np
+from doctr.file_utils import requires_package
from doctr.utils.common_types import AbstractFile
from .html import read_html
@@ -54,6 +55,11 @@ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]:
-------
the list of pages decoded as numpy ndarray of shape H x W x 3
"""
+ requires_package(
+ "weasyprint",
+ "`.from_url` requires weasyprint installed.\n"
+ + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation",
+ )
pdf_stream = read_html(url)
return cls.from_pdf(pdf_stream, **kwargs)
diff --git a/doctr/utils/reconstitution.py b/doctr/utils/reconstitution.py
new file mode 100644
index 0000000000..82ae20cdd0
--- /dev/null
+++ b/doctr/utils/reconstitution.py
@@ -0,0 +1,126 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+from typing import Any, Dict, Optional
+
+import numpy as np
+from anyascii import anyascii
+from PIL import Image, ImageDraw
+
+from .fonts import get_font
+
+__all__ = ["synthesize_page", "synthesize_kie_page"]
+
+
+def synthesize_page(
+ page: Dict[str, Any],
+ draw_proba: bool = False,
+ font_family: Optional[str] = None,
+) -> np.ndarray:
+ """Draw a the content of the element page (OCR response) on a blank page.
+
+ Args:
+ ----
+ page: exported Page object to represent
+ draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
+ font_size: size of the font, default font = 13
+ font_family: family of the font
+
+ Returns:
+ -------
+ the synthesized page
+ """
+ # Draw template
+ h, w = page["dimensions"]
+ response = 255 * np.ones((h, w, 3), dtype=np.int32)
+
+ # Draw each word
+ for block in page["blocks"]:
+ for line in block["lines"]:
+ for word in line["words"]:
+ # Get absolute word geometry
+ (xmin, ymin), (xmax, ymax) = word["geometry"]
+ xmin, xmax = int(round(w * xmin)), int(round(w * xmax))
+ ymin, ymax = int(round(h * ymin)), int(round(h * ymax))
+
+ # White drawing context adapted to font size, 0.75 factor to convert pts --> pix
+ font = get_font(font_family, int(0.75 * (ymax - ymin)))
+ img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255))
+ d = ImageDraw.Draw(img)
+ # Draw in black the value of the word
+ try:
+ d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))
+ except UnicodeEncodeError:
+ # When character cannot be encoded, use its anyascii version
+ d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0))
+
+ # Colorize if draw_proba
+ if draw_proba:
+ p = int(255 * word["confidence"])
+ mask = np.where(np.array(img) == 0, 1, 0)
+ proba: np.ndarray = np.array([255 - p, 0, p])
+ color = mask * proba[np.newaxis, np.newaxis, :]
+ white_mask = 255 * (1 - mask)
+ img = color + white_mask
+
+ # Write to response page
+ response[ymin:ymax, xmin:xmax, :] = np.array(img)
+
+ return response
+
+
+def synthesize_kie_page(
+ page: Dict[str, Any],
+ draw_proba: bool = False,
+ font_family: Optional[str] = None,
+) -> np.ndarray:
+ """Draw a the content of the element page (OCR response) on a blank page.
+
+ Args:
+ ----
+ page: exported Page object to represent
+ draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
+ font_size: size of the font, default font = 13
+ font_family: family of the font
+
+ Returns:
+ -------
+ the synthesized page
+ """
+ # Draw template
+ h, w = page["dimensions"]
+ response = 255 * np.ones((h, w, 3), dtype=np.int32)
+
+ # Draw each word
+ for predictions in page["predictions"].values():
+ for prediction in predictions:
+ # Get aboslute word geometry
+ (xmin, ymin), (xmax, ymax) = prediction["geometry"]
+ xmin, xmax = int(round(w * xmin)), int(round(w * xmax))
+ ymin, ymax = int(round(h * ymin)), int(round(h * ymax))
+
+ # White drawing context adapted to font size, 0.75 factor to convert pts --> pix
+ font = get_font(font_family, int(0.75 * (ymax - ymin)))
+ img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255))
+ d = ImageDraw.Draw(img)
+ # Draw in black the value of the word
+ try:
+ d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0))
+ except UnicodeEncodeError:
+ # When character cannot be encoded, use its anyascii version
+ d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0))
+
+ # Colorize if draw_proba
+ if draw_proba:
+ p = int(255 * prediction["confidence"])
+ mask = np.where(np.array(img) == 0, 1, 0)
+ proba: np.ndarray = np.array([255 - p, 0, p])
+ color = mask * proba[np.newaxis, np.newaxis, :]
+ white_mask = 255 * (1 - mask)
+ img = color + white_mask
+
+ # Write to response page
+ response[ymin:ymax, xmin:xmax, :] = np.array(img)
+
+ return response
diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py
index 66d9d9a943..4e97f751fe 100644
--- a/doctr/utils/visualization.py
+++ b/doctr/utils/visualization.py
@@ -9,16 +9,12 @@
import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
-import mplcursors
import numpy as np
-from anyascii import anyascii
from matplotlib.figure import Figure
-from PIL import Image, ImageDraw
from .common_types import BoundingBox, Polygon4P
-from .fonts import get_font
-__all__ = ["visualize_page", "synthesize_page", "visualize_kie_page", "synthesize_kie_page", "draw_boxes"]
+__all__ = ["visualize_page", "visualize_kie_page", "draw_boxes"]
def rect_patch(
@@ -281,6 +277,8 @@ def visualize_page(
artists.append(rect)
if interactive:
+ import mplcursors
+
# Create mlp Cursor to hover patches in artists
mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label()))
fig.tight_layout(pad=0.0)
@@ -288,63 +286,6 @@ def visualize_page(
return fig
-def synthesize_page(
- page: Dict[str, Any],
- draw_proba: bool = False,
- font_family: Optional[str] = None,
-) -> np.ndarray:
- """Draw a the content of the element page (OCR response) on a blank page.
-
- Args:
- ----
- page: exported Page object to represent
- draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
- font_size: size of the font, default font = 13
- font_family: family of the font
-
- Returns:
- -------
- the synthesized page
- """
- # Draw template
- h, w = page["dimensions"]
- response = 255 * np.ones((h, w, 3), dtype=np.int32)
-
- # Draw each word
- for block in page["blocks"]:
- for line in block["lines"]:
- for word in line["words"]:
- # Get aboslute word geometry
- (xmin, ymin), (xmax, ymax) = word["geometry"]
- xmin, xmax = int(round(w * xmin)), int(round(w * xmax))
- ymin, ymax = int(round(h * ymin)), int(round(h * ymax))
-
- # White drawing context adapted to font size, 0.75 factor to convert pts --> pix
- font = get_font(font_family, int(0.75 * (ymax - ymin)))
- img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255))
- d = ImageDraw.Draw(img)
- # Draw in black the value of the word
- try:
- d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))
- except UnicodeEncodeError:
- # When character cannot be encoded, use its anyascii version
- d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0))
-
- # Colorize if draw_proba
- if draw_proba:
- p = int(255 * word["confidence"])
- mask = np.where(np.array(img) == 0, 1, 0)
- proba: np.ndarray = np.array([255 - p, 0, p])
- color = mask * proba[np.newaxis, np.newaxis, :]
- white_mask = 255 * (1 - mask)
- img = color + white_mask
-
- # Write to response page
- response[ymin:ymax, xmin:xmax, :] = np.array(img)
-
- return response
-
-
def visualize_kie_page(
page: Dict[str, Any],
image: np.ndarray,
@@ -413,6 +354,8 @@ def visualize_kie_page(
artists.append(rect)
if interactive:
+ import mplcursors
+
# Create mlp Cursor to hover patches in artists
mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label()))
fig.tight_layout(pad=0.0)
@@ -420,62 +363,6 @@ def visualize_kie_page(
return fig
-def synthesize_kie_page(
- page: Dict[str, Any],
- draw_proba: bool = False,
- font_family: Optional[str] = None,
-) -> np.ndarray:
- """Draw a the content of the element page (OCR response) on a blank page.
-
- Args:
- ----
- page: exported Page object to represent
- draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
- font_size: size of the font, default font = 13
- font_family: family of the font
-
- Returns:
- -------
- the synthesized page
- """
- # Draw template
- h, w = page["dimensions"]
- response = 255 * np.ones((h, w, 3), dtype=np.int32)
-
- # Draw each word
- for predictions in page["predictions"].values():
- for prediction in predictions:
- # Get aboslute word geometry
- (xmin, ymin), (xmax, ymax) = prediction["geometry"]
- xmin, xmax = int(round(w * xmin)), int(round(w * xmax))
- ymin, ymax = int(round(h * ymin)), int(round(h * ymax))
-
- # White drawing context adapted to font size, 0.75 factor to convert pts --> pix
- font = get_font(font_family, int(0.75 * (ymax - ymin)))
- img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255))
- d = ImageDraw.Draw(img)
- # Draw in black the value of the word
- try:
- d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0))
- except UnicodeEncodeError:
- # When character cannot be encoded, use its anyascii version
- d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0))
-
- # Colorize if draw_proba
- if draw_proba:
- p = int(255 * prediction["confidence"])
- mask = np.where(np.array(img) == 0, 1, 0)
- proba: np.ndarray = np.array([255 - p, 0, p])
- color = mask * proba[np.newaxis, np.newaxis, :]
- white_mask = 255 * (1 - mask)
- img = color + white_mask
-
- # Write to response page
- response[ymin:ymax, xmin:xmax, :] = np.array(img)
-
- return response
-
-
def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, int, int]] = None, **kwargs) -> None:
"""Draw an array of relative straight boxes on an image
diff --git a/pyproject.toml b/pyproject.toml
index 60d9a7cfeb..06c3655699 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,6 @@ classifiers=[
]
dynamic = ["version"]
dependencies = [
- "importlib_metadata",
# For proper typing, mypy needs numpy>=1.20.0 (cf. https://github.com/numpy/numpy/pull/16515)
# Additional typing support is brought by numpy>=1.22.4, but core build sticks to >=1.16.0
"numpy>=1.16.0,<2.0.0",
@@ -45,11 +44,8 @@ dependencies = [
"langdetect>=1.0.9,<2.0.0",
"rapidfuzz>=3.0.0,<4.0.0",
"huggingface-hub>=0.20.0,<1.0.0",
- "matplotlib>=3.1.0",
- "weasyprint>=55.0",
"Pillow>=9.2.0",
"defusedxml>=0.7.0",
- "mplcursors>=0.3",
"anyascii>=0.3.2",
"tqdm>=4.30.0",
]
@@ -66,6 +62,13 @@ torch = [
"torchvision>=0.13.0",
"onnx>=1.12.0,<3.0.0",
]
+html = [
+ "weasyprint>=55.0",
+]
+viz = [
+ "matplotlib>=3.1.0",
+ "mplcursors>=0.3",
+]
testing = [
"pytest>=5.3.2",
"coverage[toml]>=4.5.4",
@@ -99,6 +102,10 @@ dev = [
"torch>=1.12.0,<3.0.0",
"torchvision>=0.13.0",
"onnx>=1.12.0,<3.0.0",
+ # Extras
+ "weasyprint>=55.0",
+ "matplotlib>=3.1.0",
+ "mplcursors>=0.3",
# Testing
"pytest>=5.3.2",
"coverage[toml]>=4.5.4",
diff --git a/references/requirements.txt b/references/requirements.txt
index bb72bdd47b..679df79518 100644
--- a/references/requirements.txt
+++ b/references/requirements.txt
@@ -3,3 +3,4 @@ tqdm
wandb>=0.10.31
psutil>=5.9.0
clearml>=1.11.1
+matplotlib>=3.1.0
diff --git a/tests/common/test_core.py b/tests/common/test_core.py
index cd2758b0df..2fb8d1b9b0 100644
--- a/tests/common/test_core.py
+++ b/tests/common/test_core.py
@@ -1,6 +1,7 @@
import pytest
import doctr
+from doctr.file_utils import requires_package
def test_version():
@@ -15,3 +16,9 @@ def test_is_tf_available():
@pytest.mark.skipif(doctr.is_torch_available() and doctr.is_tf_available(), reason="torch and tf are available")
def test_is_torch_available():
assert not doctr.is_torch_available()
+
+
+def test_requires_package():
+ requires_package("numpy") # availbable
+ with pytest.raises(ImportError): # not available
+ requires_package("non_existent_package")
diff --git a/tests/common/test_utils_reconstitution.py b/tests/common/test_utils_reconstitution.py
new file mode 100644
index 0000000000..3b70e67070
--- /dev/null
+++ b/tests/common/test_utils_reconstitution.py
@@ -0,0 +1,12 @@
+import numpy as np
+from test_io_elements import _mock_pages
+
+from doctr.utils import reconstitution
+
+
+def test_synthesize_page():
+ pages = _mock_pages()
+ reconstitution.synthesize_page(pages[0].export(), draw_proba=False)
+ render = reconstitution.synthesize_page(pages[0].export(), draw_proba=True)
+ assert isinstance(render, np.ndarray)
+ assert render.shape == (*pages[0].dimensions, 3)
diff --git a/tests/common/test_utils_visualization.py b/tests/common/test_utils_visualization.py
index b1e9998698..ae232ebb85 100644
--- a/tests/common/test_utils_visualization.py
+++ b/tests/common/test_utils_visualization.py
@@ -21,14 +21,6 @@ def test_visualize_page():
visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100))
-def test_synthesize_page():
- pages = _mock_pages()
- visualization.synthesize_page(pages[0].export(), draw_proba=False)
- render = visualization.synthesize_page(pages[0].export(), draw_proba=True)
- assert isinstance(render, np.ndarray)
- assert render.shape == (*pages[0].dimensions, 3)
-
-
def test_draw_boxes():
image = np.ones((256, 256, 3), dtype=np.float32)
boxes = [