Skip to content

Commit

Permalink
add optimization from training development
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 committed Oct 8, 2024
1 parent c7f3ebb commit 22d43df
Showing 1 changed file with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def refine_rows(rows, tokens, score_threshold):
"""

if len(tokens) > 0:
rows = nms_by_containment(rows, tokens, overlap_threshold=0.5)
rows = nms_by_containment(rows, tokens, overlap_threshold=0.5, _early_exit_vertical=True)
remove_objects_without_content(tokens, rows)
else:
rows = nms(rows, match_criteria="object2_overlap", match_threshold=0.5, keep_higher=True)
Expand All @@ -171,9 +171,11 @@ def refine_columns(columns, tokens, score_threshold):
return columns


def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5):
def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5, _early_exit_vertical=False):
"""
Non-maxima suppression (NMS) of objects based on shared containment of other objects.
_early_exit_vertical: see `slot_into_containers`
"""
container_objects = sort_objects_by_score(container_objects)
num_objects = len(container_objects)
Expand All @@ -185,6 +187,7 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5
overlap_threshold=overlap_threshold,
unique_assignment=True,
forced_assignment=False,
_early_exit_vertical=_early_exit_vertical,
)

for object2_num in range(1, num_objects):
Expand All @@ -202,11 +205,21 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5


def slot_into_containers(
container_objects, package_objects, overlap_threshold=0.5, unique_assignment=True, forced_assignment=False
container_objects,
package_objects,
overlap_threshold=0.5,
unique_assignment=True,
forced_assignment=False,
_early_exit_vertical=False, # yes. see docstring.
):
"""
Slot a collection of objects into the container they occupy most
(the container which holds the largest fraction of the object).
_early_exit_vertical controls the dimension along which to sort
container objects for the purposes of optimizing the quadratic loop.
True -> sort by y-coord, False -> sort by x-coord. We only really
set this to True when dealing with rows.
"""
best_match_scores = []

Expand All @@ -216,20 +229,34 @@ def slot_into_containers(
if len(container_objects) == 0 or len(package_objects) == 0:
return container_assignments, package_assignments, best_match_scores

if _early_exit_vertical:
sorted_co = sorted(enumerate(container_objects), key=lambda x: x[1]["bbox"][1])
else:
sorted_co = sorted(enumerate(container_objects), key=lambda x: x[1]["bbox"][0])

match_scores = defaultdict(dict)
for package_num, package in enumerate(package_objects):
match_scores = []
package_rect = BoundingBox(*package["bbox"])
if package_rect.is_empty():
continue
package_area = package_rect.area
for container_num, container in enumerate(container_objects):
for container_num, container in sorted_co:
# If the container starts after the package ends, break
if not _early_exit_vertical and container["bbox"][0] > package["bbox"][2]:
break
elif _early_exit_vertical and container["bbox"][1] > package["bbox"][3]:
break
container_rect = BoundingBox(*container["bbox"])
intersect_area = container_rect.intersect(package_rect).area
overlap_fraction = intersect_area / package_area
match_scores.append({"container": container, "container_num": container_num, "score": overlap_fraction})

sorted_match_scores = sort_objects_by_score(match_scores)
# Don't sort if you don't have to
if unique_assignment:
sorted_match_scores = [max(match_scores, key=lambda x: x["score"])]
else:
sorted_match_scores = sort_objects_by_score(match_scores)

best_match_score = sorted_match_scores[0]

Expand Down

0 comments on commit 22d43df

Please sign in to comment.