Skip to content

Commit

Permalink
TLDR-462 -- test_scan_paragraph_classifier_extractor finally works
Browse files Browse the repository at this point in the history
*now it uses gpu: before, for 1 line it automatically marked it like a non_paragraph
  • Loading branch information
raxtemur committed Nov 8, 2023
1 parent bd25cc9 commit 2fdb601
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/unit_tests/test_on_gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import cv2
import pandas as pd
from dedocutils.data_structures import BBox

from dedoc.data_structures.line_metadata import LineMetadata
Expand All @@ -24,9 +25,11 @@ def test_line_type_classifier(self) -> None:
"""
law_extractor = LawStructureExtractor(config=self.config)

lines = [LineWithMeta(" З А К О Н"), LineWithMeta("\n"), LineWithMeta(" ГОРОДА МОСКВЫ")]
lines = [LineWithMeta(" З А К О Н", metadata=LineMetadata(page_id=0, line_id=0)),
LineWithMeta("\n", metadata=LineMetadata(page_id=0, line_id=1)),
LineWithMeta(" ГОРОДА МОСКВЫ", metadata=LineMetadata(page_id=0, line_id=2))]
features = law_extractor.classifier.feature_extractor.transform([lines])
predictions = law_extractor.classifier.predict(lines)

self.assertListEqual(predictions, ["header", "header", "cellar"])

def test_orientation_classifier(self) -> None:
Expand All @@ -48,14 +51,22 @@ def test_txtlayer_classifier(self) -> None:
def test_scan_paragraph_classifier_extractor(self) -> None:
classify_lines_with_location = ScanParagraphClassifierExtractor(config=self.config)
metadata = LineMetadata(page_id=1, line_id=1)

metadata2 = LineMetadata(page_id=1, line_id=2)
bbox = BBox(x_top_left=0, y_top_left=0, width=100, height=20)
bbox2 = BBox(x_top_left=50, y_top_left=50, width=100, height=20)
location = Location(page_number=1, bbox=bbox)
lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location)]
location2 = Location(page_number=1, bbox=bbox2)
lines = [LineWithLocation(line="Example line", metadata=metadata, annotations=[], location=location),
LineWithLocation(line="Example line 2", metadata=metadata2, annotations=[], location=location2)]
data = classify_lines_with_location.feature_extractor.transform([lines])

if any((data[col].isna().all() for col in data.columns)):
labels = ["not_paragraph"] * len(lines)
else:
labels = classify_lines_with_location.classifier.predict(data)

self.assertEqual(labels, ["not_paragraph", "not_paragraph", "not_paragraph"])
self.assertEqual(labels[0], "paragraph")
self.assertEqual(labels[1], "paragraph")



0 comments on commit 2fdb601

Please sign in to comment.