diff --git a/test/test_utils.py b/test/test_utils.py index f2597162905..cb6aa7cf6d1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,6 +11,7 @@ import torchvision.utils as utils from common_utils import assert_equal, cpu_and_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageColor +from torchvision.transforms.v2.functional import to_dtype PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) @@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device): torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) +def test_draw_segmentation_masks_dtypes(): + num_masks, h, w = 2, 100, 100 + + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + out_uint8 = utils.draw_segmentation_masks(img_uint8, masks) + + assert img_uint8 is not out_uint8 + assert out_uint8.dtype == torch.uint8 + + img_float = to_dtype(img_uint8, torch.float, scale=True) + out_float = utils.draw_segmentation_masks(img_float, masks) + + assert img_float is not out_float + assert out_float.is_floating_point() + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_draw_segmentation_masks_errors(device): h, w = 10, 10 diff --git a/torchvision/utils.py b/torchvision/utils.py index 6ec19a0e0a1..630eada5cbc 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -10,6 +10,7 @@ import torch from PIL import Image, ImageColor, ImageDraw, ImageFont + __all__ = [ "make_grid", "save_image", @@ -262,10 +263,10 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. @@ -282,8 +283,8 @@ def draw_segmentation_masks( _log_api_usage_once(draw_segmentation_masks) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -303,10 +304,10 @@ def draw_segmentation_masks( warnings.warn("masks doesn't contain any mask. No mask was drawn") return image - out_dtype = torch.uint8 + original_dtype = image.dtype colors = [ - torch.tensor(color, dtype=out_dtype, device=image.device) - for color in _parse_colors(colors, num_objects=num_masks) + torch.tensor(color, dtype=original_dtype, device=image.device) + for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype) ] img_to_draw = image.detach().clone() @@ -315,7 +316,8 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - return out.to(out_dtype) + # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype + return out.to(original_dtype) @torch.no_grad() @@ -516,6 +518,7 @@ def _parse_colors( colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]], *, num_objects: int, + dtype: torch.dtype = torch.uint8, ) -> List[Tuple[int, int, int]]: """ Parses a specification of colors for a set of objects. @@ -552,7 +555,10 @@ def _parse_colors( else: # colors specifies a single color for all objects colors = [colors] * num_objects - return [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors] + colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors] + if dtype.is_floating_point: # [0, 255] -> [0, 1] + colors = [tuple(v / 255 for v in color) for color in colors] + return colors def _log_api_usage_once(obj: Any) -> None: