Skip to content

Commit

Permalink
Add min_area to SanitizeBoundingBox (#7735)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
antoinebrl and NicolasHug authored Jun 4, 2024
1 parent f7d9e75 commit 1023987
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
27 changes: 15 additions & 12 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


class TestSanitizeBoundingBoxes:
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10, min_area=10):
boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
Expand All @@ -5816,17 +5816,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
([-1, 1, 10, 20], False), # any < 0
([0, 0, -1, 20], False), # any < 0
([0, 0, -10, -1], False), # any < 0
([0, 0, min_size, 10], True), # H < min_size
([0, 0, 10, min_size], True), # W < min_size
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
([1, 1, 30, 20], True),
([0, 0, 10, 10], True),
([1, 1, 30, 20], True),
([0, 0, min_size, 10], min_size * 10 >= min_area), # H < min_size
([0, 0, 10, min_size], min_size * 10 >= min_area), # W < min_size
([0, 0, W, H], W * H >= min_area),
([1, 1, 30, 20], 29 * 19 >= min_area),
([0, 0, 10, 10], 9 * 9 >= min_area),
([1, 1, 30, 20], 29 * 19 >= min_area),
]

random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, expected_valid_mask = zip(*boxes_and_validity)

boxes = tv_tensors.BoundingBoxes(
boxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
Expand All @@ -5835,7 +5834,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):

return boxes, expected_valid_mask

@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("min_size, min_area", ((1, 1), (10, 1), (10, 101)))
@pytest.mark.parametrize(
"labels_getter",
(
Expand All @@ -5848,15 +5847,15 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
),
)
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):
def test_transform(self, min_size, min_area, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size, min_area=min_area)
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]

labels = torch.arange(boxes.shape[0])
Expand All @@ -5880,7 +5879,9 @@ def test_transform(self, min_size, labels_getter, sample_type):
img = sample.pop("image")
sample = (img, sample)

out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
out = transforms.SanitizeBoundingBoxes(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(
sample
)

if sample_type is tuple:
out_image = out[0]
Expand Down Expand Up @@ -5977,6 +5978,8 @@ def test_errors_transform(self):

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
with pytest.raises(ValueError, match="min_area must be >= 1"):
transforms.SanitizeBoundingBoxes(min_area=0)
with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBoxes(labels_getter=12)

Expand Down
12 changes: 10 additions & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform):
This transform removes bounding boxes and their associated labels/masks that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
Expand All @@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform):
cases.
Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_size (float, optional): The size below which bounding boxes are removed. Default is 1.
min_area (float, optional): The area below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if
Expand All @@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform):
def __init__(
self,
min_size: float = 1.0,
min_area: float = 1.0,
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
) -> None:
super().__init__()
Expand All @@ -387,6 +389,10 @@ def __init__(
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size

if min_area < 1:
raise ValueError(f"min_area must be >= 1, got {min_area}.")
self.min_area = min_area

self.labels_getter = labels_getter
self._labels_getter = _parse_labels_getter(labels_getter)

Expand Down Expand Up @@ -422,7 +428,9 @@ def forward(self, *inputs: Any) -> Any:
format=boxes.format,
canvas_size=boxes.canvas_size,
min_size=self.min_size,
min_area=self.min_area,
)

params = dict(valid=valid, labels=labels)
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]

Expand Down
15 changes: 11 additions & 4 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,13 @@ def sanitize_bounding_boxes(
format: Optional[tv_tensors.BoundingBoxFormat] = None,
canvas_size: Optional[Tuple[int, int]] = None,
min_size: float = 1.0,
min_area: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
This removes bounding boxes that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
Expand All @@ -346,6 +347,7 @@ def sanitize_bounding_boxes(
(size of the corresponding image/video).
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_area (float, optional) The area below which bounding boxes are removed. Default is 1.
Returns:
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
Expand All @@ -361,7 +363,7 @@ def sanitize_bounding_boxes(
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format.upper()]
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size, min_area=min_area
)
bounding_boxes = bounding_boxes[valid]
else:
Expand All @@ -374,7 +376,11 @@ def sanitize_bounding_boxes(
"Leave those to None or pass bounding_boxes as a pure tensor."
)
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
min_size=min_size,
min_area=min_area,
)
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)

Expand All @@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask(
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
min_size: float = 1.0,
min_area: float = 1.0,
) -> torch.Tensor:

bounding_boxes = _convert_bounding_box_format(
Expand All @@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask(

image_h, image_w = canvas_size
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = canvas_size
Expand Down

0 comments on commit 1023987

Please sign in to comment.