Skip to content

Commit

Permalink
[Cherry-pick for 0.17.1] Fix convert_bounding_box_format when passing…
Browse files Browse the repository at this point in the history
… strings (#8267)
  • Loading branch information
NicolasHug authored Feb 9, 2024
1 parent a0e8e6c commit 20610ed
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
17 changes: 17 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3397,6 +3397,23 @@ def test_transform(self, old_format, new_format, format_type):
make_bounding_boxes(format=old_format),
)

@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_strings(self, old_format, new_format):
# Non-regression test for https://github.com/pytorch/vision/issues/8258
input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50))
expected = self._reference_convert_bounding_box_format(input, new_format)

old_format = old_format.name
new_format = new_format.name

out_functional = F.convert_bounding_box_format(input, new_format=new_format)
out_functional_tensor = F.convert_bounding_box_format(
input.as_subclass(torch.Tensor), old_format=old_format, new_format=new_format
)
out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input)
for out in (out_functional, out_functional_tensor, out_transform):
assert_equal(out, expected)

def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
return tv_tensors.wrap(
torchvision.ops.box_convert(
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/v2/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class ConvertBoundingBoxFormat(Transform):

def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format]
self.format = format

def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type]


class ClampBoundingBoxes(Transform):
Expand Down
5 changes: 5 additions & 0 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def convert_bounding_box_format(
if not torch.jit.is_scripting():
_log_api_usage_once(convert_bounding_box_format)

if isinstance(old_format, str):
old_format = BoundingBoxFormat[old_format.upper()]
if isinstance(new_format, str):
new_format = BoundingBoxFormat[new_format.upper()]

if torch.jit.is_scripting() or is_pure_tensor(inpt):
if old_format is None:
raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
Expand Down

0 comments on commit 20610ed

Please sign in to comment.