Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v2' into dev-mose/marker-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 14, 2024
2 parents c8e9850 + 5cfd486 commit 55e097a
Show file tree
Hide file tree
Showing 51 changed files with 494 additions and 148 deletions.
3 changes: 3 additions & 0 deletions marker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
def flush_cuda_memory():
if settings.TORCH_DEVICE_MODEL == "cuda":
torch.cuda.empty_cache()



8 changes: 4 additions & 4 deletions marker/v2/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from pydantic import BaseModel

from marker.v2.util import assign_config


class BaseBuilder:
def __init__(self, config: Optional[BaseModel] = None):
if config:
for k in config.model_fields:
setattr(self, k, config[k])
def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)

def __call__(self, data, *args, **kwargs):
raise NotImplementedError
4 changes: 2 additions & 2 deletions marker/v2/builders/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder):
return document

def build_document(self, provider: PdfProvider):
if provider.config.page_range is None:
if provider.page_range is None:
page_range = range(len(provider))
else:
page_range = provider.config.page_range
page_range = provider.page_range
assert max(page_range) < len(provider) and min(page_range) >= 0, "Invalid page range"

initial_pages = [
Expand Down
11 changes: 6 additions & 5 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class LayoutBuilder(BaseBuilder):
batch_size = None

def __init__(self, layout_model, config=None):
self.layout_model = layout_model

Expand All @@ -23,10 +25,9 @@ def __call__(self, document: Document, provider: PdfProvider):
self.add_blocks_to_pages(document.pages, layout_results)
self.merge_blocks(document.pages, provider.page_lines)

@classmethod
def get_batch_size(cls):
if settings.LAYOUT_BATCH_SIZE is not None:
return settings.LAYOUT_BATCH_SIZE
def get_batch_size(self):
if self.batch_size is not None:
return self.batch_size
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 6
return 6
Expand All @@ -37,7 +38,7 @@ def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
[p.lowres_image for p in pages],
self.layout_model,
processor,
batch_size=int(LayoutBuilder.get_batch_size())
batch_size=int(self.get_batch_size())
)
return layout_results

Expand Down
10 changes: 5 additions & 5 deletions marker/v2/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from pydantic import BaseModel

from marker.v2.util import assign_config


class BaseConverter:
def __init__(self, config: Optional[BaseModel] = None):
if config:
for k in config.model_fields:
setattr(self, k, config[k])
def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)

def __call__(self):
def __call__(self, *args, **kwargs):
raise NotImplementedError
17 changes: 9 additions & 8 deletions marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from typing import List, Optional

from pydantic import BaseModel
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor

Expand All @@ -8,20 +9,20 @@
from marker.v2.builders.structure import StructureBuilder
from marker.v2.converters import BaseConverter
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema.config.pdf import PdfProviderConfig


class PdfConverter(BaseConverter):
filepath: str
page_range: List[int] | None = None

def __call__(self):
pdf_provider = PdfProvider(self.config.filepath, PdfProviderConfig())
def __init__(self, config: Optional[BaseModel] = None):
super().__init__(config)

layout_model = load_model()
layout_model.processor = load_processor()
layout_builder = LayoutBuilder(layout_model)
self.layout_model = layout_model

def __call__(self, filepath: str, page_range: List[int] | None = None):
pdf_provider = PdfProvider(filepath, {"page_range": page_range})

layout_builder = LayoutBuilder(self.layout_model)
document = DocumentBuilder()(pdf_provider, layout_builder)
StructureBuilder()(document)

60 changes: 60 additions & 0 deletions marker/v2/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS


from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.layout.model import load_model as load_layout_model
from surya.model.layout.processor import load_processor as load_layout_processor
from texify.model.model import load_model as load_texify_model
from texify.model.processor import load_processor as load_texify_processor
from marker.settings import settings
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor


def setup_table_rec_model(device=None, dtype=None):
if device:
table_model = load_table_model(device=device, dtype=dtype)
else:
table_model = load_table_model()
table_model.processor = load_table_processor()
return table_model


def setup_recognition_model(device=None, dtype=None):
if device:
rec_model = load_recognition_model(device=device, dtype=dtype)
else:
rec_model = load_recognition_model()
rec_model.processor = load_recognition_processor()
return rec_model


def setup_detection_model(device=None, dtype=None):
if device:
model = load_detection_model(device=device, dtype=dtype)
else:
model = load_detection_model()
model.processor = load_detection_processor()
return model


def setup_texify_model(device=None, dtype=None):
if device:
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
else:
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
texify_model.processor = load_texify_processor()
return texify_model


def setup_layout_model(device=None, dtype=None):
if device:
model = load_layout_model(device=device, dtype=dtype)
else:
model = load_layout_model()
model.processor = load_layout_processor()
return model
14 changes: 10 additions & 4 deletions marker/v2/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from pydantic import BaseModel

from marker.v2.schema.document import Document
from marker.v2.util import assign_config


class BaseProcessor:
def __init__(self, config: Optional[BaseModel] = None):
if config:
for k in config.model_fields:
setattr(self, k, config[k])
block_type: str | None = None # What block type this processor is responsible for

def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)

def __call__(self, document: Document, *args, **kwargs):
raise NotImplementedError
94 changes: 94 additions & 0 deletions marker/v2/processors/equation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,100 @@
from typing import Optional, List

from pydantic import BaseModel
from tqdm import tqdm

from marker.settings import settings
from marker.v2.processors import BaseProcessor
from marker.v2.schema.document import Document

from texify.inference import batch_inference


class EquationProcessor(BaseProcessor):
block_type = "Equation"
model_max_length = 384
batch_size = None
token_buffer = 256

def __init__(self, texify_model, config: Optional[BaseModel] = None):
super().__init__(config)

self.texify_model = texify_model

def __call__(self, document: Document):
equation_data = []

for page in document.pages:
for block in page.children:
if block.block_type != self.block_type:
continue
image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.lowres_image.size)
image = page.lowres_image.crop(image_poly.bbox).convert("RGB")
raw_text = block.raw_text(document)
token_count = self.get_total_texify_tokens(raw_text)

equation_data.append({
"image": image,
"block_id": block._id,
"token_count": token_count
})

predictions = self.get_latex_batched(equation_data)
for prediction, equation_d in zip(predictions, equation_data):
conditions = [
self.get_total_texify_tokens(prediction) < self.model_max_length,
# Make sure we didn't get to the overall token max, indicates run-on
len(prediction) > equation_d["token_count"] * .4,
len(prediction.strip()) > 0
]
if not all(conditions):
continue

block = document.get_block_by_id(equation_d["block_id"])
block.latex = prediction

def get_batch_size(self):
if self.batch_size is not None:
return self.batch_size
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 6
elif settings.TORCH_DEVICE_MODEL == "mps":
return 6
return 2

def get_latex_batched(self, equation_data: List[dict]):
predictions = [""] * len(equation_data)
batch_size = self.get_batch_size()

for i in tqdm(range(0, len(equation_data), batch_size), desc="Recognizing equations"):
# Dynamically set max length to save inference time
min_idx = i
max_idx = min(min_idx + batch_size, len(equation_data))

batch_equations = equation_data[min_idx:max_idx]
max_length = max([eq["token_count"] for eq in batch_equations])
max_length = min(max_length, self.model_max_length)
max_length += self.token_buffer

batch_images = [eq["image"] for eq in batch_equations]

model_output = batch_inference(
batch_images,
self.texify_model,
self.texify_model.processor,
max_tokens=max_length
)

for j, output in enumerate(model_output):
token_count = self.get_total_texify_tokens(output)
if token_count >= max_length - 1:
output = ""

image_idx = i + j
predictions[image_idx] = output
return predictions

def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])
Loading

0 comments on commit 55e097a

Please sign in to comment.