Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLDR-462 gpu for 1.1 #365

Merged
merged 11 commits into from
Nov 10, 2023
8 changes: 8 additions & 0 deletions dedoc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
from typing import Any, Optional

from dedoc.utils.parameter_utils import get_param_gpu_available

logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s - %(pathname)s - %(levelname)s - %(message)s")

DEBUG_MODE = False
Expand All @@ -22,6 +24,10 @@
# number of parallel jobs in some tasks as OCR
n_jobs=1,

# --------------------------------------------GPU SETTINGS-------------------------------------------------------
# set gpu in XGBoost and torch models
on_gpu=False,

# ---------------------------------------------API SETTINGS---------------------------------------------------------
# max file size in bytes
max_content_length=512 * 1024 * 1024,
Expand Down Expand Up @@ -81,6 +87,8 @@ def __init_config(self, args: Optional[Any] = None) -> None:
else:
self.__config = _config

get_param_gpu_available(self.__config, self.__config.get("logger", logging.getLogger()))

def get_config(self, args: Optional[Any] = None) -> dict:
if self.__config is None or args is not None:
self.__init_config(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dedoc.data_structures import LineWithMeta
from dedoc.download_models import download_from_hub
from dedoc.readers.pdf_reader.pdf_auto_reader.txtlayer_feature_extractor import TxtlayerFeatureExtractor
from dedoc.utils.parameter_utils import get_param_gpu_available


class TxtlayerClassifier:
Expand Down Expand Up @@ -37,6 +38,11 @@ def __get_model(self) -> XGBClassifier:
with gzip.open(self.path, "rb") as f:
self.__model = pickle.load(f)

if get_param_gpu_available(self.config, self.logger):
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
self.__model.set_params(**gpu_params)
self.__model.get_booster().set_param(gpu_params)

return self.__model

def predict(self, lines: List[LineWithMeta]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _set_device(self, on_gpu: bool) -> None:
self.device = torch.device("cpu")
self.location = "cpu"

self.logger.warning(f"Classifier is set to device {self.device}")

def _load_weights(self, net: ClassificationModelTorch) -> None:
path_checkpoint = path.join(self.checkpoint_path, "scan_orientation_efficient_net_b0.pth")
if not path.isfile(path_checkpoint):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gzip
import logging
import os
import pickle
from typing import List
Expand All @@ -9,6 +10,7 @@
from dedoc.download_models import download_from_hub
from dedoc.readers.pdf_reader.data_classes.line_with_location import LineWithLocation
from dedoc.readers.pdf_reader.pdf_image_reader.paragraph_extractor.paragraph_features import ParagraphFeatureExtractor
from dedoc.utils.parameter_utils import get_param_gpu_available


class ScanParagraphClassifierExtractor(object):
Expand All @@ -18,6 +20,7 @@ class ScanParagraphClassifierExtractor(object):

def __init__(self, *, config: dict) -> None:
super().__init__()
self.logger = config.get("logger", logging.getLogger())
self.path = os.path.join(get_config()["resources_path"], "paragraph_classifier.pkl.gz")
self.config = config
self._feature_extractor = None
Expand All @@ -44,6 +47,13 @@ def _unpickle(self) -> None:
self._classifier, parameters = pickle.load(file)
self._feature_extractor = ParagraphFeatureExtractor(**parameters, config=self.config)

if get_param_gpu_available(self.config, self.logger):
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
self._classifier.set_params(**gpu_params)
self._classifier.get_booster().set_param(gpu_params)

return self._classifier

def extract(self, lines_with_links: List[LineWithLocation]) -> List[LineWithLocation]:
data = self.feature_extractor.transform([lines_with_links])
if any((data[col].isna().all() for col in data.columns)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, *, config: dict) -> None:
"""
super().__init__(config=config)
self.scew_corrector = SkewCorrector()
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=False, checkpoint_path=get_config()["resources_path"], config=config)
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False),
checkpoint_path=get_config()["resources_path"], config=config)
self.binarizer = AdaptiveBinarizer()
self.ocr = OCRLineExtractor(config=config)
self.logger = config.get("logger", logging.getLogger())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import gzip
import logging
import os
import pickle
from abc import ABC
from typing import Tuple

from xgboost import XGBClassifier


from dedoc.download_models import download_from_hub
from dedoc.structure_extractors.line_type_classifiers.abstract_line_type_classifier import AbstractLineTypeClassifier
from dedoc.utils.parameter_utils import get_param_gpu_available


class AbstractPickledLineTypeClassifier(AbstractLineTypeClassifier, ABC):

def __init__(self, *, config: dict) -> None:
super().__init__(config=config)
self.logger = self.config.get("logger", logging.getLogger())

def load(self, classifier_type: str, path: str) -> Tuple[XGBClassifier, dict]:
if not os.path.isfile(path):
Expand All @@ -22,6 +26,12 @@ def load(self, classifier_type: str, path: str) -> Tuple[XGBClassifier, dict]:

with gzip.open(path) as file:
classifier, feature_extractor_parameters = pickle.load(file)

if get_param_gpu_available(self.config, self.logger):
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
classifier.set_params(**gpu_params)
classifier.get_booster().set_param(gpu_params)

return classifier, feature_extractor_parameters

def save(self, path_out: str, parameters: object) -> str:
Expand Down
40 changes: 36 additions & 4 deletions dedoc/utils/parameter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import subprocess
from logging import Logger
from typing import Any, Dict, Optional, Tuple


Expand All @@ -17,21 +19,27 @@ def get_param_language(parameters: Optional[dict]) -> str:
def get_param_orient_analysis_cells(parameters: Optional[dict]) -> bool:
if parameters is None:
return False
orient_analysis_cells = parameters.get("orient_analysis_cells", "False").lower() == "true"
orient_analysis_cells = (
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved
parameters.get("orient_analysis_cells", "False").lower() == "true"
)
return orient_analysis_cells


def get_param_need_header_footers_analysis(parameters: Optional[dict]) -> bool:
if parameters is None:
return False
need_header_footers_analysis = parameters.get("need_header_footer_analysis", "False").lower() == "true"
need_header_footers_analysis = (
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved
parameters.get("need_header_footer_analysis", "False").lower() == "true"
)
return need_header_footers_analysis


def get_param_need_pdf_table_analysis(parameters: Optional[dict]) -> bool:
if parameters is None:
return False
need_pdf_table_analysis = parameters.get("need_pdf_table_analysis", "True").lower() == "true"
need_pdf_table_analysis = (
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved
parameters.get("need_pdf_table_analysis", "True").lower() == "true"
)
return need_pdf_table_analysis


Expand Down Expand Up @@ -96,7 +104,6 @@ def get_param_image_document_page(parameters: Optional[dict]) -> str:


def get_param_table_type(parameters: Optional[dict]) -> str:

if parameters is None:
return ""

Expand All @@ -119,3 +126,28 @@ def get_param_page_slice(parameters: Dict[str, Any]) -> Tuple[Optional[int], Opt
return first_page, last_page
except Exception:
raise ValueError(f"Error input parameter 'pages'. Bad page limit {pages}")


def get_param_gpu_available(parameters: Optional[dict], logger: Logger) -> bool:
"""
Check if GPU is available and update the configuration accordingly.

Args:
parameters (Optional[dict]): A dictionary containing the parameters for the function. Usually supposed to be config.
logger (Logger): An instance of the logger.

Returns:
bool: True if GPU is available, False otherwise.
"""

if not parameters.get("on_gpu", False):
return False

try:
subprocess.run(["nvidia-smi"], check=True, stdout=subprocess.DEVNULL)
except (subprocess.CalledProcessError, FileNotFoundError):
logger.warning("No gpu device available! Changing configuration on_gpu to False!")
parameters["on_gpu"] = False
return False

return True
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Cython>=0.29.28,<=3.0.2
docx==0.2.4
dedoc-utils==0.3.5
fastapi>=0.77.0,<=0.103.0
GPUtil>=1.4.0
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved
huggingface-hub>=0.14.1,<=0.16.4
imutils==0.5.4
itsdangerous>=2.1.0,<=2.1.2
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/test_misc_on_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os

import cv2
from dedocutils.data_structures import BBox

from dedoc.data_structures.line_metadata import LineMetadata
from dedoc.data_structures.line_with_meta import LineWithMeta
from dedoc.readers.pdf_reader.data_classes.line_with_location import LineWithLocation
from dedoc.readers.pdf_reader.data_classes.tables.location import Location
from dedoc.readers.pdf_reader.pdf_auto_reader.txtlayer_classifier import TxtlayerClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.paragraph_extractor.scan_paragraph_classifier_extractor import ScanParagraphClassifierExtractor
from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor
from tests.api_tests.abstract_api_test import AbstractTestApiDocReader
from tests.test_utils import get_test_config


class TestOnGpu(AbstractTestApiDocReader):
config = dict(on_gpu=True, n_jobs=1)

def test_line_type_classifier(self) -> None:
"""
Loads AbstractPickledLineTypeClassifier
"""
law_extractor = LawStructureExtractor(config=self.config)
lines = [
LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))
]
predictions = law_extractor.classifier.predict(lines)
self.assertListEqual(predictions, ["header", "header", "cellar"])

def test_orientation_classifier(self) -> None:
checkpoint_path = get_test_config()["resources_path"]
orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False), checkpoint_path=checkpoint_path, config=self.config)
imgs_path = [f"../data/skew_corrector/rotated_{i}.jpg" for i in range(1, 5)]

for i in range(len(imgs_path)):
path = os.path.join(os.path.dirname(__file__), imgs_path[i])
image = cv2.imread(path)
_, orientation = orientation_classifier.predict(image)
self.assertEqual(orientation, 0)

def test_txtlayer_classifier(self) -> None:
classify_lines = TxtlayerClassifier(config=self.config)
lines = [LineWithMeta("Line1"), LineWithMeta("Line 2 is a bit longer")]
self.assertEqual(classify_lines.predict(lines), True)

def test_scan_paragraph_classifier_extractor(self) -> None:
classify_lines_with_location = ScanParagraphClassifierExtractor(config=self.config)
metadata = LineMetadata(page_id=1, line_id=1)
metadata2 = LineMetadata(page_id=1, line_id=2)
bbox = BBox(x_top_left=0, y_top_left=0, width=100, height=20)
bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20)
location = Location(page_number=1, bbox=bbox)
location2 = Location(page_number=1, bbox=bbox2)
lines = [
LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)
]
lines = classify_lines_with_location.extract(lines)

self.assertEqual(lines[0].metadata.tag_hierarchy_level.can_be_multiline, False)
self.assertEqual(lines[1].metadata.tag_hierarchy_level.can_be_multiline, False)
Loading