diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 92a6206..72d7718 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -5,12 +5,11 @@ from marker.settings import settings from marker.v2.builders import BaseBuilder -from marker.v2.providers.pdf import PdfProvider +from marker.v2.providers.pdf import PdfPageProviderLines, PdfProvider from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY, Block, Text from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup from marker.v2.schema.polygon import PolygonBox -from marker.v2.schema.text.line import Line class LayoutBuilder(BaseBuilder): @@ -24,7 +23,7 @@ def __init__(self, layout_model, config=None): def __call__(self, document: Document, provider: PdfProvider): layout_results = self.surya_layout(document.pages) self.add_blocks_to_pages(document.pages, layout_results) - self.merge_blocks(document.pages, provider) + self.merge_blocks(document.pages, provider.page_lines) def get_batch_size(self): if self.batch_size is not None: @@ -53,12 +52,11 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size) page.add_structure(layout_block) - def merge_blocks(self, document_pages: List[PageGroup], provider: PdfProvider): - provider_page_lines = provider.page_lines - for idx, (document_page, provider_lines) in enumerate(zip(document_pages, provider_page_lines.values())): - all_line_idxs = set(range(len(provider_lines))) + def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: PdfPageProviderLines): + for document_page, provider_lines in zip(document_pages, provider_page_lines.values()): + provider_line_idxs = set(range(len(provider_lines))) max_intersections = {} - for line_idx, line in enumerate(provider_lines): + for line_idx, (line, spans) in enumerate(zip(*provider_lines)): for block_idx, block in enumerate(document_page.children): intersection_pct = line.polygon.intersection_pct(block.polygon) if line_idx not in max_intersections: @@ -67,18 +65,21 @@ def merge_blocks(self, document_pages: List[PageGroup], provider: PdfProvider): max_intersections[line_idx] = (intersection_pct, block_idx) assigned_line_idxs = set() - for line_idx, line in enumerate(provider_lines): + for line_idx, (line, spans) in enumerate(zip(*provider_lines)): if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0: document_page.add_full_block(line) block_idx = max_intersections[line_idx][1] block: Block = document_page.children[block_idx] block.add_structure(line) assigned_line_idxs.add(line_idx) + for span in spans: + document_page.add_full_block(span) + line.add_structure(span) - for line_idx in all_line_idxs.difference(assigned_line_idxs): + for line_idx in provider_line_idxs.difference(assigned_line_idxs): min_dist = None min_dist_idx = None - line: Line = provider_lines[line_idx] + (line, spans) = provider_lines[line_idx] for block_idx, block in enumerate(document_page.children): dist = line.polygon.center_distance(block.polygon) if min_dist_idx is None or dist < min_dist: @@ -90,10 +91,15 @@ def merge_blocks(self, document_pages: List[PageGroup], provider: PdfProvider): nearest_block = document_page.children[min_dist_idx] nearest_block.add_structure(line) assigned_line_idxs.add(line_idx) + for span in spans: + document_page.add_full_block(span) + line.add_structure(span) - for line_idx in all_line_idxs.difference(assigned_line_idxs): - line: Line = provider_lines[line_idx] + for line_idx in provider_line_idxs.difference(assigned_line_idxs): + line, spans = provider_lines[line_idx] document_page.add_full_block(line) - # How do we add structure for when layout doesn't detect text?, squeeze between nearest block? text_block = document_page.add_block(Text, polygon=line.polygon) text_block.add_structure(line) + for span in spans: + document_page.add_full_block(span) + text_block.add_structure(span) diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index 94a406e..4c9b391 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -1,4 +1,5 @@ import functools +from typing import Dict, List, Tuple from typing import Dict, List, Optional import pypdfium2 as pdfium @@ -10,6 +11,9 @@ from marker.v2.schema.polygon import PolygonBox from marker.v2.schema.text.line import Line, Span +PdfPageProviderLine = Tuple[List[Line], List[List[Span]]] +PdfPageProviderLines = Dict[int, PdfPageProviderLine] + class PdfProvider(BaseProvider): page_range: List[int] | None = None @@ -19,7 +23,7 @@ class PdfProvider(BaseProvider): def __init__(self, filepath: str, config: Optional[BaseModel] = None): super().__init__(filepath, config) - self.page_lines: Dict[int, List[Line]] = {} + self.page_lines: PdfPageProviderLines = {} self.doc: pdfium.PdfDocument self.setup() @@ -79,6 +83,7 @@ def setup(self): for page in page_char_blocks: page_id = page["page"] lines: List[Line] = [] + line_spans: List[List[Span]] = [] for block in page["blocks"]: for line in block["lines"]: spans: List[Span] = [] @@ -101,11 +106,11 @@ def setup(self): lines.append( Line( polygon=PolygonBox.from_bbox(line["bbox"]), - spans=spans, page_id=page_id, ) ) - self.page_lines[page_id] = lines + line_spans.append(spans) + self.page_lines[page_id] = (lines, line_spans) @ functools.lru_cache(maxsize=None) def get_image(self, idx: int, dpi: int) -> Image.Image: @@ -118,5 +123,5 @@ def get_page_bbox(self, idx: int) -> List[float]: page = self.doc[idx] return page.get_bbox() - def get_page_lines(self, idx: int) -> List[Line]: + def get_page_lines(self, idx: int) -> PdfPageProviderLine: return self.page_lines[idx] diff --git a/marker/v2/schema/blocks/base.py b/marker/v2/schema/blocks/base.py index 7e37da1..3a0d1db 100644 --- a/marker/v2/schema/blocks/base.py +++ b/marker/v2/schema/blocks/base.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, ConfigDict -from marker.v2.schema import PolygonBox +from marker.v2.schema.polygon import PolygonBox class Block(BaseModel): diff --git a/marker/v2/schema/text/line.py b/marker/v2/schema/text/line.py index c06c7d2..a9234d7 100644 --- a/marker/v2/schema/text/line.py +++ b/marker/v2/schema/text/line.py @@ -6,5 +6,3 @@ class Line(Block): block_type: str = "Line" - - spans: List[Span] diff --git a/tests/test_document_builder.py b/tests/test_document_builder.py index a891abb..a51d289 100644 --- a/tests/test_document_builder.py +++ b/tests/test_document_builder.py @@ -25,29 +25,23 @@ def test_document_builder(layout_model): assert len(document.pages) == len(provider) first_page = document.pages[0] - assert first_page.structure[0] == '/page/0/block/0' + assert first_page.structure[0] == '/page/0/Section-header/0' - first_block = first_page.get_block('/page/0/block/0') + first_block = first_page.get_block(first_page.structure[0]) assert first_block.block_type == 'Section-header' - assert first_block.structure[0] == '/page/0/block/15' - - first_text_block: Line = first_page.get_block('/page/0/block/15') + first_text_block: Line = first_page.get_block(first_block.structure[0]) assert first_text_block.block_type == 'Line' - - first_span = first_text_block.spans[0] + first_span = first_page.get_block(first_text_block.structure[0]) assert first_span.block_type == 'Span' assert first_span.text == 'Subspace Adversarial Training' assert first_span.font == 'NimbusRomNo9L-Medi' assert first_span.formats == ['plain'] - last_block = first_page.get_block('/page/0/block/14') + last_block = first_page.get_block(first_page.structure[-1]) assert last_block.block_type == 'Text' - assert last_block.structure[-1] == '/page/0/block/106' - - last_text_block: Line = first_page.get_block('/page/0/block/106') + last_text_block: Line = first_page.get_block(last_block.structure[-1]) assert last_text_block.block_type == 'Line' - - last_span = last_text_block.spans[-1] + last_span = first_page.get_block(last_text_block.structure[-1]) assert last_span.block_type == 'Span' assert last_span.text == 'prove the quality of single-step AT solutions. However,' assert last_span.font == 'NimbusRomNo9L-Regu' diff --git a/tests/test_pdf_provider.py b/tests/test_pdf_provider.py index fa766ed..57b5732 100644 --- a/tests/test_pdf_provider.py +++ b/tests/test_pdf_provider.py @@ -17,8 +17,9 @@ def test_pdf_provider(): assert len(provider) == 12 assert provider.get_image(0, 72).size == (612, 792) assert provider.get_image(0, 96).size == (816, 1056) - line = provider.get_page_lines(0)[0] - spans = line.spans + lines, spans_list = provider.get_page_lines(0) + assert len(spans_list) == 93 + spans = spans_list[0] assert len(spans) == 1 assert spans[0].text == "Subspace Adversarial Training" assert spans[0].font == "NimbusRomNo9L-Medi"