Skip to content

Commit

Permalink
handle nms-postprocess in edge cases (#370)
Browse files Browse the repository at this point in the history
* handle when only 1 prediction selected after nms postprocess

* update

* handle postprocess when no prediction

* fix class agnostic nms

* handle when single pred is left after postprocess
  • Loading branch information
fcakyon authored Feb 12, 2022
1 parent fb0c9ba commit e01dad4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 171 deletions.
190 changes: 20 additions & 170 deletions sahi/postprocess/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@


def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5):
"""
Apply non-maximum suppression to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
predictions: (tensor) The location preds for the image
along with the class predscores, Shape: [num_boxes,5].
match_metric: (str) IOU or IOS
match_threshold: (float) The overlap thresh for
match metric.
Returns:
A list of filtered indexes, Shape: [ ,]
"""
scores = predictions[:, 4].squeeze()
category_ids = predictions[:, 5].squeeze()
keep_mask = torch.zeros_like(category_ids, dtype=torch.bool)
Expand All @@ -21,7 +33,9 @@ def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_thre
curr_keep_indices = nms(predictions[curr_indices], match_metric, match_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
# sort selected indices by their scores
keep_indices = keep_indices[scores[keep_indices].sort(descending=True)[1]].tolist()
return keep_indices


def nms(
Expand Down Expand Up @@ -130,147 +144,6 @@ def nms(
return keep


def batched_nmm_torch0(
object_predictions_as_tensor: torch.tensor,
object_prediction_list: ObjectPredictionList,
match_metric: str = "IOU",
match_threshold: float = 0.5,
):
category_ids = object_predictions_as_tensor[:, 5].squeeze()
keep_mask = torch.zeros_like(category_ids, dtype=torch.bool)
for category_id in torch.unique(category_ids):
curr_indices = torch.where(category_ids == category_id)[0]
curr_object_predictions = object_prediction_list[curr_indices]
curr_keep_indices, curr_object_predictions = nmm_torch0(
object_predictions_as_tensor[curr_indices], curr_object_predictions, match_metric, match_threshold
)
keep_mask[curr_indices[curr_keep_indices]] = True
object_prediction_list[curr_indices] = curr_object_predictions
keep_indices = torch.where(keep_mask)[0]
scores = object_prediction_list.totensor()[keep_indices, 4].squeeze()
return keep_indices[scores.sort(descending=True)[1]], object_prediction_list


def nmm_torch0(
object_predictions_as_tensor: torch.tensor,
object_prediction_list: ObjectPredictionList,
match_metric: str = "IOU",
match_threshold: float = 0.5,
):
"""
Apply non-maximum merging to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
object_predictions_as_tensor: (tensor) The location preds for the image
along with the class predscores, Shape: [num_boxes,5].
object_predictions_as_list: ObjectPredictionList Object prediction objects
to be merged.
match_metric: (str) IOU or IOS
match_threshold: (float) The overlap thresh for
match metric.
Returns:
A list of filtered indexes, Shape: [ ,]
"""

# we extract coordinates for every
# prediction box present in P
x1 = object_predictions_as_tensor[:, 0]
y1 = object_predictions_as_tensor[:, 1]
x2 = object_predictions_as_tensor[:, 2]
y2 = object_predictions_as_tensor[:, 3]

# we extract the confidence scores as well
scores = object_predictions_as_tensor[:, 4]

# calculate area of every block in P
areas = (x2 - x1) * (y2 - y1)

# sort the prediction boxes in P
# according to their confidence scores
order = scores.argsort()

# initialise an empty list for
# filtered prediction boxes
keep = []

while len(order) > 0:
# extract the index of the
# prediction with highest score
# we call this prediction S
idx = order[-1]

# push S in filtered predictions list
keep.append(idx.tolist())

# remove S from P
order = order[:-1]

# sanity check
if len(order) == 0:
break

# select coordinates of BBoxes according to
# the indices in order
xx1 = torch.index_select(x1, dim=0, index=order)
xx2 = torch.index_select(x2, dim=0, index=order)
yy1 = torch.index_select(y1, dim=0, index=order)
yy2 = torch.index_select(y2, dim=0, index=order)

# find the coordinates of the intersection boxes
xx1 = torch.max(xx1, x1[idx])
yy1 = torch.max(yy1, y1[idx])
xx2 = torch.min(xx2, x2[idx])
yy2 = torch.min(yy2, y2[idx])

# find height and width of the intersection boxes
w = xx2 - xx1
h = yy2 - yy1

# take max with 0.0 to avoid negative w and h
# due to non-overlapping boxes
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)

# find the intersection area
inter = w * h

# find the areas of BBoxes according the indices in order
rem_areas = torch.index_select(areas, dim=0, index=order)

if match_metric == "IOU":
# find the union of every prediction T in P
# with the prediction S
# Note that areas[idx] represents area of S
union = (rem_areas - inter) + areas[idx]
# find the IoU of every prediction in P with S
match_metric_value = inter / union

elif match_metric == "IOS":
# find the smaller area of every prediction T in P
# with the prediction S
# Note that areas[idx] represents area of S
smaller = torch.min(rem_areas, areas[idx])
# find the IoS of every prediction in P with S
match_metric_value = inter / smaller
else:
raise ValueError()

# keep the boxes with IoU/IoS less than thresh_iou
mask = match_metric_value < match_threshold
matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,))
unmatched_indices = order[(mask == True).nonzero().flatten()]

# merge matching predictions
for matched_box_ind in matched_box_indices:
object_prediction_list[idx] = merge_object_prediction_pair(
object_prediction_list[idx].tolist(), object_prediction_list[matched_box_ind].tolist()
)

# update box pool
order = unmatched_indices[scores[unmatched_indices].argsort()]
return keep, object_prediction_list


def batched_greedy_nmm(
object_predictions_as_tensor: torch.tensor,
match_metric: str = "IOU",
Expand Down Expand Up @@ -608,34 +481,9 @@ def __call__(
object_predictions_as_torch, match_threshold=self.match_threshold, match_metric=self.match_metric
)

selected_object_predictions = object_prediction_list[keep.tolist()].tolist()

return selected_object_predictions


class NMMPostprocess0(PostprocessPredictions):
def __call__(
self,
object_predictions: List[ObjectPrediction],
):
object_prediction_list = ObjectPredictionList(object_predictions)
object_predictions_as_torch = object_prediction_list.totensor()
if self.class_agnostic:
keep, object_prediction_list = nmm_torch0(
object_predictions_as_torch,
object_prediction_list,
match_threshold=self.match_threshold,
match_metric=self.match_metric,
)
else:
keep, object_prediction_list = batched_nmm_torch0(
object_predictions_as_torch,
object_prediction_list,
match_threshold=self.match_threshold,
match_metric=self.match_metric,
)

selected_object_predictions = object_prediction_list[keep.tolist()].tolist()
selected_object_predictions = object_prediction_list[keep].tolist()
if not isinstance(selected_object_predictions, list):
selected_object_predictions = [selected_object_predictions]

return selected_object_predictions

Expand Down Expand Up @@ -744,5 +592,7 @@ def __call__(
)

selected_object_predictions = object_prediction_list[keep].tolist()
if not isinstance(selected_object_predictions, list):
selected_object_predictions = [selected_object_predictions]

return selected_object_predictions
3 changes: 2 additions & 1 deletion sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def get_sliced_prediction(
)

# merge matching predictions
object_prediction_list = postprocess(object_prediction_list)
if len(object_prediction_list) > 0:
object_prediction_list = postprocess(object_prediction_list)

return PredictionResult(
image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
Expand Down

0 comments on commit e01dad4

Please sign in to comment.