Skip to content

Commit

Permalink
TLDR-462 - review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
raxtemur authored and NastyBoget committed Nov 9, 2023
1 parent 0390649 commit c45b758
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
9 changes: 2 additions & 7 deletions dedoc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,11 @@ def __init_config(self, args: Optional[Any] = None) -> None:
self.__config = config_module._config
else:
self.__config = _config

gpus = GPUtil.getGPUs()
if self.__config.get("on_gpu", False) and len(gpus) == 0:
self.__config["logger"].warning("No gpu device availiable! Changing configuration on_gpu to False!")
self.__config["on_gpu"] = False


gpus = GPUtil.getGPUs()
if self.__config.get("on_gpu", False) and len(gpus) == 0:
self.__config["logger"].warning("No gpu device availiable! Changing configuration on_gpu to False!")
logger = self.__config.get("logger", logging.getLogger())
logger.warning("No gpu device available! Changing configuration on_gpu to False!")
self.__config["on_gpu"] = False

def get_config(self, args: Optional[Any] = None) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gzip
import logging
import os
import pickle
from typing import List
Expand All @@ -19,6 +20,7 @@ class ScanParagraphClassifierExtractor(object):

def __init__(self, *, config: dict) -> None:
super().__init__()
self.logger = config.get("logger", logging.getLogger())
self.path = os.path.join(get_config()["resources_path"], "paragraph_classifier.pkl.gz")
self.config = config
self._feature_extractor = None
Expand Down
39 changes: 23 additions & 16 deletions tests/unit_tests/test_on_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os

import cv2
Expand All @@ -11,30 +10,26 @@
from dedoc.readers.pdf_reader.pdf_auto_reader.txtlayer_classifier import TxtlayerClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier
from dedoc.readers.pdf_reader.pdf_image_reader.paragraph_extractor.scan_paragraph_classifier_extractor import ScanParagraphClassifierExtractor
from dedoc.readers.txt_reader.raw_text_reader import RawTextReader
from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor
from tests.api_tests.abstract_api_test import AbstractTestApiDocReader
from tests.test_utils import get_test_config


class TestOnGpu(AbstractTestApiDocReader):
config = dict(on_gpu=True,
n_jobs=1)
logger = logging.getLogger()
config = dict(on_gpu=True, n_jobs=1)

def test_line_type_classifier(self) -> None:
"""
Loads AbstractPickledLineTypeClassifier
"""
txt_reader = RawTextReader(config=self.config)
law_extractor = LawStructureExtractor(config=self.config)

path = os.path.join(self.data_directory_path, "laws", "коап_москвы_8_7_2015_utf.txt")
document = txt_reader.read(path=path, document_type="law", parameters={})
document = law_extractor.extract_structure(document, {})

self.assertListEqual([], document.attachments)
self.assertListEqual([], document.tables)
lines = [
LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))
]
predictions = law_extractor.classifier.predict(lines)
self.assertListEqual(predictions, ["header", "header", "cellar"])

def test_orientation_classifier(self) -> None:
checkpoint_path = get_test_config()["resources_path"]
Expand All @@ -55,9 +50,21 @@ def test_txtlayer_classifier(self) -> None:
def test_scan_paragraph_classifier_extractor(self) -> None:
classify_lines_with_location = ScanParagraphClassifierExtractor(config=self.config)
metadata = LineMetadata(page_id=1, line_id=1)

metadata2 = LineMetadata(page_id=1, line_id=2)
bbox = BBox(x_top_left=0, y_top_left=0, width=100, height=20)
bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20)
location = Location(page_number=1, bbox=bbox)
lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location)]
location2 = Location(page_number=1, bbox=bbox2)
lines = [
LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)
]
data = classify_lines_with_location.feature_extractor.transform([lines])

if any((data[col].isna().all() for col in data.columns)):
labels = ["not_paragraph"] * len(lines)
else:
labels = classify_lines_with_location.classifier.predict(data)

self.assertEqual(classify_lines_with_location.extract(lines), lines)
self.assertEqual(labels[0], "paragraph")
self.assertEqual(labels[1], "paragraph")

0 comments on commit c45b758

Please sign in to comment.