Skip to content

Commit

Permalink
TLDR-462 - test on GPU work
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyBoget committed Nov 9, 2023
1 parent 3dc8b61 commit 95dfbce
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 1 deletion.
12 changes: 12 additions & 0 deletions dedoc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
from typing import Any, Optional

import GPUtil

logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s - %(pathname)s - %(levelname)s - %(message)s")

DEBUG_MODE = False
Expand All @@ -22,6 +24,10 @@
# number of parallel jobs in some tasks as OCR
n_jobs=1,

# --------------------------------------------GPU SETTINGS-------------------------------------------------------
# set gpu in XGBoost and torch models
on_gpu=False,

# ---------------------------------------------API SETTINGS---------------------------------------------------------
# max file size in bytes
max_content_length=512 * 1024 * 1024,
Expand Down Expand Up @@ -81,6 +87,12 @@ def __init_config(self, args: Optional[Any] = None) -> None:
else:
self.__config = _config

gpus = GPUtil.getGPUs()
if self.__config.get("on_gpu", False) and len(gpus) == 0:
logger = self.__config.get("logger", logging.getLogger())
logger.warning("No gpu device available! Changing configuration on_gpu to False!")
self.__config["on_gpu"] = False

def get_config(self, args: Optional[Any] = None) -> dict:
if self.__config is None or args is not None:
self.__init_config(args)
Expand Down
10 changes: 10 additions & 0 deletions dedoc/readers/pdf_reader/pdf_auto_reader/txtlayer_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
from typing import List

import GPUtil
from xgboost import XGBClassifier

from dedoc.config import get_config
Expand Down Expand Up @@ -37,6 +38,15 @@ def __get_model(self) -> XGBClassifier:
with gzip.open(self.path, "rb") as f:
self.__model = pickle.load(f)

gpus = GPUtil.getGPUs()
if self.config.get("on_gpu", False) and len(gpus) > 0:
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
self.__model.set_params(**gpu_params)
self.__model.get_booster().set_param(gpu_params)
elif self.config.get("on_gpu", False) and len(gpus) == 0:
self.logger.warning("No gpu device availiable! Changing configuration on_gpu to False!")
self.config["on_gpu"] = False

return self.__model

def predict(self, lines: List[LineWithMeta]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _set_device(self, on_gpu: bool) -> None:
self.device = torch.device("cpu")
self.location = "cpu"

self.logger.warning(f"Classifier is set to device {self.device}")

def _load_weights(self, net: ClassificationModelTorch) -> None:
path_checkpoint = path.join(self.checkpoint_path, "scan_orientation_efficient_net_b0.pth")
if not path.isfile(path_checkpoint):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import gzip
import logging
import os
import pickle
from typing import List

import GPUtil
from xgboost import XGBClassifier

from dedoc.config import get_config
Expand All @@ -18,6 +20,7 @@ class ScanParagraphClassifierExtractor(object):

def __init__(self, *, config: dict) -> None:
super().__init__()
self.logger = config.get("logger", logging.getLogger())
self.path = os.path.join(get_config()["resources_path"], "paragraph_classifier.pkl.gz")
self.config = config
self._feature_extractor = None
Expand All @@ -44,6 +47,15 @@ def _unpickle(self) -> None:
self._classifier, parameters = pickle.load(file)
self._feature_extractor = ParagraphFeatureExtractor(**parameters, config=self.config)

gpus = GPUtil.getGPUs()
if self.config.get("on_gpu", False) and len(gpus) > 0:
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
self._classifier.set_params(**gpu_params)
self._classifier.get_booster().set_param(gpu_params)
elif self.config.get("on_gpu", False) and len(gpus) == 0:
self.logger.warning("No gpu device availiable! Changing configuration on_gpu to False!")
self.config["on_gpu"] = False

def extract(self, lines_with_links: List[LineWithLocation]) -> List[LineWithLocation]:
data = self.feature_extractor.transform([lines_with_links])
if any((data[col].isna().all() for col in data.columns)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, *, config: dict) -> None:
"""
super().__init__(config=config)
self.scew_corrector = SkewCorrector()
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=False, checkpoint_path=get_config()["resources_path"], config=config)
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False),
checkpoint_path=get_config()["resources_path"], config=config)
self.binarizer = AdaptiveBinarizer()
self.ocr = OCRLineExtractor(config=config)
self.logger = config.get("logger", logging.getLogger())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import gzip
import logging
import os
import pickle
from abc import ABC
from typing import Tuple

import GPUtil
from xgboost import XGBClassifier

from dedoc.download_models import download_from_hub
Expand All @@ -14,6 +16,7 @@ class AbstractPickledLineTypeClassifier(AbstractLineTypeClassifier, ABC):

def __init__(self, *, config: dict) -> None:
super().__init__(config=config)
self.logger = self.config.get("logger", logging.getLogger())

def load(self, classifier_type: str, path: str) -> Tuple[XGBClassifier, dict]:
if not os.path.isfile(path):
Expand All @@ -22,6 +25,16 @@ def load(self, classifier_type: str, path: str) -> Tuple[XGBClassifier, dict]:

with gzip.open(path) as file:
classifier, feature_extractor_parameters = pickle.load(file)

gpus = GPUtil.getGPUs()
if self.config.get("on_gpu", False) and len(gpus) > 0:
gpu_params = dict(predictor="gpu_predictor", tree_method="auto", gpu_id=0)
classifier.set_params(**gpu_params)
classifier.get_booster().set_param(gpu_params)
elif self.config.get("on_gpu", False) and len(gpus) == 0:
self.logger.warning("No gpu device availiable! Changing configuration on_gpu to False!")
self.config["on_gpu"] = False

return classifier, feature_extractor_parameters

def save(self, path_out: str, parameters: object) -> str:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Cython>=0.29.28,<=3.0.2
docx==0.2.4
dedoc-utils==0.3.5
fastapi>=0.77.0,<=0.103.0
GPUtil>=1.4.0
huggingface-hub>=0.14.1,<=0.16.4
imutils==0.5.4
itsdangerous>=2.1.0,<=2.1.2
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/test_on_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os

import cv2
from dedocutils.data_structures import BBox

from dedoc.data_structures.line_metadata import LineMetadata
from dedoc.data_structures.line_with_meta import LineWithMeta
from dedoc.readers.pdf_reader.data_classes.line_with_location import LineWithLocation
from dedoc.readers.pdf_reader.data_classes.tables.location import Location
from dedoc.readers.pdf_reader.pdf_auto_reader.txtlayer_classifier import TxtlayerClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.paragraph_extractor.scan_paragraph_classifier_extractor import ScanParagraphClassifierExtractor
from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor
from tests.api_tests.abstract_api_test import AbstractTestApiDocReader
from tests.test_utils import get_test_config


class TestOnGpu(AbstractTestApiDocReader):
config = dict(on_gpu=True, n_jobs=1)

def test_line_type_classifier(self) -> None:
"""
Loads AbstractPickledLineTypeClassifier
"""
law_extractor = LawStructureExtractor(config=self.config)
lines = [
LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))
]
predictions = law_extractor.classifier.predict(lines)
self.assertListEqual(predictions, ["header", "header", "cellar"])

def test_orientation_classifier(self) -> None:
checkpoint_path = get_test_config()["resources_path"]
orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False), checkpoint_path=checkpoint_path, config=self.config)
imgs_path = [f"../data/skew_corrector/rotated_{i}.jpg" for i in range(1, 5)]

for i in range(len(imgs_path)):
path = os.path.join(os.path.dirname(__file__), imgs_path[i])
image = cv2.imread(path)
_, orientation = orientation_classifier.predict(image)
self.assertEqual(orientation, 0)

def test_txtlayer_classifier(self) -> None:
classify_lines = TxtlayerClassifier(config=self.config)
lines = [LineWithMeta("Line1"), LineWithMeta("Line 2 is a bit longer")]
self.assertEqual(classify_lines.predict(lines), True)

def test_scan_paragraph_classifier_extractor(self) -> None:
classify_lines_with_location = ScanParagraphClassifierExtractor(config=self.config)
metadata = LineMetadata(page_id=1, line_id=1)
metadata2 = LineMetadata(page_id=1, line_id=2)
bbox = BBox(x_top_left=0, y_top_left=0, width=100, height=20)
bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20)
location = Location(page_number=1, bbox=bbox)
location2 = Location(page_number=1, bbox=bbox2)
lines = [
LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)
]
data = classify_lines_with_location.feature_extractor.transform([lines])

if any((data[col].isna().all() for col in data.columns)):
labels = ["not_paragraph"] * len(lines)
else:
labels = classify_lines_with_location.classifier.predict(data)

self.assertEqual(labels[0], "paragraph")
self.assertEqual(labels[1], "paragraph")

0 comments on commit 95dfbce

Please sign in to comment.