diff --git a/dedoc/config.py b/dedoc/config.py index c220d035..0bb15ce3 100644 --- a/dedoc/config.py +++ b/dedoc/config.py @@ -86,16 +86,11 @@ def __init_config(self, args: Optional[Any] = None) -> None: self.__config = config_module._config else: self.__config = _config - - gpus = GPUtil.getGPUs() - if self.__config.get("on_gpu", False) and len(gpus) == 0: - self.__config["logger"].warning("No gpu device availiable! Changing configuration on_gpu to False!") - self.__config["on_gpu"] = False - gpus = GPUtil.getGPUs() if self.__config.get("on_gpu", False) and len(gpus) == 0: - self.__config["logger"].warning("No gpu device availiable! Changing configuration on_gpu to False!") + logger = self.__config.get("logger", logging.getLogger()) + logger.warning("No gpu device available! Changing configuration on_gpu to False!") self.__config["on_gpu"] = False def get_config(self, args: Optional[Any] = None) -> dict: diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py b/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py index c54c2f89..9b6e5ba6 100644 --- a/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py +++ b/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py @@ -1,4 +1,5 @@ import gzip +import logging import os import pickle from typing import List @@ -19,6 +20,7 @@ class ScanParagraphClassifierExtractor(object): def __init__(self, *, config: dict) -> None: super().__init__() + self.logger = config.get("logger", logging.getLogger()) self.path = os.path.join(get_config()["resources_path"], "paragraph_classifier.pkl.gz") self.config = config self._feature_extractor = None diff --git a/tests/unit_tests/test_on_gpu.py b/tests/unit_tests/test_on_gpu.py index 36c8d89d..ea2001ae 100644 --- a/tests/unit_tests/test_on_gpu.py +++ b/tests/unit_tests/test_on_gpu.py @@ -1,4 +1,3 @@ -import logging import os import cv2 @@ -11,30 +10,26 @@ from dedoc.readers.pdf_reader.pdf_auto_reader.txtlayer_classifier import TxtlayerClassifier from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier from dedoc.readers.pdf_reader.pdf_image_reader.paragraph_extractor.scan_paragraph_classifier_extractor import ScanParagraphClassifierExtractor -from dedoc.readers.txt_reader.raw_text_reader import RawTextReader from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor from tests.api_tests.abstract_api_test import AbstractTestApiDocReader from tests.test_utils import get_test_config class TestOnGpu(AbstractTestApiDocReader): - config = dict(on_gpu=True, - n_jobs=1) - logger = logging.getLogger() + config = dict(on_gpu=True, n_jobs=1) def test_line_type_classifier(self) -> None: """ Loads AbstractPickledLineTypeClassifier """ - txt_reader = RawTextReader(config=self.config) law_extractor = LawStructureExtractor(config=self.config) - - path = os.path.join(self.data_directory_path, "laws", "коап_москвы_8_7_2015_utf.txt") - document = txt_reader.read(path=path, document_type="law", parameters={}) - document = law_extractor.extract_structure(document, {}) - - self.assertListEqual([], document.attachments) - self.assertListEqual([], document.tables) + lines = [ + LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)), + LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)), + LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2)) + ] + predictions = law_extractor.classifier.predict(lines) + self.assertListEqual(predictions, ["header", "header", "cellar"]) def test_orientation_classifier(self) -> None: checkpoint_path = get_test_config()["resources_path"] @@ -55,9 +50,21 @@ def test_txtlayer_classifier(self) -> None: def test_scan_paragraph_classifier_extractor(self) -> None: classify_lines_with_location = ScanParagraphClassifierExtractor(config=self.config) metadata = LineMetadata(page_id=1, line_id=1) - + metadata2 = LineMetadata(page_id=1, line_id=2) bbox = BBox(x_top_left=0, y_top_left=0, width=100, height=20) + bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20) location = Location(page_number=1, bbox=bbox) - lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location)] + location2 = Location(page_number=1, bbox=bbox2) + lines = [ + LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location), + LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2) + ] + data = classify_lines_with_location.feature_extractor.transform([lines]) + + if any((data[col].isna().all() for col in data.columns)): + labels = ["not_paragraph"] * len(lines) + else: + labels = classify_lines_with_location.classifier.predict(data) - self.assertEqual(classify_lines_with_location.extract(lines), lines) + self.assertEqual(labels[0], "paragraph") + self.assertEqual(labels[1], "paragraph")