diff --git a/.gitignore b/.gitignore index bc1ad2aa..6f036496 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ experiments test_data training wandb +*.dat # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/marker/convert.py b/marker/convert.py index 4c6c06a2..c74f4458 100644 --- a/marker/convert.py +++ b/marker/convert.py @@ -58,7 +58,7 @@ def convert_single_pdf( } if filetype == "other": # We can't process this file - return "", out_meta + return "", {}, out_meta # Get initial text blocks from the pdf doc = pdfium.PdfDocument(fname) @@ -85,7 +85,7 @@ def convert_single_pdf( out_meta["ocr_stats"] = ocr_stats if len([b for p in pages for b in p.blocks]) == 0: print(f"Could not extract any text blocks for {fname}") - return "", out_meta + return "", {}, out_meta surya_layout(doc, pages, layout_model, batch_multiplier=batch_multiplier) flush_cuda_memory() diff --git a/marker/models.py b/marker/models.py index 26bc0b19..3593006b 100644 --- a/marker/models.py +++ b/marker/models.py @@ -9,50 +9,69 @@ from surya.model.ordering.processor import load_processor as load_order_processor -def setup_recognition_model(langs): - rec_model = load_recognition_model(langs=langs) +def setup_recognition_model(langs, device, dtype): + if device: + rec_model = load_recognition_model(langs=langs, device=device, dtype=dtype) + else: + rec_model = load_recognition_model(langs=langs) rec_processor = load_recognition_processor() rec_model.processor = rec_processor return rec_model -def setup_detection_model(): - model = segformer.load_model() +def setup_detection_model(device, dtype): + if device: + model = segformer.load_model(device=device, dtype=dtype) + else: + model = segformer.load_model() + processor = segformer.load_processor() model.processor = processor return model -def setup_texify_model(): - texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) +def setup_texify_model(device, dtype): + if device: + texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype) + else: + texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) texify_processor = load_texify_processor() texify_model.processor = texify_processor return texify_model -def setup_layout_model(): - model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) +def setup_layout_model(device, dtype): + if device: + model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=device, dtype=dtype) + else: + model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) processor = segformer.load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) model.processor = processor return model -def setup_order_model(): - model = load_order_model() +def setup_order_model(device, dtype): + if device: + model = load_order_model(device=device, dtype=dtype) + else: + model = load_order_model() processor = load_order_processor() model.processor = processor return model -def load_all_models(langs=None): +def load_all_models(langs=None, device=None, dtype=None, force_load_ocr=False): + if device is not None: + assert dtype is not None, "Must provide dtype if device is provided" + # langs is optional list of languages to prune from recognition MoE model - detection = setup_detection_model() - layout = setup_layout_model() - order = setup_order_model() - edit = load_editing_model() + detection = setup_detection_model(device, dtype) + layout = setup_layout_model(device, dtype) + order = setup_order_model(device, dtype) + edit = load_editing_model(device, dtype) # Only load recognition model if we'll need it for all pdfs - ocr = setup_recognition_model(langs) if (settings.OCR_ENGINE == "surya" and settings.OCR_ALL_PAGES) else None - texify = setup_texify_model() + ocr = setup_recognition_model(langs, device, dtype) if ((settings.OCR_ENGINE == "surya" and settings.OCR_ALL_PAGES) or force_load_ocr) else None + texify = setup_texify_model(device, dtype) model_lst = [texify, layout, order, edit, detection, ocr] return model_lst \ No newline at end of file diff --git a/marker/postprocessors/editor.py b/marker/postprocessors/editor.py index 43a4e48e..9b73bedc 100644 --- a/marker/postprocessors/editor.py +++ b/marker/postprocessors/editor.py @@ -16,14 +16,21 @@ def get_batch_size(): return 6 -def load_editing_model(): +def load_editing_model(device, dtype): if not settings.ENABLE_EDITOR_MODEL: return None - model = T5ForTokenClassification.from_pretrained( + if device: + model = T5ForTokenClassification.from_pretrained( settings.EDITOR_MODEL_NAME, - torch_dtype=settings.MODEL_DTYPE, - ).to(settings.TORCH_DEVICE_MODEL) + torch_dtype=dtype, + device=device, + ) + else: + model = T5ForTokenClassification.from_pretrained( + settings.EDITOR_MODEL_NAME, + torch_dtype=settings.MODEL_DTYPE, + ).to(settings.TORCH_DEVICE_MODEL) model.eval() model.config.label2id = { diff --git a/poetry.lock b/poetry.lock index bd76c2be..84be0da5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3714,13 +3714,13 @@ streamlit = ">=0.63" [[package]] name = "surya-ocr" -version = "0.4.4" +version = "0.4.5" description = "OCR, layout, reading order, and line detection in 90+ languages" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,!=3.8.*,>=3.9" files = [ - {file = "surya_ocr-0.4.4-py3-none-any.whl", hash = "sha256:5ff52f11f0f13218566bcd402697bf84336a488cbc314ebc02ba234255df96ef"}, - {file = "surya_ocr-0.4.4.tar.gz", hash = "sha256:00d6376186e9a68f366a2c149e233f3d68b6080df48065e756a97e739464fce9"}, + {file = "surya_ocr-0.4.5-py3-none-any.whl", hash = "sha256:a670387307ffc2951d7a8b3a005dec6c80ababb6107ea7d0a097cb33f0c6f355"}, + {file = "surya_ocr-0.4.5.tar.gz", hash = "sha256:1b59858c2caa476a6e27579f0747dafe62ad8fb866af7bdcdb19ee01b0d7915c"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index a3b7a436..c4fb46a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "marker-pdf" -version = "0.2.5" +version = "0.2.6" description = "Convert PDF to markdown with high speed and accuracy." authors = ["Vik Paruchuri "] readme = "README.md"