Skip to content

Commit

Permalink
TLDR-462 -- style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
raxtemur committed Nov 8, 2023
1 parent 2fdb601 commit 8e508f3
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions tests/unit_tests/test_on_gpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import cv2
import pandas as pd
from dedocutils.data_structures import BBox

from dedoc.data_structures.line_metadata import LineMetadata
Expand All @@ -24,11 +23,11 @@ def test_line_type_classifier(self) -> None:
Loads AbstractPickledLineTypeClassifier
"""
law_extractor = LawStructureExtractor(config=self.config)

lines = [LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))]
features = law_extractor.classifier.feature_extractor.transform([lines])
lines = [
LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))
]
predictions = law_extractor.classifier.predict(lines)
self.assertListEqual(predictions, ["header", "header", "cellar"])

Expand Down Expand Up @@ -56,8 +55,10 @@ def test_scan_paragraph_classifier_extractor(self) -> None:
bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20)
location = Location(page_number=1, bbox=bbox)
location2 = Location(page_number=1, bbox=bbox2)
lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)]
lines = [
LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)
]
data = classify_lines_with_location.feature_extractor.transform([lines])

if any((data[col].isna().all() for col in data.columns)):
Expand All @@ -67,6 +68,3 @@ def test_scan_paragraph_classifier_extractor(self) -> None:

self.assertEqual(labels[0], "paragraph")
self.assertEqual(labels[1], "paragraph")



0 comments on commit 8e508f3

Please sign in to comment.