Skip to content

Commit

Permalink
More flexibility with model devices
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 16, 2024
1 parent 7edbece commit 7a6a82c
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ experiments
test_data
training
wandb
*.dat

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 2 additions & 2 deletions marker/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def convert_single_pdf(
}

if filetype == "other": # We can't process this file
return "", out_meta
return "", {}, out_meta

# Get initial text blocks from the pdf
doc = pdfium.PdfDocument(fname)
Expand All @@ -85,7 +85,7 @@ def convert_single_pdf(
out_meta["ocr_stats"] = ocr_stats
if len([b for p in pages for b in p.blocks]) == 0:
print(f"Could not extract any text blocks for {fname}")
return "", out_meta
return "", {}, out_meta

surya_layout(doc, pages, layout_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()
Expand Down
53 changes: 36 additions & 17 deletions marker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,69 @@
from surya.model.ordering.processor import load_processor as load_order_processor


def setup_recognition_model(langs):
rec_model = load_recognition_model(langs=langs)
def setup_recognition_model(langs, device, dtype):
if device:
rec_model = load_recognition_model(langs=langs, device=device, dtype=dtype)
else:
rec_model = load_recognition_model(langs=langs)
rec_processor = load_recognition_processor()
rec_model.processor = rec_processor
return rec_model


def setup_detection_model():
model = segformer.load_model()
def setup_detection_model(device, dtype):
if device:
model = segformer.load_model(device=device, dtype=dtype)
else:
model = segformer.load_model()

processor = segformer.load_processor()
model.processor = processor
return model


def setup_texify_model():
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
def setup_texify_model(device, dtype):
if device:
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
else:
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
texify_processor = load_texify_processor()
texify_model.processor = texify_processor
return texify_model


def setup_layout_model():
model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
def setup_layout_model(device, dtype):
if device:
model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=device, dtype=dtype)
else:
model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = segformer.load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
model.processor = processor
return model


def setup_order_model():
model = load_order_model()
def setup_order_model(device, dtype):
if device:
model = load_order_model(device=device, dtype=dtype)
else:
model = load_order_model()
processor = load_order_processor()
model.processor = processor
return model


def load_all_models(langs=None):
def load_all_models(langs=None, device=None, dtype=None, force_load_ocr=False):
if device is not None:
assert dtype is not None, "Must provide dtype if device is provided"

# langs is optional list of languages to prune from recognition MoE model
detection = setup_detection_model()
layout = setup_layout_model()
order = setup_order_model()
edit = load_editing_model()
detection = setup_detection_model(device, dtype)
layout = setup_layout_model(device, dtype)
order = setup_order_model(device, dtype)
edit = load_editing_model(device, dtype)

# Only load recognition model if we'll need it for all pdfs
ocr = setup_recognition_model(langs) if (settings.OCR_ENGINE == "surya" and settings.OCR_ALL_PAGES) else None
texify = setup_texify_model()
ocr = setup_recognition_model(langs, device, dtype) if ((settings.OCR_ENGINE == "surya" and settings.OCR_ALL_PAGES) or force_load_ocr) else None
texify = setup_texify_model(device, dtype)
model_lst = [texify, layout, order, edit, detection, ocr]
return model_lst
15 changes: 11 additions & 4 deletions marker/postprocessors/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ def get_batch_size():
return 6


def load_editing_model():
def load_editing_model(device, dtype):
if not settings.ENABLE_EDITOR_MODEL:
return None

model = T5ForTokenClassification.from_pretrained(
if device:
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
torch_dtype=dtype,
device=device,
)
else:
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
model.eval()

model.config.label2id = {
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "marker-pdf"
version = "0.2.5"
version = "0.2.6"
description = "Convert PDF to markdown with high speed and accuracy."
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 7a6a82c

Please sign in to comment.