diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 5eb28efa9c..f194539df1 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -24,7 +24,6 @@ requirements: - numpy >=1.16.0, <2.0.0 - scipy >=1.4.0, <2.0.0 - pillow >=9.2.0 - - matplotlib >=3.1.0 - h5py >=3.1.0, <4.0.0 - opencv >=4.5.0, <5.0.0 - pypdfium2-team::pypdfium2_helpers >=4.0.0, <5.0.0 @@ -33,10 +32,7 @@ requirements: - langdetect >=1.0.9, <2.0.0 - rapidfuzz >=3.0.0, <4.0.0 - huggingface_hub >=0.20.0, <1.0.0 - - matplotlib >=3.1.0 - - weasyprint >=55.0 - defusedxml >=0.7.0 - - mplcursors >=0.3 - anyascii >=0.3.2 - tqdm >=4.30.0 diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml index 8f45a8ad2c..725579f373 100644 --- a/.github/workflows/builds.yml +++ b/.github/workflows/builds.yml @@ -41,12 +41,12 @@ jobs: name: Install package (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install package (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - name: Import package run: python -c "import doctr; print(doctr.__version__)" diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml index d62c7d79e0..0fc7f203ff 100644 --- a/.github/workflows/demo.yml +++ b/.github/workflows/demo.yml @@ -42,13 +42,13 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r demo/tf-requirements.txt - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r demo/pt-requirements.txt - if: matrix.framework == 'tensorflow' diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8efd563559..79965b560c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[tf] + pip install -e .[tf,viz,html] pip install -e .[docs] - name: Build documentation diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ecfb6a0e3a..c7c5120ef9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -e .[testing] - name: Run unittests run: | @@ -61,7 +61,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -e .[testing] - name: Run unittests run: | @@ -94,7 +94,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -e .[testing] - name: Run unittests diff --git a/.github/workflows/public_docker_images.yml b/.github/workflows/public_docker_images.yml index 61a974e8fa..bf1cb4ea38 100644 --- a/.github/workflows/public_docker_images.yml +++ b/.github/workflows/public_docker_images.yml @@ -23,7 +23,7 @@ jobs: matrix: # Must match version at https://www.python.org/ftp/python/ python: ["3.9.18", "3.10.13", "3.11.8"] - framework: ["tf", "torch"] + framework: ["tf", "torch", "tf,viz,html", "torch,viz,html"] system: ["cpu", "gpu"] # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml index 87baef1616..045c4676f6 100644 --- a/.github/workflows/pull_requests.yml +++ b/.github/workflows/pull_requests.yml @@ -22,7 +22,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -e .[docs] - name: Build documentation diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index e9f7bd54dc..b579c79e9c 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -42,14 +42,14 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r references/requirements.txt sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r references/requirements.txt sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y - if: matrix.framework == 'tensorflow' @@ -94,13 +94,13 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r references/requirements.txt - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r references/requirements.txt - name: Download and extract detection toy set run: | @@ -160,13 +160,13 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r references/requirements.txt - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r references/requirements.txt - name: Download and extract toy set run: | @@ -211,12 +211,12 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - if: matrix.framework == 'tensorflow' name: Evaluate text recognition (TF) run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset IIIT5K -b 32 @@ -255,12 +255,12 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - if: matrix.framework == 'tensorflow' name: Benchmark latency (TF) run: python references/recognition/latency_tensorflow.py crnn_mobilenet_v3_small --it 5 @@ -303,13 +303,13 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r references/requirements.txt - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r references/requirements.txt - name: Download and extract toy set run: | @@ -354,13 +354,13 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade pip install -r references/requirements.txt - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade pip install -r references/requirements.txt - if: matrix.framework == 'tensorflow' name: Evaluate text detection (TF) @@ -400,12 +400,12 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - if: matrix.framework == 'tensorflow' name: Benchmark latency (TF) run: python references/detection/latency_tensorflow.py db_mobilenet_v3_large --it 5 --size 512 @@ -444,12 +444,12 @@ jobs: name: Install dependencies (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install dependencies (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Benchmark latency (PT) run: python references/obj_detection/latency_pytorch.py fasterrcnn_mobilenet_v3_large_fpn --it 5 --size 512 diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index a91d4cf5e6..37b0ce3652 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -41,12 +41,12 @@ jobs: name: Install package (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install package (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - name: Run analysis script run: | @@ -87,12 +87,12 @@ jobs: name: Install package (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install package (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - name: Run detection script run: | @@ -133,12 +133,12 @@ jobs: name: Install package (TF) run: | python -m pip install --upgrade pip - pip install -e .[tf] --upgrade + pip install -e .[tf,viz,html] --upgrade - if: matrix.framework == 'pytorch' name: Install package (PT) run: | python -m pip install --upgrade pip - pip install -e .[torch] --upgrade + pip install -e .[torch,viz,html] --upgrade - name: Run evaluation script run: | python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 diff --git a/README.md b/README.md index 82cc3c2480..d82ef8344b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ from doctr.io import DocumentFile pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf") # Image single_img_doc = DocumentFile.from_images("path/to/your/img.jpg") -# Webpage +# Webpage (requires `weasyprint` to be installed) webpage_doc = DocumentFile.from_url("https://www.yoursite.com") # Multiple page images multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"]) @@ -75,6 +75,7 @@ If both options are set to False, the predictor will always fit and return rotat To interpret your model's predictions, you can visualize them interactively as follows: ```python +# Display the result (requires matplotlib & mplcursors to be installed) result.show() ``` @@ -136,16 +137,6 @@ The KIE predictor results per page are in a dictionary format with each key repr Python 3.9 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR. -Since we use [weasyprint](https://weasyprint.org/), you will need extra dependencies if you are not running Linux. - -For MacOS users, you can install them as follows: - -```shell -brew install cairo pango gdk-pixbuf libffi -``` - -For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases). - ### Latest release You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows: @@ -350,9 +341,15 @@ Your API should now be running locally on your port 8002. Access your automatica ```python import requests + +headers = {"accept": "application/json"} +params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"} + with open('/path/to/your/doc.jpg', 'rb') as f: - data = f.read() -response = requests.post("http://localhost:8002/ocr", files={'file': data}).json() + files = [ # application/pdf, image/jpeg, image/png supported + ("files", ("doc.jpg", f.read(), "image/jpeg")), + ] +print(requests.post("http://localhost:8080/ocr", headers=headers, params=params, files=files).json()) ``` ### Example notebooks diff --git a/docs/README.md b/docs/README.md index 972bda511d..bf92c7bc45 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,6 @@ To install only the `docs` environment, you can do: ```bash # Make sure you are at the root of the repository before executing these commands python -m pip install --upgrade pip -pip install -e .[tf] # or .[torch] +pip install -e .[tf,viz,html] # or .[torch,viz,html] pip install -e .[docs] ``` diff --git a/docs/source/getting_started/installing.rst b/docs/source/getting_started/installing.rst index d746ebc0b4..a00451514a 100644 --- a/docs/source/getting_started/installing.rst +++ b/docs/source/getting_started/installing.rst @@ -19,17 +19,6 @@ For MacBooks with M1 chip, you will need some additional packages or specific ve * `TensorFlow 2 Metal Plugin `_ * `PyTorch >= 1.12.0 `_ -If you are running another OS than Linux, you will need a few extra dependencies. - -For MacOS users, you can install them using `Homebrew `_ as follows: - -.. code:: shell - - brew install cairo pango gdk-pixbuf libffi - -For Windows users, those dependencies are included in GTK. You can find the latest installer over `here `_. - - Via Python Package ================== @@ -49,12 +38,18 @@ We strive towards reducing framework-specific dependencies to a minimum, but som .. code:: bash pip install "python-doctr[tf]" + # or with preinstalled packages for visualization & html support + pip install "python-doctr[tf,viz,html]" .. tab:: PyTorch .. code:: bash pip install "python-doctr[torch]" + # or with preinstalled packages for visualization & html support + pip install "python-doctr[torch,viz,html]" + + Via Conda (Only for Linux) diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 93f12a55cb..68e9dfffac 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -5,21 +5,16 @@ # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py +import importlib.metadata import importlib.util import logging import os -import sys +from typing import Optional CLASS_NAME: str = "words" -if sys.version_info < (3, 8): # pragma: no cover - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - - -__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"] +__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) @@ -32,9 +27,9 @@ _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: try: - _torch_version = importlib_metadata.version("torch") + _torch_version = importlib.metadata.version("torch") logging.info(f"PyTorch version {_torch_version} available.") - except importlib_metadata.PackageNotFoundError: # pragma: no cover + except importlib.metadata.PackageNotFoundError: # pragma: no cover _torch_available = False else: # pragma: no cover logging.info("Disabling PyTorch because USE_TF is set") @@ -59,9 +54,9 @@ # For the metadata, we have to look for both tensorflow and tensorflow-cpu for pkg in candidates: try: - _tf_version = importlib_metadata.version(pkg) + _tf_version = importlib.metadata.version(pkg) break - except importlib_metadata.PackageNotFoundError: + except importlib.metadata.PackageNotFoundError: pass _tf_available = _tf_version is not None if _tf_available: @@ -82,6 +77,25 @@ ) +def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover + """ + package requirement helper + + Args: + ---- + name: name of the package + extra_message: additional message to display if the package is not found + """ + try: + _pkg_version = importlib.metadata.version(name) + logging.info(f"{name} version {_pkg_version} available.") + except importlib.metadata.PackageNotFoundError: + raise ImportError( + f"\n\n{extra_message if extra_message is not None else ''} " + f"\nPlease install it with the following command: pip install {name}\n" + ) + + def is_torch_available(): """Whether PyTorch is installed.""" return _torch_available diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 4862b17b6b..b6d2cb9b5b 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -12,14 +12,19 @@ from xml.etree.ElementTree import Element as ETElement from xml.etree.ElementTree import SubElement -import matplotlib.pyplot as plt import numpy as np import doctr +from doctr.file_utils import requires_package from doctr.utils.common_types import BoundingBox from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox +from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page from doctr.utils.repr import NestedObject -from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page + +try: # optional dependency for visualization + from doctr.utils.visualization import visualize_kie_page, visualize_page +except ModuleNotFoundError: + pass __all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"] @@ -282,6 +287,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** preserve_aspect_ratio: pass True if you passed True to the predictor **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method """ + requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") + requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") + import matplotlib.pyplot as plt + visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) plt.show(**kwargs) @@ -457,6 +466,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** preserve_aspect_ratio: pass True if you passed True to the predictor **kwargs: keyword arguments passed to the matplotlib.pyplot.show method """ + requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") + requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") + import matplotlib.pyplot as plt + visualize_kie_page( self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio ) diff --git a/doctr/io/html.py b/doctr/io/html.py index dd2d00c102..f8a8da237d 100644 --- a/doctr/io/html.py +++ b/doctr/io/html.py @@ -5,8 +5,6 @@ from typing import Any -from weasyprint import HTML - __all__ = ["read_html"] @@ -25,4 +23,6 @@ def read_html(url: str, **kwargs: Any) -> bytes: ------- decoded PDF file as a bytes stream """ + from weasyprint import HTML + return HTML(url, **kwargs).write_pdf() diff --git a/doctr/io/reader.py b/doctr/io/reader.py index 37af393e46..76f7317cb1 100644 --- a/doctr/io/reader.py +++ b/doctr/io/reader.py @@ -8,6 +8,7 @@ import numpy as np +from doctr.file_utils import requires_package from doctr.utils.common_types import AbstractFile from .html import read_html @@ -54,6 +55,11 @@ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]: ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ + requires_package( + "weasyprint", + "`.from_url` requires weasyprint installed.\n" + + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation", + ) pdf_stream = read_html(url) return cls.from_pdf(pdf_stream, **kwargs) diff --git a/doctr/utils/reconstitution.py b/doctr/utils/reconstitution.py new file mode 100644 index 0000000000..82ae20cdd0 --- /dev/null +++ b/doctr/utils/reconstitution.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. +from typing import Any, Dict, Optional + +import numpy as np +from anyascii import anyascii +from PIL import Image, ImageDraw + +from .fonts import get_font + +__all__ = ["synthesize_page", "synthesize_kie_page"] + + +def synthesize_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for block in page["blocks"]: + for line in block["lines"]: + for word in line["words"]: + # Get absolute word geometry + (xmin, ymin), (xmax, ymax) = word["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * word["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response + + +def synthesize_kie_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for predictions in page["predictions"].values(): + for prediction in predictions: + # Get aboslute word geometry + (xmin, ymin), (xmax, ymax) = prediction["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * prediction["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index 66d9d9a943..4e97f751fe 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -9,16 +9,12 @@ import cv2 import matplotlib.patches as patches import matplotlib.pyplot as plt -import mplcursors import numpy as np -from anyascii import anyascii from matplotlib.figure import Figure -from PIL import Image, ImageDraw from .common_types import BoundingBox, Polygon4P -from .fonts import get_font -__all__ = ["visualize_page", "synthesize_page", "visualize_kie_page", "synthesize_kie_page", "draw_boxes"] +__all__ = ["visualize_page", "visualize_kie_page", "draw_boxes"] def rect_patch( @@ -281,6 +277,8 @@ def visualize_page( artists.append(rect) if interactive: + import mplcursors + # Create mlp Cursor to hover patches in artists mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) fig.tight_layout(pad=0.0) @@ -288,63 +286,6 @@ def visualize_page( return fig -def synthesize_page( - page: Dict[str, Any], - draw_proba: bool = False, - font_family: Optional[str] = None, -) -> np.ndarray: - """Draw a the content of the element page (OCR response) on a blank page. - - Args: - ---- - page: exported Page object to represent - draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 - font_size: size of the font, default font = 13 - font_family: family of the font - - Returns: - ------- - the synthesized page - """ - # Draw template - h, w = page["dimensions"] - response = 255 * np.ones((h, w, 3), dtype=np.int32) - - # Draw each word - for block in page["blocks"]: - for line in block["lines"]: - for word in line["words"]: - # Get aboslute word geometry - (xmin, ymin), (xmax, ymax) = word["geometry"] - xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) - ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) - - # White drawing context adapted to font size, 0.75 factor to convert pts --> pix - font = get_font(font_family, int(0.75 * (ymax - ymin))) - img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) - d = ImageDraw.Draw(img) - # Draw in black the value of the word - try: - d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) - except UnicodeEncodeError: - # When character cannot be encoded, use its anyascii version - d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0)) - - # Colorize if draw_proba - if draw_proba: - p = int(255 * word["confidence"]) - mask = np.where(np.array(img) == 0, 1, 0) - proba: np.ndarray = np.array([255 - p, 0, p]) - color = mask * proba[np.newaxis, np.newaxis, :] - white_mask = 255 * (1 - mask) - img = color + white_mask - - # Write to response page - response[ymin:ymax, xmin:xmax, :] = np.array(img) - - return response - - def visualize_kie_page( page: Dict[str, Any], image: np.ndarray, @@ -413,6 +354,8 @@ def visualize_kie_page( artists.append(rect) if interactive: + import mplcursors + # Create mlp Cursor to hover patches in artists mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) fig.tight_layout(pad=0.0) @@ -420,62 +363,6 @@ def visualize_kie_page( return fig -def synthesize_kie_page( - page: Dict[str, Any], - draw_proba: bool = False, - font_family: Optional[str] = None, -) -> np.ndarray: - """Draw a the content of the element page (OCR response) on a blank page. - - Args: - ---- - page: exported Page object to represent - draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 - font_size: size of the font, default font = 13 - font_family: family of the font - - Returns: - ------- - the synthesized page - """ - # Draw template - h, w = page["dimensions"] - response = 255 * np.ones((h, w, 3), dtype=np.int32) - - # Draw each word - for predictions in page["predictions"].values(): - for prediction in predictions: - # Get aboslute word geometry - (xmin, ymin), (xmax, ymax) = prediction["geometry"] - xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) - ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) - - # White drawing context adapted to font size, 0.75 factor to convert pts --> pix - font = get_font(font_family, int(0.75 * (ymax - ymin))) - img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) - d = ImageDraw.Draw(img) - # Draw in black the value of the word - try: - d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) - except UnicodeEncodeError: - # When character cannot be encoded, use its anyascii version - d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0)) - - # Colorize if draw_proba - if draw_proba: - p = int(255 * prediction["confidence"]) - mask = np.where(np.array(img) == 0, 1, 0) - proba: np.ndarray = np.array([255 - p, 0, p]) - color = mask * proba[np.newaxis, np.newaxis, :] - white_mask = 255 * (1 - mask) - img = color + white_mask - - # Write to response page - response[ymin:ymax, xmin:xmax, :] = np.array(img) - - return response - - def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, int, int]] = None, **kwargs) -> None: """Draw an array of relative straight boxes on an image diff --git a/pyproject.toml b/pyproject.toml index 60d9a7cfeb..06c3655699 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ classifiers=[ ] dynamic = ["version"] dependencies = [ - "importlib_metadata", # For proper typing, mypy needs numpy>=1.20.0 (cf. https://github.com/numpy/numpy/pull/16515) # Additional typing support is brought by numpy>=1.22.4, but core build sticks to >=1.16.0 "numpy>=1.16.0,<2.0.0", @@ -45,11 +44,8 @@ dependencies = [ "langdetect>=1.0.9,<2.0.0", "rapidfuzz>=3.0.0,<4.0.0", "huggingface-hub>=0.20.0,<1.0.0", - "matplotlib>=3.1.0", - "weasyprint>=55.0", "Pillow>=9.2.0", "defusedxml>=0.7.0", - "mplcursors>=0.3", "anyascii>=0.3.2", "tqdm>=4.30.0", ] @@ -66,6 +62,13 @@ torch = [ "torchvision>=0.13.0", "onnx>=1.12.0,<3.0.0", ] +html = [ + "weasyprint>=55.0", +] +viz = [ + "matplotlib>=3.1.0", + "mplcursors>=0.3", +] testing = [ "pytest>=5.3.2", "coverage[toml]>=4.5.4", @@ -99,6 +102,10 @@ dev = [ "torch>=1.12.0,<3.0.0", "torchvision>=0.13.0", "onnx>=1.12.0,<3.0.0", + # Extras + "weasyprint>=55.0", + "matplotlib>=3.1.0", + "mplcursors>=0.3", # Testing "pytest>=5.3.2", "coverage[toml]>=4.5.4", diff --git a/references/requirements.txt b/references/requirements.txt index bb72bdd47b..679df79518 100644 --- a/references/requirements.txt +++ b/references/requirements.txt @@ -3,3 +3,4 @@ tqdm wandb>=0.10.31 psutil>=5.9.0 clearml>=1.11.1 +matplotlib>=3.1.0 diff --git a/tests/common/test_core.py b/tests/common/test_core.py index cd2758b0df..2fb8d1b9b0 100644 --- a/tests/common/test_core.py +++ b/tests/common/test_core.py @@ -1,6 +1,7 @@ import pytest import doctr +from doctr.file_utils import requires_package def test_version(): @@ -15,3 +16,9 @@ def test_is_tf_available(): @pytest.mark.skipif(doctr.is_torch_available() and doctr.is_tf_available(), reason="torch and tf are available") def test_is_torch_available(): assert not doctr.is_torch_available() + + +def test_requires_package(): + requires_package("numpy") # availbable + with pytest.raises(ImportError): # not available + requires_package("non_existent_package") diff --git a/tests/common/test_utils_reconstitution.py b/tests/common/test_utils_reconstitution.py new file mode 100644 index 0000000000..3b70e67070 --- /dev/null +++ b/tests/common/test_utils_reconstitution.py @@ -0,0 +1,12 @@ +import numpy as np +from test_io_elements import _mock_pages + +from doctr.utils import reconstitution + + +def test_synthesize_page(): + pages = _mock_pages() + reconstitution.synthesize_page(pages[0].export(), draw_proba=False) + render = reconstitution.synthesize_page(pages[0].export(), draw_proba=True) + assert isinstance(render, np.ndarray) + assert render.shape == (*pages[0].dimensions, 3) diff --git a/tests/common/test_utils_visualization.py b/tests/common/test_utils_visualization.py index b1e9998698..ae232ebb85 100644 --- a/tests/common/test_utils_visualization.py +++ b/tests/common/test_utils_visualization.py @@ -21,14 +21,6 @@ def test_visualize_page(): visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) -def test_synthesize_page(): - pages = _mock_pages() - visualization.synthesize_page(pages[0].export(), draw_proba=False) - render = visualization.synthesize_page(pages[0].export(), draw_proba=True) - assert isinstance(render, np.ndarray) - assert render.shape == (*pages[0].dimensions, 3) - - def test_draw_boxes(): image = np.ones((256, 256, 3), dtype=np.float32) boxes = [