Skip to content

Commit

Permalink
Merge pull request #311 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Fix layout bugs
  • Loading branch information
VikParuchuri authored Oct 22, 2024
2 parents 31f6ee6 + 6463bdf commit 6bee852
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 131 deletions.
8 changes: 5 additions & 3 deletions marker/cleaners/headings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def bucket_headings(line_heights, num_levels=settings.HEADING_LEVEL_COUNT):
data_labels = np.concatenate([data, labels.reshape(-1, 1)], axis=1)
data_labels = np.sort(data_labels, axis=0)

cluster_means = {label: np.mean(data_labels[data_labels[:, 1] == label, 0]) for label in np.unique(labels)}
cluster_means = {int(label): float(np.mean(data_labels[data_labels[:, 1] == label, 0])) for label in np.unique(labels)}
label_max = None
label_min = None
heading_ranges = []
prev_cluster = None
for row in data_labels:
value, label = row
value = float(value)
label = int(label)
if prev_cluster is not None and label != prev_cluster:
prev_cluster_mean = cluster_means[prev_cluster]
cluster_mean = cluster_means[label]
Expand All @@ -93,7 +95,7 @@ def bucket_headings(line_heights, num_levels=settings.HEADING_LEVEL_COUNT):
if label_min is not None:
heading_ranges.append((label_min, label_max))

heading_ranges = sorted(heading_ranges, key=lambda x: x[0], reverse=True)
heading_ranges = sorted(heading_ranges, reverse=True)

return heading_ranges

Expand All @@ -114,7 +116,7 @@ def infer_heading_levels(pages: List[Page], height_tol=.99):
if block.block_type not in ["Title", "Section-header"]:
continue

block_heights = [l.height for l in block.lines] # Account for rotation
block_heights = [l.height for l in block.lines]
if len(block_heights) > 0:
avg_height = sum(block_heights) / len(block_heights)
for idx, (min_height, max_height) in enumerate(heading_ranges):
Expand Down
23 changes: 11 additions & 12 deletions marker/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

from marker.pdf.images import render_image

warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings

import os
Expand Down Expand Up @@ -82,12 +85,16 @@ def convert_single_pdf(
for page_idx in range(start_page):
doc.del_page(0)

max_len = min(len(pages), len(doc))
lowres_images = [render_image(doc[pnum], dpi=settings.SURYA_DETECTOR_DPI) for pnum in range(max_len)]

# Unpack models from list
texify_model, layout_model, order_model, detection_model, ocr_model, table_rec_model = model_lst

# Identify text lines on pages
surya_detection(doc, pages, detection_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()
# Identify text lines, layout, reading order
surya_detection(lowres_images, pages, detection_model, batch_multiplier=batch_multiplier)
surya_layout(lowres_images, pages, layout_model, batch_multiplier=batch_multiplier)
surya_order(lowres_images, pages, order_model, batch_multiplier=batch_multiplier)

# OCR pages as needed
pages, ocr_stats = run_ocr(doc, pages, langs, ocr_model, batch_multiplier=batch_multiplier, ocr_all_pages=ocr_all_pages)
Expand All @@ -98,21 +105,13 @@ def convert_single_pdf(
print(f"Could not extract any text blocks for {fname}")
return "", {}, out_meta

surya_layout(doc, pages, layout_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()

# Find headers and footers
bad_span_ids = filter_header_footer(pages)
out_meta["block_stats"] = {"header_footer": len(bad_span_ids)}

# Add block types in
# Add block types from layout and sort from reading order
annotate_block_types(pages)

# Find reading order for blocks
# Sort blocks by reading order
surya_order(doc, pages, order_model, batch_multiplier=batch_multiplier)
sort_blocks_in_reading_order(pages)
flush_cuda_memory()

# Dump debug data if flags are set
draw_page_debug_images(fname, pages)
Expand Down
35 changes: 22 additions & 13 deletions marker/layout/layout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict
from collections import defaultdict, Counter
from typing import List

from surya.layout import batch_layout_detection
Expand All @@ -18,8 +18,7 @@ def get_batch_size():
return 6


def surya_layout(doc, pages: List[Page], layout_model, batch_multiplier=1):
images = [render_image(doc[pnum], dpi=settings.SURYA_LAYOUT_DPI) for pnum in range(len(pages))]
def surya_layout(images: list, pages: List[Page], layout_model, batch_multiplier=1):
text_detection_results = [p.text_lines for p in pages]

processor = layout_model.processor
Expand Down Expand Up @@ -54,7 +53,7 @@ def annotate_block_types(pages: List[Page]):
for i, block in enumerate(page.blocks):
if block.block_type is not None:
continue
min_dist = 0
min_dist = None
min_dist_idx = None
for j, block2 in enumerate(page.blocks):
if j == i or block2.block_type is None:
Expand All @@ -64,8 +63,8 @@ def annotate_block_types(pages: List[Page]):
min_dist = dist
min_dist_idx = j
for line in block2.lines:
dist = block2.distance(line.bbox)
if min_dist_idx is None or dist < min_dist:
dist = block.distance(line.bbox)
if dist < min_dist:
min_dist = dist
min_dist_idx = j

Expand All @@ -76,29 +75,39 @@ def annotate_block_types(pages: List[Page]):
if block.block_type is None:
block.block_type = "Text"

def get_layout_label(block_labels: List[str]):
counter = Counter(block_labels)
return counter.most_common(1)[0][0]

def generate_block(block, block_labels):
block.bbox = bbox_from_lines(block.lines)
block.block_type = get_layout_label(block_labels)
return block

# Merge blocks together, preserving pdf order
curr_layout_idx = None
curr_layout_block = None
curr_block_labels = []
new_blocks = []
for i in range(len(page.blocks)):
if i not in max_intersections:
if i not in max_intersections or max_intersections[i][0] == 0:
if curr_layout_block is not None:
curr_layout_block.bbox = bbox_from_lines(curr_layout_block.lines)
new_blocks.append(curr_layout_block)
new_blocks.append(generate_block(curr_layout_block, curr_block_labels))
curr_layout_block = None
curr_layout_idx = None
curr_block_labels = []
new_blocks.append(page.blocks[i])
elif max_intersections[i][1] != curr_layout_idx:
if curr_layout_block is not None:
curr_layout_block.bbox = bbox_from_lines(curr_layout_block.lines)
new_blocks.append(curr_layout_block)
new_blocks.append(generate_block(curr_layout_block, curr_block_labels))
curr_layout_block = page.blocks[i].copy()
curr_layout_idx = max_intersections[i][1]
curr_block_labels = [page.blocks[i].block_type]
else:
curr_layout_block.lines.extend(page.blocks[i].lines)
curr_block_labels.append(page.blocks[i].block_type)

if curr_layout_block is not None:
curr_layout_block.bbox = bbox_from_lines(curr_layout_block.lines)
new_blocks.append(curr_layout_block)
new_blocks.append(generate_block(curr_layout_block, curr_block_labels))

page.blocks = new_blocks
4 changes: 1 addition & 3 deletions marker/layout/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def get_batch_size():
return 6


def surya_order(doc, pages: List[Page], order_model, batch_multiplier=1):
images = [render_image(doc[pnum], dpi=settings.SURYA_ORDER_DPI) for pnum in range(len(pages))]

def surya_order(images: list, pages: List[Page], order_model, batch_multiplier=1):
# Get bboxes for all pages
bboxes = []
for page in pages:
Expand Down
4 changes: 1 addition & 3 deletions marker/ocr/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ def get_batch_size():
return 4


def surya_detection(doc: PdfDocument, pages: List[Page], det_model, batch_multiplier=1):
def surya_detection(images: list, pages: List[Page], det_model, batch_multiplier=1):
processor = det_model.processor
max_len = min(len(pages), len(doc))
images = [render_image(doc[pnum], dpi=settings.SURYA_DETECTOR_DPI) for pnum in range(max_len)]

predictions = batch_text_detection(images, det_model, processor, batch_size=int(get_batch_size() * batch_multiplier))
for (page, pred) in zip(pages, predictions):
Expand Down
Loading

0 comments on commit 6bee852

Please sign in to comment.