diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index c3c2b33068d..7ee8c7cd3c1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3397,6 +3397,23 @@ def test_transform(self, old_format, new_format, format_type): make_bounding_boxes(format=old_format), ) + @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats) + def test_strings(self, old_format, new_format): + # Non-regression test for https://github.com/pytorch/vision/issues/8258 + input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50)) + expected = self._reference_convert_bounding_box_format(input, new_format) + + old_format = old_format.name + new_format = new_format.name + + out_functional = F.convert_bounding_box_format(input, new_format=new_format) + out_functional_tensor = F.convert_bounding_box_format( + input.as_subclass(torch.Tensor), old_format=old_format, new_format=new_format + ) + out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input) + for out in (out_functional, out_functional_tensor, out_transform): + assert_equal(out, expected) + def _reference_convert_bounding_box_format(self, bounding_boxes, new_format): return tv_tensors.wrap( torchvision.ops.box_convert( diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index badb630082c..01a356f46f5 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -17,12 +17,10 @@ class ConvertBoundingBoxFormat(Transform): def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None: super().__init__() - if isinstance(format, str): - format = tv_tensors.BoundingBoxFormat[format] self.format = format def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: - return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value] + return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type] class ClampBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e27aa18fc60..b90e5fb7b5b 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -214,6 +214,11 @@ def convert_bounding_box_format( if not torch.jit.is_scripting(): _log_api_usage_once(convert_bounding_box_format) + if isinstance(old_format, str): + old_format = BoundingBoxFormat[old_format.upper()] + if isinstance(new_format, str): + new_format = BoundingBoxFormat[new_format.upper()] + if torch.jit.is_scripting() or is_pure_tensor(inpt): if old_format is None: raise ValueError("For pure tensor inputs, `old_format` has to be passed.")