-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/v2' into dev-mose/marker-v2
- Loading branch information
Showing
51 changed files
with
494 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,6 @@ | |
def flush_cuda_memory(): | ||
if settings.TORCH_DEVICE_MODEL == "cuda": | ||
torch.cuda.empty_cache() | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
|
||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS | ||
|
||
|
||
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor | ||
from surya.model.layout.model import load_model as load_layout_model | ||
from surya.model.layout.processor import load_processor as load_layout_processor | ||
from texify.model.model import load_model as load_texify_model | ||
from texify.model.processor import load_processor as load_texify_processor | ||
from marker.settings import settings | ||
from surya.model.recognition.model import load_model as load_recognition_model | ||
from surya.model.recognition.processor import load_processor as load_recognition_processor | ||
from surya.model.table_rec.model import load_model as load_table_model | ||
from surya.model.table_rec.processor import load_processor as load_table_processor | ||
|
||
|
||
def setup_table_rec_model(device=None, dtype=None): | ||
if device: | ||
table_model = load_table_model(device=device, dtype=dtype) | ||
else: | ||
table_model = load_table_model() | ||
table_model.processor = load_table_processor() | ||
return table_model | ||
|
||
|
||
def setup_recognition_model(device=None, dtype=None): | ||
if device: | ||
rec_model = load_recognition_model(device=device, dtype=dtype) | ||
else: | ||
rec_model = load_recognition_model() | ||
rec_model.processor = load_recognition_processor() | ||
return rec_model | ||
|
||
|
||
def setup_detection_model(device=None, dtype=None): | ||
if device: | ||
model = load_detection_model(device=device, dtype=dtype) | ||
else: | ||
model = load_detection_model() | ||
model.processor = load_detection_processor() | ||
return model | ||
|
||
|
||
def setup_texify_model(device=None, dtype=None): | ||
if device: | ||
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype) | ||
else: | ||
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) | ||
texify_model.processor = load_texify_processor() | ||
return texify_model | ||
|
||
|
||
def setup_layout_model(device=None, dtype=None): | ||
if device: | ||
model = load_layout_model(device=device, dtype=dtype) | ||
else: | ||
model = load_layout_model() | ||
model.processor = load_layout_processor() | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,100 @@ | ||
from typing import Optional, List | ||
|
||
from pydantic import BaseModel | ||
from tqdm import tqdm | ||
|
||
from marker.settings import settings | ||
from marker.v2.processors import BaseProcessor | ||
from marker.v2.schema.document import Document | ||
|
||
from texify.inference import batch_inference | ||
|
||
|
||
class EquationProcessor(BaseProcessor): | ||
block_type = "Equation" | ||
model_max_length = 384 | ||
batch_size = None | ||
token_buffer = 256 | ||
|
||
def __init__(self, texify_model, config: Optional[BaseModel] = None): | ||
super().__init__(config) | ||
|
||
self.texify_model = texify_model | ||
|
||
def __call__(self, document: Document): | ||
equation_data = [] | ||
|
||
for page in document.pages: | ||
for block in page.children: | ||
if block.block_type != self.block_type: | ||
continue | ||
image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.lowres_image.size) | ||
image = page.lowres_image.crop(image_poly.bbox).convert("RGB") | ||
raw_text = block.raw_text(document) | ||
token_count = self.get_total_texify_tokens(raw_text) | ||
|
||
equation_data.append({ | ||
"image": image, | ||
"block_id": block._id, | ||
"token_count": token_count | ||
}) | ||
|
||
predictions = self.get_latex_batched(equation_data) | ||
for prediction, equation_d in zip(predictions, equation_data): | ||
conditions = [ | ||
self.get_total_texify_tokens(prediction) < self.model_max_length, | ||
# Make sure we didn't get to the overall token max, indicates run-on | ||
len(prediction) > equation_d["token_count"] * .4, | ||
len(prediction.strip()) > 0 | ||
] | ||
if not all(conditions): | ||
continue | ||
|
||
block = document.get_block_by_id(equation_d["block_id"]) | ||
block.latex = prediction | ||
|
||
def get_batch_size(self): | ||
if self.batch_size is not None: | ||
return self.batch_size | ||
elif settings.TORCH_DEVICE_MODEL == "cuda": | ||
return 6 | ||
elif settings.TORCH_DEVICE_MODEL == "mps": | ||
return 6 | ||
return 2 | ||
|
||
def get_latex_batched(self, equation_data: List[dict]): | ||
predictions = [""] * len(equation_data) | ||
batch_size = self.get_batch_size() | ||
|
||
for i in tqdm(range(0, len(equation_data), batch_size), desc="Recognizing equations"): | ||
# Dynamically set max length to save inference time | ||
min_idx = i | ||
max_idx = min(min_idx + batch_size, len(equation_data)) | ||
|
||
batch_equations = equation_data[min_idx:max_idx] | ||
max_length = max([eq["token_count"] for eq in batch_equations]) | ||
max_length = min(max_length, self.model_max_length) | ||
max_length += self.token_buffer | ||
|
||
batch_images = [eq["image"] for eq in batch_equations] | ||
|
||
model_output = batch_inference( | ||
batch_images, | ||
self.texify_model, | ||
self.texify_model.processor, | ||
max_tokens=max_length | ||
) | ||
|
||
for j, output in enumerate(model_output): | ||
token_count = self.get_total_texify_tokens(output) | ||
if token_count >= max_length - 1: | ||
output = "" | ||
|
||
image_idx = i + j | ||
predictions[image_idx] = output | ||
return predictions | ||
|
||
def get_total_texify_tokens(self, text): | ||
tokenizer = self.texify_model.processor.tokenizer | ||
tokens = tokenizer(text) | ||
return len(tokens["input_ids"]) |
Oops, something went wrong.