From d27b8743fd55d75e4f6a67f418a81f4deef74586 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 14 Nov 2024 10:58:28 -0500 Subject: [PATCH 1/5] Refactor config --- marker/utils.py | 3 +++ marker/v2/builders/__init__.py | 6 +++--- marker/v2/builders/document.py | 4 ++-- marker/v2/converters/__init__.py | 6 +++--- marker/v2/converters/pdf.py | 3 +-- marker/v2/processors/__init__.py | 6 +++--- marker/v2/providers/__init__.py | 15 ++++++--------- marker/v2/providers/pdf.py | 21 ++++++++++++--------- marker/v2/renderers/__init__.py | 10 +++++++++- marker/v2/schema/config/pdf.py | 8 -------- marker/v2/schema/config/provider.py | 9 --------- marker/v2/util.py | 4 ++++ tests/conftest.py | 3 +-- tests/test_document_builder.py | 3 +-- tests/test_pdf_provider.py | 3 +-- tests/utils.py | 3 +-- 16 files changed, 50 insertions(+), 57 deletions(-) delete mode 100644 marker/v2/schema/config/pdf.py delete mode 100644 marker/v2/schema/config/provider.py create mode 100644 marker/v2/util.py diff --git a/marker/utils.py b/marker/utils.py index aa83ec10..57c3b5e6 100644 --- a/marker/utils.py +++ b/marker/utils.py @@ -5,3 +5,6 @@ def flush_cuda_memory(): if settings.TORCH_DEVICE_MODEL == "cuda": torch.cuda.empty_cache() + + + diff --git a/marker/v2/builders/__init__.py b/marker/v2/builders/__init__.py index 4b000896..54c05b8b 100644 --- a/marker/v2/builders/__init__.py +++ b/marker/v2/builders/__init__.py @@ -2,12 +2,12 @@ from pydantic import BaseModel +from marker.v2.util import assign_config + class BaseBuilder: def __init__(self, config: Optional[BaseModel] = None): - if config: - for k in config.model_fields: - setattr(self, k, config[k]) + assign_config(self, config) def __call__(self, data, *args, **kwargs): raise NotImplementedError diff --git a/marker/v2/builders/document.py b/marker/v2/builders/document.py index f260d7c1..1493119d 100644 --- a/marker/v2/builders/document.py +++ b/marker/v2/builders/document.py @@ -14,10 +14,10 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder): return document def build_document(self, provider: PdfProvider): - if provider.config.page_range is None: + if provider.page_range is None: page_range = range(len(provider)) else: - page_range = provider.config.page_range + page_range = provider.page_range assert max(page_range) < len(provider) and min(page_range) >= 0, "Invalid page range" initial_pages = [ diff --git a/marker/v2/converters/__init__.py b/marker/v2/converters/__init__.py index 9ffded3d..34863d02 100644 --- a/marker/v2/converters/__init__.py +++ b/marker/v2/converters/__init__.py @@ -2,12 +2,12 @@ from pydantic import BaseModel +from marker.v2.util import assign_config + class BaseConverter: def __init__(self, config: Optional[BaseModel] = None): - if config: - for k in config.model_fields: - setattr(self, k, config[k]) + assign_config(self, config) def __call__(self): raise NotImplementedError \ No newline at end of file diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index fde473d2..e997ef3a 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -8,7 +8,6 @@ from marker.v2.builders.structure import StructureBuilder from marker.v2.converters import BaseConverter from marker.v2.providers.pdf import PdfProvider -from marker.v2.schema.config.pdf import PdfProviderConfig class PdfConverter(BaseConverter): @@ -16,7 +15,7 @@ class PdfConverter(BaseConverter): page_range: List[int] | None = None def __call__(self): - pdf_provider = PdfProvider(self.config.filepath, PdfProviderConfig()) + pdf_provider = PdfProvider(self.filepath) layout_model = load_model() layout_model.processor = load_processor() diff --git a/marker/v2/processors/__init__.py b/marker/v2/processors/__init__.py index ff123fd0..763f5507 100644 --- a/marker/v2/processors/__init__.py +++ b/marker/v2/processors/__init__.py @@ -2,9 +2,9 @@ from pydantic import BaseModel +from marker.v2.util import assign_config + class BaseProcessor: def __init__(self, config: Optional[BaseModel] = None): - if config: - for k in config.model_fields: - setattr(self, k, config[k]) \ No newline at end of file + assign_config(self, config) \ No newline at end of file diff --git a/marker/v2/providers/__init__.py b/marker/v2/providers/__init__.py index d8d63ffb..bfb3b7e4 100644 --- a/marker/v2/providers/__init__.py +++ b/marker/v2/providers/__init__.py @@ -1,22 +1,19 @@ -from typing import List +from typing import List, Optional + +from pydantic import BaseModel -from marker.v2.schema.config.provider import ProviderConfig from marker.v2.schema.text.line import Line +from marker.v2.util import assign_config class BaseProvider: - def __init__(self, filepath: str, config: ProviderConfig): + def __init__(self, filepath: str, config: Optional[BaseModel] = None): + assign_config(self, config) self.filepath = filepath - self.config = config - - self.setup() def __len__(self): pass - def setup(self): - pass - def get_image(self, idx: int, dpi: int): pass diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index 7ee23402..94a406ea 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -1,22 +1,25 @@ import functools -from typing import Dict, List +from typing import Dict, List, Optional import pypdfium2 as pdfium from pdftext.extraction import dictionary_output from PIL import Image +from pydantic import BaseModel from marker.v2.providers import BaseProvider -from marker.v2.schema.config.pdf import PdfProviderConfig from marker.v2.schema.polygon import PolygonBox from marker.v2.schema.text.line import Line, Span class PdfProvider(BaseProvider): - def __init__(self, filepath: str, config: PdfProviderConfig): - self.filepath = filepath - self.config = config - self.page_lines: Dict[int, List[Line]] = {} + page_range: List[int] | None = None + pdftext_workers: int = 4 + flatten_pdf: bool = True + + def __init__(self, filepath: str, config: Optional[BaseModel] = None): + super().__init__(filepath, config) + self.page_lines: Dict[int, List[Line]] = {} self.doc: pdfium.PdfDocument self.setup() @@ -68,10 +71,10 @@ def setup(self): self.doc = pdfium.PdfDocument(self.filepath) page_char_blocks = dictionary_output( self.filepath, - page_range=self.config.page_range, + page_range=self.page_range, keep_chars=False, - workers=self.config.pdftext_workers, - flatten_pdf=self.config.flatten_pdf + workers=self.pdftext_workers, + flatten_pdf=self.flatten_pdf ) for page in page_char_blocks: page_id = page["page"] diff --git a/marker/v2/renderers/__init__.py b/marker/v2/renderers/__init__.py index 197d27d0..4ffa9b5e 100644 --- a/marker/v2/renderers/__init__.py +++ b/marker/v2/renderers/__init__.py @@ -1,2 +1,10 @@ +from typing import Optional + +from pydantic import BaseModel + + class BaseRenderer: - pass \ No newline at end of file + def __init__(self, config: Optional[BaseModel] = None): + if config: + for k in config.model_fields: + setattr(self, k, config[k]) \ No newline at end of file diff --git a/marker/v2/schema/config/pdf.py b/marker/v2/schema/config/pdf.py deleted file mode 100644 index 7aa39da9..00000000 --- a/marker/v2/schema/config/pdf.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Optional - -from marker.v2.schema.config.provider import ProviderConfig - - -class PdfProviderConfig(ProviderConfig): - pdftext_workers: int = 4 - flatten_pdf: bool = True diff --git a/marker/v2/schema/config/provider.py b/marker/v2/schema/config/provider.py deleted file mode 100644 index d6f1b90e..00000000 --- a/marker/v2/schema/config/provider.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel, ConfigDict - - -class ProviderConfig(BaseModel): - page_range: Optional[range] = None - - model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/marker/v2/util.py b/marker/v2/util.py new file mode 100644 index 00000000..3887a9a5 --- /dev/null +++ b/marker/v2/util.py @@ -0,0 +1,4 @@ +def assign_config(cls, config): + if config: + for k in config.model_fields: + setattr(cls, k, config[k]) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index cbf1c2a8..c9a361f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.providers.pdf import PdfProvider -from marker.v2.schema.config.pdf import PdfProviderConfig from marker.v2.schema.document import Document @@ -34,7 +33,7 @@ def pdf_document(request, layout_model) -> Document: temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() - provider = PdfProvider(temp_pdf.name, PdfProviderConfig()) + provider = PdfProvider(temp_pdf.name) layout_builder = LayoutBuilder(layout_model) builder = DocumentBuilder() document = builder(provider, layout_builder) diff --git a/tests/test_document_builder.py b/tests/test_document_builder.py index aa26ffe9..e87758fe 100644 --- a/tests/test_document_builder.py +++ b/tests/test_document_builder.py @@ -6,7 +6,6 @@ from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder -from marker.v2.schema.config.pdf import PdfProviderConfig def test_document_builder(layout_model): @@ -17,7 +16,7 @@ def test_document_builder(layout_model): temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() - provider = PdfProvider(temp_pdf.name, PdfProviderConfig()) + provider = PdfProvider(temp_pdf.name) layout_builer = LayoutBuilder(layout_model) builder = DocumentBuilder() document = builder(provider, layout_builer) diff --git a/tests/test_pdf_provider.py b/tests/test_pdf_provider.py index 37d47dd9..fa766ed4 100644 --- a/tests/test_pdf_provider.py +++ b/tests/test_pdf_provider.py @@ -3,7 +3,6 @@ import datasets from marker.v2.providers.pdf import PdfProvider -from marker.v2.schema.config.pdf import PdfProviderConfig def test_pdf_provider(): @@ -14,7 +13,7 @@ def test_pdf_provider(): temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() - provider = PdfProvider(temp_pdf.name, PdfProviderConfig()) + provider = PdfProvider(temp_pdf.name) assert len(provider) == 12 assert provider.get_image(0, 72).size == (612, 792) assert provider.get_image(0, 96).size == (816, 1056) diff --git a/tests/utils.py b/tests/utils.py index 1ad0e30d..4b2813b8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,6 @@ from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.providers.pdf import PdfProvider -from marker.v2.schema.config.pdf import PdfProviderConfig from marker.v2.schema.document import Document @@ -21,7 +20,7 @@ def setup_pdf_document(filename: str) -> Document: layout_model = load_model() layout_model.processor = load_processor() - provider = PdfProvider(temp_pdf.name, PdfProviderConfig()) + provider = PdfProvider(temp_pdf.name) layout_builder = LayoutBuilder(layout_model) builder = DocumentBuilder() document = builder(provider, layout_builder) From c1d97f03ab31e2d7bb9349118f72b216fdf1d2c2 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 14 Nov 2024 11:10:34 -0500 Subject: [PATCH 2/5] Config is basemodel or dict --- marker/v2/builders/__init__.py | 2 +- marker/v2/converters/__init__.py | 4 ++-- marker/v2/converters/pdf.py | 16 +++++++++------- marker/v2/processors/__init__.py | 9 +++++++-- marker/v2/processors/equation.py | 9 +++++++++ marker/v2/providers/__init__.py | 2 +- marker/v2/renderers/__init__.py | 2 +- marker/v2/schema/groups/__init__.py | 11 +++++++++-- marker/v2/schema/util.py | 10 ---------- marker/v2/util.py | 14 +++++++++++--- 10 files changed, 50 insertions(+), 29 deletions(-) delete mode 100644 marker/v2/schema/util.py diff --git a/marker/v2/builders/__init__.py b/marker/v2/builders/__init__.py index 54c05b8b..ac43c757 100644 --- a/marker/v2/builders/__init__.py +++ b/marker/v2/builders/__init__.py @@ -6,7 +6,7 @@ class BaseBuilder: - def __init__(self, config: Optional[BaseModel] = None): + def __init__(self, config: Optional[BaseModel | dict] = None): assign_config(self, config) def __call__(self, data, *args, **kwargs): diff --git a/marker/v2/converters/__init__.py b/marker/v2/converters/__init__.py index 34863d02..787b8a85 100644 --- a/marker/v2/converters/__init__.py +++ b/marker/v2/converters/__init__.py @@ -6,8 +6,8 @@ class BaseConverter: - def __init__(self, config: Optional[BaseModel] = None): + def __init__(self, config: Optional[BaseModel | dict] = None): assign_config(self, config) - def __call__(self): + def __call__(self, *args, **kwargs): raise NotImplementedError \ No newline at end of file diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index e997ef3a..e7140d6d 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -1,5 +1,6 @@ -from typing import List +from typing import List, Optional +from pydantic import BaseModel from surya.model.layout.model import load_model from surya.model.layout.processor import load_processor @@ -11,16 +12,17 @@ class PdfConverter(BaseConverter): - filepath: str - page_range: List[int] | None = None - - def __call__(self): - pdf_provider = PdfProvider(self.filepath) + def __init__(self, config: Optional[BaseModel] = None): + super().__init__(config) layout_model = load_model() layout_model.processor = load_processor() - layout_builder = LayoutBuilder(layout_model) + self.layout_model = layout_model + + def __call__(self, filepath: str, page_range: List[int] | None = None): + pdf_provider = PdfProvider(filepath, {"page_range": page_range}) + layout_builder = LayoutBuilder(self.layout_model) document = DocumentBuilder()(pdf_provider, layout_builder) StructureBuilder()(document) diff --git a/marker/v2/processors/__init__.py b/marker/v2/processors/__init__.py index 763f5507..dc519ee3 100644 --- a/marker/v2/processors/__init__.py +++ b/marker/v2/processors/__init__.py @@ -6,5 +6,10 @@ class BaseProcessor: - def __init__(self, config: Optional[BaseModel] = None): - assign_config(self, config) \ No newline at end of file + block_type: str | None = None # What block type this processor is responsible for + + def __init__(self, config: Optional[BaseModel | dict] = None): + assign_config(self, config) + + def __call__(self, *args, **kwargs): + raise NotImplementedError \ No newline at end of file diff --git a/marker/v2/processors/equation.py b/marker/v2/processors/equation.py index e5c00519..8b919ca1 100644 --- a/marker/v2/processors/equation.py +++ b/marker/v2/processors/equation.py @@ -1,6 +1,15 @@ +from typing import Optional + +from pydantic import BaseModel + from marker.v2.processors import BaseProcessor class EquationProcessor(BaseProcessor): block_type = "Equation" + def __init__(self, texify_model, config: Optional[BaseModel] = None): + super().__init__(config) + + self.texify_model = texify_model + diff --git a/marker/v2/providers/__init__.py b/marker/v2/providers/__init__.py index bfb3b7e4..04c9826a 100644 --- a/marker/v2/providers/__init__.py +++ b/marker/v2/providers/__init__.py @@ -7,7 +7,7 @@ class BaseProvider: - def __init__(self, filepath: str, config: Optional[BaseModel] = None): + def __init__(self, filepath: str, config: Optional[BaseModel | dict] = None): assign_config(self, config) self.filepath = filepath diff --git a/marker/v2/renderers/__init__.py b/marker/v2/renderers/__init__.py index 4ffa9b5e..d932ebfc 100644 --- a/marker/v2/renderers/__init__.py +++ b/marker/v2/renderers/__init__.py @@ -4,7 +4,7 @@ class BaseRenderer: - def __init__(self, config: Optional[BaseModel] = None): + def __init__(self, config: Optional[BaseModel | dict] = None): if config: for k in config.model_fields: setattr(self, k, config[k]) \ No newline at end of file diff --git a/marker/v2/schema/groups/__init__.py b/marker/v2/schema/groups/__init__.py index d170ef2c..bc2c2091 100644 --- a/marker/v2/schema/groups/__init__.py +++ b/marker/v2/schema/groups/__init__.py @@ -1,7 +1,14 @@ +from marker.v2.schema import Block from marker.v2.schema.groups.figure import FigureGroup from marker.v2.schema.groups.table import TableGroup from marker.v2.schema.groups.list import ListGroup from marker.v2.schema.groups.picture import PictureGroup -from marker.v2.schema.util import build_block_registry -GROUP_BLOCK_REGISTRY = build_block_registry() +GROUP_BLOCK_REGISTRY = { + v.model_fields['block_type'].default: v for k, v in locals().items() + if isinstance(v, type) + and issubclass(v, Block) + and v != Block # Exclude the base Block class + and v.model_fields['block_type'].default.endswith("Group") +} + diff --git a/marker/v2/schema/util.py b/marker/v2/schema/util.py deleted file mode 100644 index 4f03950c..00000000 --- a/marker/v2/schema/util.py +++ /dev/null @@ -1,10 +0,0 @@ -from marker.v2.schema import Block - - -def build_block_registry(): - return { - v.block_type: v for k, v in locals().items() - if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - } \ No newline at end of file diff --git a/marker/v2/util.py b/marker/v2/util.py index 3887a9a5..c7819c22 100644 --- a/marker/v2/util.py +++ b/marker/v2/util.py @@ -1,4 +1,12 @@ -def assign_config(cls, config): - if config: +from pydantic import BaseModel + + +def assign_config(cls, config: BaseModel | dict | None): + if config is None: + return + elif isinstance(config, BaseModel): for k in config.model_fields: - setattr(cls, k, config[k]) \ No newline at end of file + setattr(cls, k, config[k]) + elif isinstance(config, dict): + for k, v in config.items(): + setattr(cls, k, v) \ No newline at end of file From 8cbfea5bce063eaf225c2e649e7f3a773cd2f079 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 14 Nov 2024 14:39:51 -0500 Subject: [PATCH 3/5] Add processors --- marker/v2/builders/layout.py | 11 +-- marker/v2/processors/__init__.py | 3 +- marker/v2/processors/equation.py | 87 ++++++++++++++++++- marker/v2/processors/table.py | 102 ++++++++++++++++++++++- marker/v2/schema/__init__.py | 24 ------ marker/v2/schema/blocks/__init__.py | 53 +++++++++++- marker/v2/schema/blocks/caption.py | 2 +- marker/v2/schema/blocks/code.py | 2 +- marker/v2/schema/blocks/equation.py | 3 +- marker/v2/schema/blocks/figure.py | 2 +- marker/v2/schema/blocks/footnote.py | 2 +- marker/v2/schema/blocks/form.py | 2 +- marker/v2/schema/blocks/handwriting.py | 2 +- marker/v2/schema/blocks/inlinemath.py | 2 +- marker/v2/schema/blocks/listitem.py | 2 +- marker/v2/schema/blocks/pagefooter.py | 2 +- marker/v2/schema/blocks/pageheader.py | 2 +- marker/v2/schema/blocks/picture.py | 2 +- marker/v2/schema/blocks/sectionheader.py | 2 +- marker/v2/schema/blocks/table.py | 7 +- marker/v2/schema/blocks/text.py | 2 +- marker/v2/schema/blocks/toc.py | 2 +- marker/v2/schema/groups/__init__.py | 2 +- marker/v2/schema/groups/figure.py | 2 +- marker/v2/schema/groups/list.py | 2 +- marker/v2/schema/groups/page.py | 2 +- marker/v2/schema/groups/picture.py | 2 +- marker/v2/schema/groups/table.py | 2 +- marker/v2/schema/polygon.py | 2 +- marker/v2/schema/text/line.py | 2 +- marker/v2/schema/text/span.py | 2 +- 31 files changed, 278 insertions(+), 58 deletions(-) diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 89aa4884..e9cedf87 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -14,6 +14,8 @@ class LayoutBuilder(BaseBuilder): + batch_size = None + def __init__(self, layout_model, config=None): self.layout_model = layout_model @@ -24,10 +26,9 @@ def __call__(self, document: Document, provider: PdfProvider): self.add_blocks_to_pages(document.pages, layout_results) self.merge_blocks(document.pages, provider) - @classmethod - def get_batch_size(cls): - if settings.LAYOUT_BATCH_SIZE is not None: - return settings.LAYOUT_BATCH_SIZE + def get_batch_size(self): + if self.batch_size is not None: + return self.batch_size elif settings.TORCH_DEVICE_MODEL == "cuda": return 6 return 6 @@ -38,7 +39,7 @@ def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]: [p.lowres_image for p in pages], self.layout_model, processor, - batch_size=int(LayoutBuilder.get_batch_size()) + batch_size=int(self.get_batch_size()) ) return layout_results diff --git a/marker/v2/processors/__init__.py b/marker/v2/processors/__init__.py index dc519ee3..53dde666 100644 --- a/marker/v2/processors/__init__.py +++ b/marker/v2/processors/__init__.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from marker.v2.schema.document import Document from marker.v2.util import assign_config @@ -11,5 +12,5 @@ class BaseProcessor: def __init__(self, config: Optional[BaseModel | dict] = None): assign_config(self, config) - def __call__(self, *args, **kwargs): + def __call__(self, document: Document, *args, **kwargs): raise NotImplementedError \ No newline at end of file diff --git a/marker/v2/processors/equation.py b/marker/v2/processors/equation.py index 8b919ca1..88b5dce5 100644 --- a/marker/v2/processors/equation.py +++ b/marker/v2/processors/equation.py @@ -1,15 +1,100 @@ -from typing import Optional +from typing import Optional, List from pydantic import BaseModel +from tqdm import tqdm +from marker.settings import settings from marker.v2.processors import BaseProcessor +from marker.v2.schema.document import Document + +from texify.inference import batch_inference class EquationProcessor(BaseProcessor): block_type = "Equation" + model_max_length = 384 + batch_size = None + token_buffer = 256 def __init__(self, texify_model, config: Optional[BaseModel] = None): super().__init__(config) self.texify_model = texify_model + def __call__(self, document: Document): + equation_data = [] + + for page in document.pages: + for block in page.children: + if block.block_type != self.block_type: + continue + image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.lowres_image.size) + image = page.lowres_image.crop(image_poly.bbox).convert("RGB") + raw_text = block.raw_text(document) + token_count = self.get_total_texify_tokens(raw_text) + + equation_data.append({ + "image": image, + "block_id": block._id, + "token_count": token_count + }) + + predictions = self.get_latex_batched(equation_data) + for prediction, equation_d in zip(predictions, equation_data): + conditions = [ + self.get_total_texify_tokens(prediction) < self.model_max_length, + # Make sure we didn't get to the overall token max, indicates run-on + len(prediction) > equation_d["token_count"] * .4, + len(prediction.strip()) > 0 + ] + if not all(conditions): + continue + + block = document.get_block_by_id(equation_d["block_id"]) + block.latex = prediction + + def get_batch_size(self): + if self.batch_size is not None: + return self.batch_size + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 6 + elif settings.TORCH_DEVICE_MODEL == "mps": + return 6 + return 2 + + def get_latex_batched(self, equation_data: List[dict]): + predictions = [""] * len(equation_data) + batch_size = self.get_batch_size() + + for i in tqdm(range(0, len(equation_data), batch_size), desc="Recognizing equations"): + # Dynamically set max length to save inference time + min_idx = i + max_idx = min(min_idx + batch_size, len(equation_data)) + + batch_equations = equation_data[min_idx:max_idx] + max_length = max([eq["token_count"] for eq in batch_equations]) + max_length = min(max_length, self.model_max_length) + max_length += self.token_buffer + + batch_images = [eq["image"] for eq in batch_equations] + + model_output = batch_inference( + batch_images, + self.texify_model, + self.texify_model.processor, + max_tokens=max_length + ) + + for j, output in enumerate(model_output): + token_count = self.get_total_texify_tokens(output) + if token_count >= max_length - 1: + output = "" + + image_idx = i + j + predictions[image_idx] = output + return predictions + + def get_total_texify_tokens(self, text): + tokenizer = self.texify_model.processor.tokenizer + tokens = tokenizer(text) + return len(tokens["input_ids"]) \ No newline at end of file diff --git a/marker/v2/processors/table.py b/marker/v2/processors/table.py index 89ab3953..ee06105e 100644 --- a/marker/v2/processors/table.py +++ b/marker/v2/processors/table.py @@ -1,5 +1,105 @@ +from typing import Optional + +from pydantic import BaseModel +from surya.input.pdflines import get_page_text_lines +from tabled.assignment import assign_rows_columns +from tabled.inference.recognition import get_cells, recognize_tables + +from marker.settings import settings from marker.v2.processors import BaseProcessor +from marker.v2.schema.document import Document class TableProcessor(BaseProcessor): - pass \ No newline at end of file + block_type = "Table" + detect_boxes = False + detector_batch_size = None + table_rec_batch_size = None + ocr_batch_size = None + + def __init__(self, detection_model, ocr_model, table_rec_model, config: Optional[BaseModel] = None): + super().__init__(config) + + self.detection_model = detection_model + self.ocr_model = ocr_model + self.table_rec_model = table_rec_model + + def __call__(self, document: Document): + filepath = document.filepath # Path to original pdf file + + table_data = [] + for page in document.pages: + for block in page.children: + if block.block_type != self.block_type: + continue + image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.highres_image.size) + image = page.highres_image.crop(image_poly.bbox).convert("RGB") + + if block.text_extraction_method == "ocr": + text_lines = None + else: + text_lines = get_page_text_lines( + filepath, + [page.page_id], + page.highres_image.size, + flatten_pdf=True + ) + + table_data.append({ + "block_id": block._id, + "table_image": image, + "table_bbox": image_poly.bbox, + "text_lines": text_lines, + "img_size": page.highres_image.size + }) + + lst_format = zip( + *(t[key] for key in ["table_image", "table_bbox", "img_size", "text_lines"]) + for t in table_data + ) + + cells, needs_ocr = get_cells( + *lst_format, + [self.detection_model, self.detection_model.processor], + detect_boxes=self.detect_boxes, + detector_batch_size=self.get_detector_batch_size() + ) + + tables = recognize_tables( + [t["table_image"] for t in table_data], + cells, + needs_ocr, + [self.table_rec_model, self.table_rec_model.processor, self.ocr_model, self.ocr_model.processor], + table_rec_batch_size=self.get_table_rec_batch_size(), + ocr_batch_size=self.get_ocr_batch_size() + ) + + for table_d, table_res in zip(table_data, tables): + block = document.get_block_by_id(table_d["block_id"]) + cells = assign_rows_columns(table_res, table_d["img_size"]) + block.cells = cells + + def get_detector_batch_size(self): + if self.detector_batch_size is not None: + return self.detector_batch_size + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 4 + return 4 + + def get_table_rec_batch_size(self): + if self.table_rec_batch_size is not None: + return self.table_rec_batch_size + elif settings.TORCH_DEVICE_MODEL == "mps": + return 16 + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 64 + return 8 + + def get_ocr_batch_size(self): + if self.ocr_batch_size is not None: + return self.ocr_batch_size + elif settings.TORCH_DEVICE_MODEL == "mps": + return 32 + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 128 + return 32 \ No newline at end of file diff --git a/marker/v2/schema/__init__.py b/marker/v2/schema/__init__.py index facada45..8a84fc47 100644 --- a/marker/v2/schema/__init__.py +++ b/marker/v2/schema/__init__.py @@ -1,29 +1,5 @@ from __future__ import annotations -from typing import Optional, List -from pydantic import BaseModel, ConfigDict from marker.v2.schema.polygon import PolygonBox -class Block(BaseModel): - polygon: PolygonBox - block_type: Optional[str] = None - block_id: Optional[int] = None - page_id: Optional[int] = None - structure: List[str] | None = None # The top-level page structure, which is the block ids in order - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @property - def _id(self): - page_path = f"/page/{self.page_id}" - if self.block_id is not None: - return f"{page_path}/block/{self.block_id}" - else: - return page_path - - def add_structure(self, block: Block): - if self.structure is None: - self.structure = [block._id] - else: - self.structure.append(block._id) diff --git a/marker/v2/schema/blocks/__init__.py b/marker/v2/schema/blocks/__init__.py index 779fd7e8..f5341362 100644 --- a/marker/v2/schema/blocks/__init__.py +++ b/marker/v2/schema/blocks/__init__.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +from typing import Optional, List + +from pydantic import BaseModel, ConfigDict + +from marker.v2.schema import PolygonBox from marker.v2.schema.blocks.caption import Caption from marker.v2.schema.blocks.code import Code from marker.v2.schema.blocks.figure import Figure @@ -14,11 +21,55 @@ from marker.v2.schema.blocks.table import Table from marker.v2.schema.blocks.text import Text from marker.v2.schema.blocks.toc import TableOfContents -from marker.v2.schema import Block +from marker.v2.schema.document import Document +from marker.v2.schema.text.line import Line +from marker.v2.schema.text.span import Span + + +class Block(BaseModel): + polygon: PolygonBox + block_type: Optional[str] = None + block_id: Optional[int] = None + page_id: Optional[int] = None + structure: List[str] | None = None # The top-level page structure, which is the block ids in order + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def _id(self): + page_path = f"/page/{self.page_id}" + if self.block_id is not None: + return f"{page_path}/{self.block_type}/{self.block_id}" + else: + return page_path + + def add_structure(self, block: Block): + if self.structure is None: + self.structure = [block._id] + else: + self.structure.append(block._id) + + def raw_text(self, document: Document): + if self.structure is None: + return 0 + + text = "" + for block_id in self.structure: + block = document.get_block(block_id) + if isinstance(block, Span): + text += block.text + else: + text += block.raw_text(document) + if isinstance(block, Line): + text += "\n" + return text + + LAYOUT_BLOCK_REGISTRY = { v.model_fields['block_type'].default: v for k, v in locals().items() if isinstance(v, type) and issubclass(v, Block) and v != Block # Exclude the base Block class + and not v.model_fields['block_type'].default.endswith("Group") } diff --git a/marker/v2/schema/blocks/caption.py b/marker/v2/schema/blocks/caption.py index 1a7bedde..4793e0fc 100644 --- a/marker/v2/schema/blocks/caption.py +++ b/marker/v2/schema/blocks/caption.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Caption(Block): diff --git a/marker/v2/schema/blocks/code.py b/marker/v2/schema/blocks/code.py index d07074b8..74b07a3a 100644 --- a/marker/v2/schema/blocks/code.py +++ b/marker/v2/schema/blocks/code.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Code(Block): diff --git a/marker/v2/schema/blocks/equation.py b/marker/v2/schema/blocks/equation.py index 89dc6409..b6df4ae8 100644 --- a/marker/v2/schema/blocks/equation.py +++ b/marker/v2/schema/blocks/equation.py @@ -1,5 +1,6 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Equation(Block): block_type: str = "Equation" + latex: str | None = None diff --git a/marker/v2/schema/blocks/figure.py b/marker/v2/schema/blocks/figure.py index efa40af3..9f25b068 100644 --- a/marker/v2/schema/blocks/figure.py +++ b/marker/v2/schema/blocks/figure.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Figure(Block): diff --git a/marker/v2/schema/blocks/footnote.py b/marker/v2/schema/blocks/footnote.py index 33044b4f..4a87b8a8 100644 --- a/marker/v2/schema/blocks/footnote.py +++ b/marker/v2/schema/blocks/footnote.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Footnote(Block): diff --git a/marker/v2/schema/blocks/form.py b/marker/v2/schema/blocks/form.py index bf931d17..6e62ad29 100644 --- a/marker/v2/schema/blocks/form.py +++ b/marker/v2/schema/blocks/form.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Form(Block): diff --git a/marker/v2/schema/blocks/handwriting.py b/marker/v2/schema/blocks/handwriting.py index ffdc95ba..053d9ecc 100644 --- a/marker/v2/schema/blocks/handwriting.py +++ b/marker/v2/schema/blocks/handwriting.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Handwriting(Block): diff --git a/marker/v2/schema/blocks/inlinemath.py b/marker/v2/schema/blocks/inlinemath.py index 865ab317..4cc76380 100644 --- a/marker/v2/schema/blocks/inlinemath.py +++ b/marker/v2/schema/blocks/inlinemath.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class InlineMath(Block): diff --git a/marker/v2/schema/blocks/listitem.py b/marker/v2/schema/blocks/listitem.py index 6b26b158..0e3f67ec 100644 --- a/marker/v2/schema/blocks/listitem.py +++ b/marker/v2/schema/blocks/listitem.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class ListItem(Block): diff --git a/marker/v2/schema/blocks/pagefooter.py b/marker/v2/schema/blocks/pagefooter.py index 36b0c05a..a676a987 100644 --- a/marker/v2/schema/blocks/pagefooter.py +++ b/marker/v2/schema/blocks/pagefooter.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class PageFooter(Block): diff --git a/marker/v2/schema/blocks/pageheader.py b/marker/v2/schema/blocks/pageheader.py index b41a5d68..e48dc217 100644 --- a/marker/v2/schema/blocks/pageheader.py +++ b/marker/v2/schema/blocks/pageheader.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class PageHeader(Block): diff --git a/marker/v2/schema/blocks/picture.py b/marker/v2/schema/blocks/picture.py index 3ee7feac..212ba5f1 100644 --- a/marker/v2/schema/blocks/picture.py +++ b/marker/v2/schema/blocks/picture.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Picture(Block): diff --git a/marker/v2/schema/blocks/sectionheader.py b/marker/v2/schema/blocks/sectionheader.py index a3f149f1..01f0b741 100644 --- a/marker/v2/schema/blocks/sectionheader.py +++ b/marker/v2/schema/blocks/sectionheader.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class SectionHeader(Block): diff --git a/marker/v2/schema/blocks/table.py b/marker/v2/schema/blocks/table.py index 34bd2976..39b4e258 100644 --- a/marker/v2/schema/blocks/table.py +++ b/marker/v2/schema/blocks/table.py @@ -1,5 +1,10 @@ -from marker.v2.schema import Block +from typing import List + +from tabled.schema import SpanTableCell + +from marker.v2.schema.blocks import Block class Table(Block): block_type: str = "Table" + cells: List[SpanTableCell] | None = None \ No newline at end of file diff --git a/marker/v2/schema/blocks/text.py b/marker/v2/schema/blocks/text.py index fbeae118..3bee9470 100644 --- a/marker/v2/schema/blocks/text.py +++ b/marker/v2/schema/blocks/text.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Text(Block): diff --git a/marker/v2/schema/blocks/toc.py b/marker/v2/schema/blocks/toc.py index bd4ba6fb..11796336 100644 --- a/marker/v2/schema/blocks/toc.py +++ b/marker/v2/schema/blocks/toc.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class TableOfContents(Block): diff --git a/marker/v2/schema/groups/__init__.py b/marker/v2/schema/groups/__init__.py index bc2c2091..269440e6 100644 --- a/marker/v2/schema/groups/__init__.py +++ b/marker/v2/schema/groups/__init__.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block from marker.v2.schema.groups.figure import FigureGroup from marker.v2.schema.groups.table import TableGroup from marker.v2.schema.groups.list import ListGroup diff --git a/marker/v2/schema/groups/figure.py b/marker/v2/schema/groups/figure.py index 23bfdf96..03209110 100644 --- a/marker/v2/schema/groups/figure.py +++ b/marker/v2/schema/groups/figure.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class FigureGroup(Block): diff --git a/marker/v2/schema/groups/list.py b/marker/v2/schema/groups/list.py index e31ac664..c480a76b 100644 --- a/marker/v2/schema/groups/list.py +++ b/marker/v2/schema/groups/list.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class ListGroup(Block): diff --git a/marker/v2/schema/groups/page.py b/marker/v2/schema/groups/page.py index 9625c7bb..0e270632 100644 --- a/marker/v2/schema/groups/page.py +++ b/marker/v2/schema/groups/page.py @@ -2,7 +2,7 @@ from PIL import Image -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block from marker.v2.schema.polygon import PolygonBox diff --git a/marker/v2/schema/groups/picture.py b/marker/v2/schema/groups/picture.py index 65f692cf..ba750ee7 100644 --- a/marker/v2/schema/groups/picture.py +++ b/marker/v2/schema/groups/picture.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class PictureGroup(Block): diff --git a/marker/v2/schema/groups/table.py b/marker/v2/schema/groups/table.py index 451493b1..9f80647c 100644 --- a/marker/v2/schema/groups/table.py +++ b/marker/v2/schema/groups/table.py @@ -1,4 +1,4 @@ -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class TableGroup(Block): diff --git a/marker/v2/schema/polygon.py b/marker/v2/schema/polygon.py index 75a3edfc..9548f48e 100644 --- a/marker/v2/schema/polygon.py +++ b/marker/v2/schema/polygon.py @@ -103,7 +103,7 @@ def rescale(self, old_size, new_size): for corner in new_corners: corner[0] = corner[0] * width_scaler corner[1] = corner[1] * height_scaler - self.polygon = new_corners + return PolygonBox(polygon=new_corners) def fit_to_bounds(self, bounds): new_corners = copy.deepcopy(self.polygon) diff --git a/marker/v2/schema/text/line.py b/marker/v2/schema/text/line.py index 09d6a0a2..c06c7d2d 100644 --- a/marker/v2/schema/text/line.py +++ b/marker/v2/schema/text/line.py @@ -1,6 +1,6 @@ from typing import List -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block from marker.v2.schema.text.span import Span diff --git a/marker/v2/schema/text/span.py b/marker/v2/schema/text/span.py index 568b82c5..a0af73b2 100644 --- a/marker/v2/schema/text/span.py +++ b/marker/v2/schema/text/span.py @@ -1,6 +1,6 @@ from typing import List, Literal -from marker.v2.schema import Block +from marker.v2.schema.blocks import Block class Span(Block): From 6cb51df692ecc0598af569184a8fb39ad0c79e16 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 14 Nov 2024 15:04:40 -0500 Subject: [PATCH 4/5] Add tests, cleanup impls --- marker/v2/models.py | 60 +++++++++++++++++++++++++++++ marker/v2/processors/table.py | 8 ++-- marker/v2/schema/blocks/__init__.py | 56 ++------------------------- marker/v2/schema/blocks/base.py | 49 +++++++++++++++++++++++ marker/v2/schema/groups/__init__.py | 2 +- tests/conftest.py | 39 ++++++++++++++++--- tests/test_equation_processor.py | 12 ++++++ tests/test_structure.py | 1 - tests/test_table_processor.py | 16 ++++++++ 9 files changed, 179 insertions(+), 64 deletions(-) create mode 100644 marker/v2/models.py create mode 100644 marker/v2/schema/blocks/base.py create mode 100644 tests/test_equation_processor.py create mode 100644 tests/test_table_processor.py diff --git a/marker/v2/models.py b/marker/v2/models.py new file mode 100644 index 00000000..4936cc42 --- /dev/null +++ b/marker/v2/models.py @@ -0,0 +1,60 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS + + +from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor +from surya.model.layout.model import load_model as load_layout_model +from surya.model.layout.processor import load_processor as load_layout_processor +from texify.model.model import load_model as load_texify_model +from texify.model.processor import load_processor as load_texify_processor +from marker.settings import settings +from surya.model.recognition.model import load_model as load_recognition_model +from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.table_rec.model import load_model as load_table_model +from surya.model.table_rec.processor import load_processor as load_table_processor + + +def setup_table_rec_model(device=None, dtype=None): + if device: + table_model = load_table_model(device=device, dtype=dtype) + else: + table_model = load_table_model() + table_model.processor = load_table_processor() + return table_model + + +def setup_recognition_model(device=None, dtype=None): + if device: + rec_model = load_recognition_model(device=device, dtype=dtype) + else: + rec_model = load_recognition_model() + rec_model.processor = load_recognition_processor() + return rec_model + + +def setup_detection_model(device=None, dtype=None): + if device: + model = load_detection_model(device=device, dtype=dtype) + else: + model = load_detection_model() + model.processor = load_detection_processor() + return model + + +def setup_texify_model(device=None, dtype=None): + if device: + texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype) + else: + texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) + texify_model.processor = load_texify_processor() + return texify_model + + +def setup_layout_model(device=None, dtype=None): + if device: + model = load_layout_model(device=device, dtype=dtype) + else: + model = load_layout_model() + model.processor = load_layout_processor() + return model \ No newline at end of file diff --git a/marker/v2/processors/table.py b/marker/v2/processors/table.py index ee06105e..135cf122 100644 --- a/marker/v2/processors/table.py +++ b/marker/v2/processors/table.py @@ -53,10 +53,10 @@ def __call__(self, document: Document): "img_size": page.highres_image.size }) - lst_format = zip( - *(t[key] for key in ["table_image", "table_bbox", "img_size", "text_lines"]) - for t in table_data - ) + lst_format = zip(*( + [t[key] for t in table_data] + for key in ["table_image", "table_bbox", "img_size", "text_lines"] + )) cells, needs_ocr = get_cells( *lst_format, diff --git a/marker/v2/schema/blocks/__init__.py b/marker/v2/schema/blocks/__init__.py index f5341362..120b3ad4 100644 --- a/marker/v2/schema/blocks/__init__.py +++ b/marker/v2/schema/blocks/__init__.py @@ -1,10 +1,6 @@ from __future__ import annotations -from typing import Optional, List - -from pydantic import BaseModel, ConfigDict - -from marker.v2.schema import PolygonBox +from marker.v2.schema.blocks.base import Block from marker.v2.schema.blocks.caption import Caption from marker.v2.schema.blocks.code import Code from marker.v2.schema.blocks.figure import Figure @@ -21,55 +17,11 @@ from marker.v2.schema.blocks.table import Table from marker.v2.schema.blocks.text import Text from marker.v2.schema.blocks.toc import TableOfContents -from marker.v2.schema.document import Document -from marker.v2.schema.text.line import Line -from marker.v2.schema.text.span import Span - - -class Block(BaseModel): - polygon: PolygonBox - block_type: Optional[str] = None - block_id: Optional[int] = None - page_id: Optional[int] = None - structure: List[str] | None = None # The top-level page structure, which is the block ids in order - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @property - def _id(self): - page_path = f"/page/{self.page_id}" - if self.block_id is not None: - return f"{page_path}/{self.block_type}/{self.block_id}" - else: - return page_path - - def add_structure(self, block: Block): - if self.structure is None: - self.structure = [block._id] - else: - self.structure.append(block._id) - - def raw_text(self, document: Document): - if self.structure is None: - return 0 - - text = "" - for block_id in self.structure: - block = document.get_block(block_id) - if isinstance(block, Span): - text += block.text - else: - text += block.raw_text(document) - if isinstance(block, Line): - text += "\n" - return text - - LAYOUT_BLOCK_REGISTRY = { v.model_fields['block_type'].default: v for k, v in locals().items() if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - and not v.model_fields['block_type'].default.endswith("Group") + and issubclass(v, Block) + and v != Block # Exclude the base Block class + and not v.model_fields['block_type'].default.endswith("Group") } diff --git a/marker/v2/schema/blocks/base.py b/marker/v2/schema/blocks/base.py new file mode 100644 index 00000000..efebd196 --- /dev/null +++ b/marker/v2/schema/blocks/base.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Optional, List + +from pydantic import BaseModel, ConfigDict + +from marker.v2.schema import PolygonBox + + +class Block(BaseModel): + polygon: PolygonBox + block_type: Optional[str] = None + block_id: Optional[int] = None + page_id: Optional[int] = None + structure: List[str] | None = None # The top-level page structure, which is the block ids in order + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def _id(self): + page_path = f"/page/{self.page_id}" + if self.block_id is not None: + return f"{page_path}/{self.block_type}/{self.block_id}" + else: + return page_path + + def add_structure(self, block: Block): + if self.structure is None: + self.structure = [block._id] + else: + self.structure.append(block._id) + + def raw_text(self, document) -> str: + from marker.v2.schema.text.line import Line + from marker.v2.schema.text.span import Span + + if self.structure is None: + return "" + + text = "" + for block_id in self.structure: + block = document.get_block(block_id) + if isinstance(block, Span): + text += block.text + else: + text += block.raw_text(document) + if isinstance(block, Line): + text += "\n" + return text diff --git a/marker/v2/schema/groups/__init__.py b/marker/v2/schema/groups/__init__.py index 269440e6..dcf6a170 100644 --- a/marker/v2/schema/groups/__init__.py +++ b/marker/v2/schema/groups/__init__.py @@ -1,4 +1,4 @@ -from marker.v2.schema.blocks import Block +from marker.v2.schema.blocks.base import Block from marker.v2.schema.groups.figure import FigureGroup from marker.v2.schema.groups.table import TableGroup from marker.v2.schema.groups.list import ListGroup diff --git a/tests/conftest.py b/tests/conftest.py index c9a361f4..e919e5b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,9 @@ import datasets import pytest -from surya.model.layout.model import load_model -from surya.model.layout.processor import load_processor +from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \ + setup_detection_model from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.providers.pdf import PdfProvider @@ -13,10 +13,37 @@ @pytest.fixture(scope="session") def layout_model(): - layout_model = load_model() - layout_model.processor = load_processor() - yield layout_model - del layout_model + layout_m = setup_layout_model() + yield layout_m + del layout_m + + +@pytest.fixture(scope="session") +def detection_model(): + detection_m = setup_detection_model() + yield detection_m + del detection_m + + +@pytest.fixture(scope="session") +def texify_model(): + texify_m = setup_texify_model() + yield texify_m + del texify_m + + +@pytest.fixture(scope="session") +def recognition_model(): + ocr_m = setup_recognition_model() + yield ocr_m + del ocr_m + + +@pytest.fixture(scope="session") +def table_rec_model(): + table_rec_m = setup_table_rec_model() + yield table_rec_m + del table_rec_m @pytest.fixture(scope="session") diff --git a/tests/test_equation_processor.py b/tests/test_equation_processor.py new file mode 100644 index 00000000..870ba545 --- /dev/null +++ b/tests/test_equation_processor.py @@ -0,0 +1,12 @@ +from marker.v2.processors.equation import EquationProcessor + + +def test_equation_processor(pdf_document, texify_model): + processor = EquationProcessor(texify_model) + + pdf_document.pages = [pdf_document.pages[0]] + processor(pdf_document) + + for block in pdf_document.pages[0].children: + if block.block_type == "Equation": + assert block.latex is not None diff --git a/tests/test_structure.py b/tests/test_structure.py index 8e65ada6..836048a5 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -1,5 +1,4 @@ from marker.v2.builders.structure import StructureBuilder -from tests.utils import setup_pdf_document def test_structure_builder(pdf_document): diff --git a/tests/test_table_processor.py b/tests/test_table_processor.py new file mode 100644 index 00000000..7233c6e3 --- /dev/null +++ b/tests/test_table_processor.py @@ -0,0 +1,16 @@ +from tabled.schema import SpanTableCell + +from marker.v2.processors.table import TableProcessor + + +def test_table_processor(pdf_document, detection_model, recognition_model, table_rec_model): + processor = TableProcessor(detection_model, recognition_model, table_rec_model) + + pdf_document.pages = [pdf_document.pages[5]] + processor(pdf_document) + + for block in pdf_document.pages[0].children: + if block.block_type == "Table": + assert block.cells is not None + assert len(block.cells) > 0 + assert isinstance(block.cells[0], SpanTableCell) From aa66523f498b871a42267ba1cb7fa88667d61ed8 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 14 Nov 2024 15:07:13 -0500 Subject: [PATCH 5/5] Update structure --- marker/v2/schema/blocks/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/marker/v2/schema/blocks/base.py b/marker/v2/schema/blocks/base.py index efebd196..7e37da1f 100644 --- a/marker/v2/schema/blocks/base.py +++ b/marker/v2/schema/blocks/base.py @@ -25,6 +25,8 @@ def _id(self): return page_path def add_structure(self, block: Block): + self.polygon = self.polygon.merge([block.polygon]) + if self.structure is None: self.structure = [block._id] else: