diff --git a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py index 1baf6a07f..59b122233 100644 --- a/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py +++ b/lib/sycamore/sycamore/transforms/table_structure/table_transformers.py @@ -144,7 +144,7 @@ def refine_rows(rows, tokens, score_threshold): """ if len(tokens) > 0: - rows = nms_by_containment(rows, tokens, overlap_threshold=0.5) + rows = nms_by_containment(rows, tokens, overlap_threshold=0.5, _early_exit_vertical=True) remove_objects_without_content(tokens, rows) else: rows = nms(rows, match_criteria="object2_overlap", match_threshold=0.5, keep_higher=True) @@ -171,9 +171,11 @@ def refine_columns(columns, tokens, score_threshold): return columns -def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5): +def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5, _early_exit_vertical=False): """ Non-maxima suppression (NMS) of objects based on shared containment of other objects. + + _early_exit_vertical: see `slot_into_containers` """ container_objects = sort_objects_by_score(container_objects) num_objects = len(container_objects) @@ -185,6 +187,7 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5 overlap_threshold=overlap_threshold, unique_assignment=True, forced_assignment=False, + _early_exit_vertical=_early_exit_vertical, ) for object2_num in range(1, num_objects): @@ -202,11 +205,21 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5 def slot_into_containers( - container_objects, package_objects, overlap_threshold=0.5, unique_assignment=True, forced_assignment=False + container_objects, + package_objects, + overlap_threshold=0.5, + unique_assignment=True, + forced_assignment=False, + _early_exit_vertical=False, # yes. see docstring. ): """ Slot a collection of objects into the container they occupy most (the container which holds the largest fraction of the object). + + _early_exit_vertical controls the dimension along which to sort + container objects for the purposes of optimizing the quadratic loop. + True -> sort by y-coord, False -> sort by x-coord. We only really + set this to True when dealing with rows. """ best_match_scores = [] @@ -216,6 +229,11 @@ def slot_into_containers( if len(container_objects) == 0 or len(package_objects) == 0: return container_assignments, package_assignments, best_match_scores + if _early_exit_vertical: + sorted_co = sorted(enumerate(container_objects), key=lambda x: x[1]["bbox"][1]) + else: + sorted_co = sorted(enumerate(container_objects), key=lambda x: x[1]["bbox"][0]) + match_scores = defaultdict(dict) for package_num, package in enumerate(package_objects): match_scores = [] @@ -223,13 +241,22 @@ def slot_into_containers( if package_rect.is_empty(): continue package_area = package_rect.area - for container_num, container in enumerate(container_objects): + for container_num, container in sorted_co: + # If the container starts after the package ends, break + if not _early_exit_vertical and container["bbox"][0] > package["bbox"][2]: + break + elif _early_exit_vertical and container["bbox"][1] > package["bbox"][3]: + break container_rect = BoundingBox(*container["bbox"]) intersect_area = container_rect.intersect(package_rect).area overlap_fraction = intersect_area / package_area match_scores.append({"container": container, "container_num": container_num, "score": overlap_fraction}) - sorted_match_scores = sort_objects_by_score(match_scores) + # Don't sort if you don't have to + if unique_assignment: + sorted_match_scores = [max(match_scores, key=lambda x: x["score"])] + else: + sorted_match_scores = sort_objects_by_score(match_scores) best_match_score = sorted_match_scores[0]