diff --git a/tests/unit_tests/test_on_gpu.py b/tests/unit_tests/test_on_gpu.py index 2aecb5e8..ea2001ae 100644 --- a/tests/unit_tests/test_on_gpu.py +++ b/tests/unit_tests/test_on_gpu.py @@ -1,7 +1,6 @@ import os import cv2 -import pandas as pd from dedocutils.data_structures import BBox from dedoc.data_structures.line_metadata import LineMetadata @@ -24,11 +23,11 @@ def test_line_type_classifier(self) -> None: Loads AbstractPickledLineTypeClassifier """ law_extractor = LawStructureExtractor(config=self.config) - - lines = [LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)), - LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)), - LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))] - features = law_extractor.classifier.feature_extractor.transform([lines]) + lines = [ + LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)), + LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)), + LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2)) + ] predictions = law_extractor.classifier.predict(lines) self.assertListEqual(predictions, ["header", "header", "cellar"]) @@ -56,8 +55,10 @@ def test_scan_paragraph_classifier_extractor(self) -> None: bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20) location = Location(page_number=1, bbox=bbox) location2 = Location(page_number=1, bbox=bbox2) - lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location), - LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)] + lines = [ + LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location), + LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2) + ] data = classify_lines_with_location.feature_extractor.transform([lines]) if any((data[col].isna().all() for col in data.columns)): @@ -67,6 +68,3 @@ def test_scan_paragraph_classifier_extractor(self) -> None: self.assertEqual(labels[0], "paragraph") self.assertEqual(labels[1], "paragraph") - - -