diff --git a/dedoc/config.py b/dedoc/config.py index 3932b3fb..1e7cacc0 100644 --- a/dedoc/config.py +++ b/dedoc/config.py @@ -21,9 +21,9 @@ # --------------------------------------------JOBLIB SETTINGS------------------------------------------------------- # number of parallel jobs in some tasks as OCR n_jobs=1, - + # --------------------------------------------GPU SETTINGS------------------------------------------------------- - # set gpu in XGBoost and torch models + # set gpu in XGBoost and torch models on_gpu=False, # ---------------------------------------------API SETTINGS--------------------------------------------------------- diff --git a/dedoc/readers/pdf_reader/pdf_auto_reader/txtlayer_classifier.py b/dedoc/readers/pdf_reader/pdf_auto_reader/txtlayer_classifier.py index b91ce2b8..56386bdd 100644 --- a/dedoc/readers/pdf_reader/pdf_auto_reader/txtlayer_classifier.py +++ b/dedoc/readers/pdf_reader/pdf_auto_reader/txtlayer_classifier.py @@ -38,10 +38,9 @@ def __get_model(self) -> XGBClassifier: self.__model = pickle.load(f) if self.config.get("on_gpu", False): - self.__model.set_params(predictor="gpu_predictor", tree_method='auto', n_gpus=1, gpu_id=0) + self.__model.set_params(predictor="gpu_predictor", tree_method="auto", n_gpus=1, gpu_id=0) self.__model.get_booster().set_param(self.__model.get_params()) - return self.__model def predict(self, lines: List[LineWithMeta]) -> bool: diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py b/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py index dbad92c6..221dea3e 100644 --- a/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py +++ b/dedoc/readers/pdf_reader/pdf_image_reader/paragraph_extractor/scan_paragraph_classifier_extractor.py @@ -45,7 +45,7 @@ def _unpickle(self) -> None: self._feature_extractor = ParagraphFeatureExtractor(**parameters, config=self.config) if self.config.get("on_gpu", False): - self._classifier.set_params(predictor="gpu_predictor", tree_method='auto', n_gpus=1, gpu_id=0) + self._classifier.set_params(predictor="gpu_predictor", tree_method="auto", n_gpus=1, gpu_id=0) self._classifier.get_booster().set_param(self._classifier.get_params()) def extract(self, lines_with_links: List[LineWithLocation]) -> List[LineWithLocation]: diff --git a/tests/unit_tests/test_my_gpu_tests.py b/tests/unit_tests/test_my_gpu_tests.py index e80b43da..652d4578 100644 --- a/tests/unit_tests/test_my_gpu_tests.py +++ b/tests/unit_tests/test_my_gpu_tests.py @@ -7,15 +7,14 @@ from dedoc.metadata_extractors.concrete_metadata_extractors.base_metadata_extractor import BaseMetadataExtractor from dedoc.readers.txt_reader.raw_text_reader import RawTextReader from dedoc.structure_extractors.concrete_structure_extractors.law_structure_excractor import LawStructureExtractor -from tests.api_tests.abstract_api_test import AbstractTestApiDocReader from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier - +from tests.api_tests.abstract_api_test import AbstractTestApiDocReader from tests.test_utils import get_test_config + @unittest.skip("Should load gpu") class MyGPUTests(AbstractTestApiDocReader): config = dict(on_gpu=True) - def _get_abs_path(self, file_name: str) -> str: return os.path.join(self.data_directory_path, "laws", file_name) @@ -37,7 +36,6 @@ def test_law_document_spaces_correctness(self) -> None: self.assertListEqual([], document.attachments) self.assertListEqual([], document.tables) - def test_skew_corrector(self) -> None: checkpoint_path = get_test_config()["resources_path"] orientation_classifier = ColumnsOrientationClassifier(on_gpu=self.config.get("on_gpu", False), checkpoint_path=checkpoint_path, config=self.config)