From f36dd2c94fd8084fec8908d03d51c0928df00135 Mon Sep 17 00:00:00 2001 From: Mithra Date: Mon, 11 Dec 2023 03:12:32 +0530 Subject: [PATCH 01/20] support for float32 for draw_* --- torchvision/utils.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 6ec19a0e0a1..e1353ba5cb3 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,11 +164,11 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8/float32 between 0 and 255. If fill is True, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + image (Tensor): Tensor of shape (C x H x W) and dtype uint8/float32. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -185,15 +185,15 @@ def draw_bounding_boxes( font_size (int): The requested font size in points. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + img (Tensor[C, H, W]): Image Tensor of dtype uint8/float32 with bounding boxes plotted. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(draw_bounding_boxes) if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dtype != torch.uint8 or image.dtype != torch.float32: + raise ValueError(f"Tensor uint8/float32 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: @@ -249,7 +249,7 @@ def draw_bounding_boxes( margin = width + 1 draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype) @torch.no_grad() @@ -262,10 +262,10 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8/float32 between 0 and 255. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8/float32. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. @@ -282,8 +282,8 @@ def draw_segmentation_masks( _log_api_usage_once(draw_segmentation_masks) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dtype != torch.uint8 or image.dtype != torch.float32: + raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -303,7 +303,6 @@ def draw_segmentation_masks( warnings.warn("masks doesn't contain any mask. No mask was drawn") return image - out_dtype = torch.uint8 colors = [ torch.tensor(color, dtype=out_dtype, device=image.device) for color in _parse_colors(colors, num_objects=num_masks) @@ -315,7 +314,7 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - return out.to(out_dtype) + return out.to(image.dtype) @torch.no_grad() @@ -330,10 +329,10 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8/float32 between 0 and 255. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8/float32. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where, @@ -344,15 +343,15 @@ def draw_keypoints( width (int): Integer denoting width of line connecting keypoints. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor of dtype uint8/float32 with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(draw_keypoints) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dtype != torch.uint8 or image.dtype != torch.float32: + raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -387,7 +386,7 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype) # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization From f091cc7f4aa58cd14167cf9f81cbaa17f7a1a826 Mon Sep 17 00:00:00 2001 From: Mithra Date: Mon, 11 Dec 2023 03:44:33 +0530 Subject: [PATCH 02/20] cleaning up --- torchvision/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e1353ba5cb3..2594363477a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -192,7 +192,7 @@ def draw_bounding_boxes( _log_api_usage_once(draw_bounding_boxes) if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8 or image.dtype != torch.float32: + elif image.dtype not in {torch.uint8, torch.float32}: raise ValueError(f"Tensor uint8/float32 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") @@ -282,7 +282,7 @@ def draw_segmentation_masks( _log_api_usage_once(draw_segmentation_masks) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8 or image.dtype != torch.float32: + elif image.dtype not in {torch.uint8, torch.float32}: raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") @@ -350,7 +350,7 @@ def draw_keypoints( _log_api_usage_once(draw_keypoints) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8 or image.dtype != torch.float32: + elif image.dtype not in {torch.uint8, torch.float32}: raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") From 5672d8cfab71615be57d39d3f847efc9bfbe9ab3 Mon Sep 17 00:00:00 2001 From: Mithra Date: Mon, 11 Dec 2023 03:46:08 +0530 Subject: [PATCH 03/20] method description --- torchvision/utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 2594363477a..bf8d23e39bc 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,11 +164,11 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The values of the input image should be uint8/float32 between 0 and 255. + The image values should be in the uint8 or float32 that range between 0 and 255. If fill is True, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8/float32. + image (Tensor): Tensor of shape (C x H x W) and dtype uint8 or float32. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -185,7 +185,7 @@ def draw_bounding_boxes( font_size (int): The requested font size in points. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8/float32 with bounding boxes plotted. + img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with bounding boxes plotted. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -193,7 +193,7 @@ def draw_bounding_boxes( if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"Tensor uint8/float32 expected, got {image.dtype}") + raise ValueError(f"Tensor uint8 or float32 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: @@ -262,10 +262,10 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8/float32 between 0 and 255. + The image values should be in the uint8 or float32 that range between 0 and 255. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8/float32. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. @@ -283,7 +283,7 @@ def draw_segmentation_masks( if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") + raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -329,10 +329,10 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8/float32 between 0 and 255. + The image values should be in the uint8 or float32 that range between 0 and 255. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8/float32. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where, @@ -343,7 +343,7 @@ def draw_keypoints( width (int): Integer denoting width of line connecting keypoints. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8/float32 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -351,7 +351,7 @@ def draw_keypoints( if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"The image dtype must be uint8/float32, got {image.dtype}") + raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: From 301e42ba357ee411818dc1b414b745ce0c2b1c07 Mon Sep 17 00:00:00 2001 From: Mithra <50359897+GsnMithra@users.noreply.github.com> Date: Tue, 12 Dec 2023 02:43:56 +0530 Subject: [PATCH 04/20] Update torchvision/utils.py Co-authored-by: Nicolas Hug --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index bf8d23e39bc..e8f45f57a0a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,7 +164,7 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The image values should be in the uint8 or float32 that range between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1] If fill is True, Resulting Tensor should be saved as PNG image. Args: From 9146ebeecb17f7bd1f980447d3c4535de290c2f7 Mon Sep 17 00:00:00 2001 From: Mithra Date: Tue, 12 Dec 2023 02:50:28 +0530 Subject: [PATCH 05/20] floating point dtype check --- torchvision/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e8f45f57a0a..6b0f7dc81e3 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,7 +164,7 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The image values should be uint8 in [0, 255] or float in [0, 1] + The image values should be uint8 in [0, 255] or float in [0, 1]. If fill is True, Resulting Tensor should be saved as PNG image. Args: @@ -192,8 +192,8 @@ def draw_bounding_boxes( _log_api_usage_once(draw_bounding_boxes) if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"Tensor uint8 or float32 expected, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"Tensor uint8 or float expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: @@ -262,7 +262,7 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The image values should be in the uint8 or float32 that range between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. @@ -282,8 +282,8 @@ def draw_segmentation_masks( _log_api_usage_once(draw_segmentation_masks) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -329,7 +329,7 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The image values should be in the uint8 or float32 that range between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. @@ -350,8 +350,8 @@ def draw_keypoints( _log_api_usage_once(draw_keypoints) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype not in {torch.uint8, torch.float32}: - raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: From bb1cd08e93a5483defe779f8478a1a5a1f49748f Mon Sep 17 00:00:00 2001 From: Mithra Date: Tue, 12 Dec 2023 11:36:47 +0530 Subject: [PATCH 06/20] test_draw_segmentation_masks --- test/test_utils.py | 43 +++++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 2999e84bf8a..4b7e1a9c1fa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,8 +7,9 @@ import numpy as np import pytest import torch -import torchvision.transforms.functional as F import torchvision.utils as utils +import torchvision.transforms.functional as F +from torchvision.transforms.v2.functional import to_type from common_utils import assert_equal, cpu_and_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageColor @@ -207,8 +208,10 @@ def test_draw_no_boxes(): def test_draw_segmentation_masks(colors, alpha, device): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 - dtype = torch.uint8 - img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) + + img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) + img_float = to_dtype(img_uint8, torch.float32, scale=True) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device) # For testing we enforce that there's no overlap between the masks. The @@ -218,13 +221,25 @@ def test_draw_segmentation_masks(colors, alpha, device): overlap = masks[0] & masks[1] masks[:, overlap] = False - out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) - assert out.dtype == dtype - assert out is not img + # Test for uint8 input + out_uint8 = utils.draw_segmentation_masks(img_uint8, masks, colors=colors, alpha=alpha) + assert out_uint8.dtype == torch.uint8 + assert out_uint8 is not img_uint8 + + # Test for float input + out_float = utils.draw_segmentation_masks(img_float, masks, colors=colors, alpha=alpha) + assert torch.is_floating_point(out_float) + assert out_float is not img_float + + # Converting the float output to uint8 with scaling and asserting the + # uint8 output is close to the converted float output. + out_float_uint8 = to_dtype(out_float, torch.uint8, scale=True) + torch.testing.assert_close(out_uint8, out_float_uint8, rtol=1e-5) # Make sure the image didn't change where there's no mask masked_pixels = masks[0] | masks[1] - assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels]) + assert_equal(img_uint8[:, ~masked_pixels], out_uint8[:, ~masked_pixels]) + torch.testing.assert_close(img_float[:, ~masked_pixels], out_float[:, ~masked_pixels], rtol=1e-5) if colors is None: colors = utils._generate_color_palette(num_masks) @@ -235,15 +250,19 @@ def test_draw_segmentation_masks(colors, alpha, device): for mask, color in zip(masks, colors): if isinstance(color, str): color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=dtype, device=device) + color = torch.tensor(color, dtype=torch.uint8, device=device) if alpha == 1: - assert (out[:, mask] == color[:, None]).all() + assert (out_uint8[:, mask] == color[:, None]).all() + torch.testing.assert_close(out_float_uint8[:, mask], color[:, None].to(out_float_uint8.dtype), rtol=1e-5) elif alpha == 0: - assert (out[:, mask] == img[:, mask]).all() + assert (out_uint8[:, mask] == img_uint8[:, mask]).all() + torch.testing.assert_close(out_float_uint8[:, mask], out_float[:, mask], rtol=1e-5) - interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype) - torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) + interpolated_color = (img_uint8[:, mask].float() * (1 - alpha) + color[:, None] * alpha).to(torch.uint8) + torch.testing.assert_close(out_uint8[:, mask], interpolated_color, rtol=0.0, atol=1.0) + interpolated_color_float = (out_float[:, mask] * (1 - alpha) + color[:, None] * alpha) + torch.testing.assert_close(out_float_uint8[:, mask], interpolated_color_float, rtol=1e-5) @pytest.mark.parametrize("device", cpu_and_cuda()) From 8d486a8a1403f93d4c0c0e0bed5965cfc7690239 Mon Sep 17 00:00:00 2001 From: Mithra Date: Fri, 15 Dec 2023 00:53:12 +0530 Subject: [PATCH 07/20] unit test for draw_segmentation_masks dtypes --- test/test_utils.py | 61 ++++++++++++++++++++++---------------------- torchvision/utils.py | 13 +++++++--- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 4b7e1a9c1fa..b3c61a561f2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,7 @@ import torch import torchvision.utils as utils import torchvision.transforms.functional as F -from torchvision.transforms.v2.functional import to_type +from torchvision.transforms.v2.functional import to_dtype from common_utils import assert_equal, cpu_and_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageColor @@ -208,10 +208,8 @@ def test_draw_no_boxes(): def test_draw_segmentation_masks(colors, alpha, device): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 - - img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) - img_float = to_dtype(img_uint8, torch.float32, scale=True) - + dtype = torch.uint8 + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device) # For testing we enforce that there's no overlap between the masks. The @@ -221,25 +219,13 @@ def test_draw_segmentation_masks(colors, alpha, device): overlap = masks[0] & masks[1] masks[:, overlap] = False - # Test for uint8 input - out_uint8 = utils.draw_segmentation_masks(img_uint8, masks, colors=colors, alpha=alpha) - assert out_uint8.dtype == torch.uint8 - assert out_uint8 is not img_uint8 - - # Test for float input - out_float = utils.draw_segmentation_masks(img_float, masks, colors=colors, alpha=alpha) - assert torch.is_floating_point(out_float) - assert out_float is not img_float - - # Converting the float output to uint8 with scaling and asserting the - # uint8 output is close to the converted float output. - out_float_uint8 = to_dtype(out_float, torch.uint8, scale=True) - torch.testing.assert_close(out_uint8, out_float_uint8, rtol=1e-5) + out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) + assert out.dtype == dtype + assert out is not img # Make sure the image didn't change where there's no mask masked_pixels = masks[0] | masks[1] - assert_equal(img_uint8[:, ~masked_pixels], out_uint8[:, ~masked_pixels]) - torch.testing.assert_close(img_float[:, ~masked_pixels], out_float[:, ~masked_pixels], rtol=1e-5) + assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels]) if colors is None: colors = utils._generate_color_palette(num_masks) @@ -250,20 +236,35 @@ def test_draw_segmentation_masks(colors, alpha, device): for mask, color in zip(masks, colors): if isinstance(color, str): color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=torch.uint8, device=device) + color = torch.tensor(color, dtype=dtype, device=device) if alpha == 1: - assert (out_uint8[:, mask] == color[:, None]).all() - torch.testing.assert_close(out_float_uint8[:, mask], color[:, None].to(out_float_uint8.dtype), rtol=1e-5) + assert (out[:, mask] == color[:, None]).all() elif alpha == 0: - assert (out_uint8[:, mask] == img_uint8[:, mask]).all() - torch.testing.assert_close(out_float_uint8[:, mask], out_float[:, mask], rtol=1e-5) + assert (out[:, mask] == img[:, mask]).all() + + interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype) + torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) + + +def test_draw_segmentation_masks_dtypes(): + num_masks, h, w = 2, 100, 100 + + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + out_uint8 = utils.draw_segmentation_masks(img_uint8, masks) + + assert img_uint8 is not out_uint8 + assert out_uint8.dtype == torch.uint8 + + img_float = to_dtype(img_uint8, torch.float32, scale=True) + out_float = utils.draw_segmentation_masks(img_float, masks) - interpolated_color = (img_uint8[:, mask].float() * (1 - alpha) + color[:, None] * alpha).to(torch.uint8) - torch.testing.assert_close(out_uint8[:, mask], interpolated_color, rtol=0.0, atol=1.0) - interpolated_color_float = (out_float[:, mask] * (1 - alpha) + color[:, None] * alpha) - torch.testing.assert_close(out_float_uint8[:, mask], interpolated_color_float, rtol=1e-5) + assert img_float is not out_float + assert out_float.is_floating_point () + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1.0) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_draw_segmentation_masks_errors(device): diff --git a/torchvision/utils.py b/torchvision/utils.py index 6b0f7dc81e3..4c91affac66 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -265,7 +265,7 @@ def draw_segmentation_masks( The image values should be uint8 in [0, 255] or float in [0, 1]. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. @@ -297,6 +297,10 @@ def draw_segmentation_masks( if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") + original_dtype = image.dtype + if image.is_floating_point(): + image = (image * 255).to(torch.uint8) + num_masks = masks.size()[0] if num_masks == 0: @@ -304,7 +308,7 @@ def draw_segmentation_masks( return image colors = [ - torch.tensor(color, dtype=out_dtype, device=image.device) + torch.tensor(color, dtype=torch.uint8, device=image.device) for color in _parse_colors(colors, num_objects=num_masks) ] @@ -314,7 +318,10 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - return out.to(image.dtype) + if original_dtype in {torch.float16, torch.float32, torch.float64}: + out = out.float() / 255.0 + + return out.to(original_dtype) @torch.no_grad() From baec368e5f38f9df87b9b44367c981b9c961187f Mon Sep 17 00:00:00 2001 From: Mithra Date: Fri, 15 Dec 2023 01:38:32 +0530 Subject: [PATCH 08/20] cleaning up other method and checks --- .DS_Store | Bin 0 -> 10244 bytes .github/.DS_Store | Bin 0 -> 6148 bytes android/.DS_Store | Bin 0 -> 6148 bytes examples/.DS_Store | Bin 0 -> 6148 bytes gallery/.DS_Store | Bin 0 -> 6148 bytes packaging/.DS_Store | Bin 0 -> 6148 bytes references/.DS_Store | Bin 0 -> 8196 bytes test/.DS_Store | Bin 0 -> 6148 bytes test/test_utils.py | 6 +++--- torchvision/.DS_Store | Bin 0 -> 8196 bytes torchvision/utils.py | 2 +- 11 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 .DS_Store create mode 100644 .github/.DS_Store create mode 100644 android/.DS_Store create mode 100644 examples/.DS_Store create mode 100644 gallery/.DS_Store create mode 100644 packaging/.DS_Store create mode 100644 references/.DS_Store create mode 100644 test/.DS_Store create mode 100644 torchvision/.DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a8e2f4928f1c3680c02774e60779283faca812c2 GIT binary patch literal 10244 zcmeI1J7^S96o${_L1KKR@P&z2= zh`f#yy_<;=MAV3lWAh?Z5rx}%3zZtXXBVP}{xp!?+?g)qheArX!#Us_a1J;JoCD5* zzr_JOv$-@b^rNet1I_{GK-mGlKP0emOxjrVqq%jUGD`r&I_x$Dznl37V#3CxjYU7A z1;z8bBKloDyTtJP?%1EOx}=RoKlX2;97gDL7fkF@VH9l|J`10xalO4Lfw`%)PU+`_fn{7c?YAEHU zJ~}}?GCR@i!LULyJUI92k zoNKwPTP4Tmbw9SJ`$EkGrJ;U~uo&L3$G*P+MT}B1&CF!?Sf*HY;LqG@#d|Rc2wE1+i?@Reu7@mb zG}pvYDyTD{{Y!tPuRN#S2PXI0XWqPX%W?Rf{_4ExCZD$}&H?9ubHF+9XB^-+IFtmx z3G9sj|GyEr$~oX1_&*)+>QbGl4xF3)wjMf5{??wrb_5$2))oC|x}ahS?>C?0@dtm7 y=k+u1YvITRzmql={fIp%fBc^TZv7wol|a3hyY+uo`QiU}DRQl@)?)v_`u``aHcL(b literal 0 HcmV?d00001 diff --git a/.github/.DS_Store b/.github/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f844d354f60a38e0548a56da845039e3c90dcb17 GIT binary patch literal 6148 zcmeHKOG*Pl5Pcm3CMXhjF6Rj529q#B1O*|wabrk=A555Flq}rL+I`O9Df9p?Jb_!6 z?nHdm-ArckF}M-Ys?hyZ}T= z0;)i&fP5bU4#Ci4;?RCNIN26}Sftw;*V0Q!PT(>0m^kDensKQ_muh^AVO%=>fsYG4 zCJtRXjBh@Sf3oo%im^}U{(%jL2_5RJ3aA2k1yUV;=%fn#r~;qV CJ%ZT) literal 0 HcmV?d00001 diff --git a/android/.DS_Store b/android/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..554d09e34931977b20818b41191bea4f4f03eca1 GIT binary patch literal 6148 zcmeHKu}%U(5S8j} zQSlN{=_Q{QqiUyIbag(P_oHdQ%cb`9`Y|+IyfwPXgvq=ZpT{Tj*m8N$8TF_Mm98kI zKK0BgH{z}Ce~Nc`KJUDZcrL8r@?JK_-$%&%C{O~Oe)Z%D*MDxCLQykjf*U122DCC z`}k0{va&A}rK@B8P=}L>3`#EwhyuO>1@p7S_y5WK_rISczeE91;9n`Aiq%@Rf-Tv* yHMcpwYeTeiG&YXQ4C)kg_Bz%BzKVCzw80nh02o=!3}OZ*e+0A)(uo4Us=znRW4+}7 literal 0 HcmV?d00001 diff --git a/examples/.DS_Store b/examples/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5bb92b23617f060180db45ecdc460ba22f31bb4e GIT binary patch literal 6148 zcmeHKF-}7<47F*6zJw4R88b1kRpJ1l`kbH_Ky4*D^z;7!16vNj+KuaQ2M)kNc(z0P z10P0&z?PDi*s-1es>TTs@$7yzC7KXXg(k={7!Xkpnhx?}F|x#%&X;Y|4TSi;D%typ zmef(Vx#{o!_;GJ(+PYqB+66S#$@$CccKdwUx30dicI$rB2sb)Jwgz!}&vAm>9s6O0*? zp&lLR^a%hQL#=|b-YtPa6u_7<8A1bLa|N0!=O+f6JMU%L3r8E^*v7y~@6=k*L9m22y#&y!jk(5}%$ tWL=U01Y<45KrO|X$T@PVMzaSQ!yYpxLy01KbqD%IKnd~A8TbJP-T>o7M>GHc literal 0 HcmV?d00001 diff --git a/gallery/.DS_Store b/gallery/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4636c78a39945ef71256a8688e3a21fd1973ab8f GIT binary patch literal 6148 zcmeHKJ5Iwu5S z&9iwUS{?sU0seM}bVCdL?x_C#;H`G|NZuNBzssuhr`Pycsjz8Z(~;>m)8Q z1$g9OFa?jIzK-6m>f<4g?`6Ovm?O)7d9=k;cwiO7j=`dUCJ)f$9-Iow>U=(4+u}lA z&o66U8|Mu9ZHuk)w9HdY6;K6Kfn61l-HxJ@T_?~yRX`QkTLpN3uxN~l$5PO39q3#M z0PMo;4A)#saE|qucq|1m0yB~dlvLwO3?u3ATOXHrECnT76Q|3hb2vrkf7a0ba?Ut&NxCv(}U;@%zVexw`u* zdi0Ix+1O{H(Jk#`m6_8uT~S20AU7O5=p)NFjOt=S9x?OSqXiUuz%LD1 zL}y{E&*A2L&B4p}8B|A9jNiog6PnTpZ*M{KkVCT#4qFjEkKyH-M=5HfMtxKTAJ#d*gS zlH7>`qQJjWKsAzn(#0dWy|wgkyw_^@J)Di>3X3KMh1-tJ1#iVCaARoGd;yFd78X$h PlMewSgA}5`k1Fs9yyudx literal 0 HcmV?d00001 diff --git a/references/.DS_Store b/references/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c67d2d97b7c9173bc916db1ad4014e3fb0b17f2e GIT binary patch literal 8196 zcmeHMyG|QH6ukpPGOs8F1t5x*D5wz?gknJs1w|THG`wtH(k^SoCWwmd_yXure#8xB zqSPo=M9T-zaqi5po{@KLLz4*Kk!H?n=G>Vx$NSJ76Omd!t1S{u6Hy73@zXe(oWkQ= zJIavlnFA`s6E*0B4pFPL*TI$!aX=gp2gCt!KpglV9Kbi5OR3pko4yp~cjo4#k`%dr&r1*%U*WaEu2`2Zk0?gC?Am z2`6P&RyIRXdUeDFnocS-D6Kdk4s<%ex%)@t>k%^chL^uLk;f-=8XYZGs!3eep=$L< zFMgbq%llEX4_;~J>;1#^_1#+UnA)pj>c_6GT`WN#s741=rzRbHs#k6={ZUurd?nN9 z#K%cnQ9i?4|8vOekLvZC&`^bj9QT$c^I5K*boS4SXF2n^y1FwxeM%_AHJ{qH|rkL&r=M)`RV!wY|@4A^~T>Ko~!rQW~MiHPFHI$uCG(y7gO{d z+W6T?a8LC6E}*LlwD%qB?q_Q|+s`ZF+xO91;{AQMAYF<$FwhPRo83=y{{N-;{C}W5 zlFf?)@7)14Qm&NOu}hmLtF=TEVP&!#0bj2 PIS6R!wWSdU2GxOQpe;KP literal 0 HcmV?d00001 diff --git a/test/.DS_Store b/test/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..de372509067f06b8498d92d2458dbd04b320e93d GIT binary patch literal 6148 zcmeHKK~4fO6#W&VAc?YYVO&k32M`yoM}ilSm>4%GFaZ;giD06;Ie>{S$^@(vimxRq&&k|7S zF&cD84Qf+BwM4WyL;+FYuPMN5w}91Jqma(Ab|>q%)@hZhjb1m@lvnc`>5<)DrBd(o z>WIiaE!;Nmhj&|FMK`|@9kydE)_9+yS+MNm6VQPA4$HlV>U4XY+hhOveQ3BjV|05H zMrkn)kNqp=k%3+zg$@InyO^}ZbKh*fjB?w4@vO;KC!gM*G4FuRkTF2sfs-}A!r!CQ z9?!kI^E%2s6*HN4T8zVE(MKAZgYGB#mvThM)N+_m{yG^E=W_Wr<}x)-T#xi9hYzhY zAF-b)fHRvdS1>5GC?E=m0%HYufACNkeT$Jny>+0{R{&rZZf$thXEe}b2hg_|8H5L> zOjn@ks_Yd*neOm=HqN&g88qEV>C70%W>)rwqI7onJsnQUHz>6zAPTq&WXx-k&;RZ8 z`@h>s-b4XW;9n`AvXyG3j49c(H8D9pYdw?$6gG~F4C*eZ>~X9g_$XdR(S|mUAAr8a T$RIp0`6HlZkV+KzRRul(P(rry literal 0 HcmV?d00001 diff --git a/test/test_utils.py b/test/test_utils.py index b3c61a561f2..3da87847104 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -152,7 +152,7 @@ def test_draw_boxes_grayscale(): def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) + img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.long) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) @@ -162,7 +162,7 @@ def test_draw_invalid_boxes(): with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) - with pytest.raises(ValueError, match="Tensor uint8 expected"): + with pytest.raises(ValueError, match="Tensor uint8 or float expected"): utils.draw_bounding_boxes(img_wrong1, boxes) with pytest.raises(ValueError, match="Pass individual images, not batches"): utils.draw_bounding_boxes(img_wrong2, boxes) @@ -258,7 +258,7 @@ def test_draw_segmentation_masks_dtypes(): assert img_uint8 is not out_uint8 assert out_uint8.dtype == torch.uint8 - img_float = to_dtype(img_uint8, torch.float32, scale=True) + img_float = to_dtype(img_uint8, torch.float, scale=True) out_float = utils.draw_segmentation_masks(img_float, masks) assert img_float is not out_float diff --git a/torchvision/.DS_Store b/torchvision/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..29d00ea36157dfb7e507d2ac245d07fc8e0094ac GIT binary patch literal 8196 zcmeHMJ#Q015S_Iz=tx9lBqUOv08x{Y1|88+Ql&sthWMNqVdtE)L2`Ezf;v+0FQAST zrQ-)cO@U}=5OPDsU*OH|+P8Zjj%W~ycctC2ci!B*eYd*vtcghMyo$Gpwus0_XWQDq zP|?`Wz0?|d<_@etJW+?9qxYyw?GoBeI0c*nP64NYQ@|bg0lJhEU<~2MmV| zwat`O;UrWz37uK!3Ps55hzlf~M5wGAodQmQvI1PYZ=hOl(~u@s{od{$Z0~lG!B7SL zp4B4*EW|w~H}75OZ`nfK}BDAwyP zVDdHBVJseZk9oaKuoxB?crP!5+*q8g@nBqR!d-KM(H+49Z<4N>#u2tnV zO_mp7EuJ?QJYVBUGsX^LuFTZtDc*h8JWDu}=k4?NCeI3NzOOFAT09$)Cu{1OQ{b{I z&{PZD;`;wl{rUf8m%**iDX`KC2rt@=b|8VStxIBXtsS60L+8f0Oj#L%L6+mN9;Vs9 h_`?vN13edNn<)zq+CM)CDE$6#?|*q4+^}K_`~kBDN3s9_ literal 0 HcmV?d00001 diff --git a/torchvision/utils.py b/torchvision/utils.py index 4c91affac66..04513f42f10 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -339,7 +339,7 @@ def draw_keypoints( The image values should be uint8 in [0, 255] or float in [0, 1]. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float32. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where, From 745e1b24963eb1e58f8c963639e03ce018ae9fe0 Mon Sep 17 00:00:00 2001 From: Mithra Date: Fri, 15 Dec 2023 01:46:12 +0530 Subject: [PATCH 09/20] removing .DS_Store files --- .DS_Store | Bin 10244 -> 0 bytes .github/.DS_Store | Bin 6148 -> 0 bytes android/.DS_Store | Bin 6148 -> 0 bytes examples/.DS_Store | Bin 6148 -> 0 bytes gallery/.DS_Store | Bin 6148 -> 0 bytes packaging/.DS_Store | Bin 6148 -> 0 bytes references/.DS_Store | Bin 8196 -> 0 bytes test/.DS_Store | Bin 6148 -> 0 bytes torchvision/.DS_Store | Bin 8196 -> 0 bytes 9 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 .github/.DS_Store delete mode 100644 android/.DS_Store delete mode 100644 examples/.DS_Store delete mode 100644 gallery/.DS_Store delete mode 100644 packaging/.DS_Store delete mode 100644 references/.DS_Store delete mode 100644 test/.DS_Store delete mode 100644 torchvision/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index a8e2f4928f1c3680c02774e60779283faca812c2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10244 zcmeI1J7^S96o${_L1KKR@P&z2= zh`f#yy_<;=MAV3lWAh?Z5rx}%3zZtXXBVP}{xp!?+?g)qheArX!#Us_a1J;JoCD5* zzr_JOv$-@b^rNet1I_{GK-mGlKP0emOxjrVqq%jUGD`r&I_x$Dznl37V#3CxjYU7A z1;z8bBKloDyTtJP?%1EOx}=RoKlX2;97gDL7fkF@VH9l|J`10xalO4Lfw`%)PU+`_fn{7c?YAEHU zJ~}}?GCR@i!LULyJUI92k zoNKwPTP4Tmbw9SJ`$EkGrJ;U~uo&L3$G*P+MT}B1&CF!?Sf*HY;LqG@#d|Rc2wE1+i?@Reu7@mb zG}pvYDyTD{{Y!tPuRN#S2PXI0XWqPX%W?Rf{_4ExCZD$}&H?9ubHF+9XB^-+IFtmx z3G9sj|GyEr$~oX1_&*)+>QbGl4xF3)wjMf5{??wrb_5$2))oC|x}ahS?>C?0@dtm7 y=k+u1YvITRzmql={fIp%fBc^TZv7wol|a3hyY+uo`QiU}DRQl@)?)v_`u``aHcL(b diff --git a/.github/.DS_Store b/.github/.DS_Store deleted file mode 100644 index f844d354f60a38e0548a56da845039e3c90dcb17..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOG*Pl5Pcm3CMXhjF6Rj529q#B1O*|wabrk=A555Flq}rL+I`O9Df9p?Jb_!6 z?nHdm-ArckF}M-Ys?hyZ}T= z0;)i&fP5bU4#Ci4;?RCNIN26}Sftw;*V0Q!PT(>0m^kDensKQ_muh^AVO%=>fsYG4 zCJtRXjBh@Sf3oo%im^}U{(%jL2_5RJ3aA2k1yUV;=%fn#r~;qV CJ%ZT) diff --git a/android/.DS_Store b/android/.DS_Store deleted file mode 100644 index 554d09e34931977b20818b41191bea4f4f03eca1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu}%U(5S8j} zQSlN{=_Q{QqiUyIbag(P_oHdQ%cb`9`Y|+IyfwPXgvq=ZpT{Tj*m8N$8TF_Mm98kI zKK0BgH{z}Ce~Nc`KJUDZcrL8r@?JK_-$%&%C{O~Oe)Z%D*MDxCLQykjf*U122DCC z`}k0{va&A}rK@B8P=}L>3`#EwhyuO>1@p7S_y5WK_rISczeE91;9n`Aiq%@Rf-Tv* yHMcpwYeTeiG&YXQ4C)kg_Bz%BzKVCzw80nh02o=!3}OZ*e+0A)(uo4Us=znRW4+}7 diff --git a/examples/.DS_Store b/examples/.DS_Store deleted file mode 100644 index 5bb92b23617f060180db45ecdc460ba22f31bb4e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKF-}7<47F*6zJw4R88b1kRpJ1l`kbH_Ky4*D^z;7!16vNj+KuaQ2M)kNc(z0P z10P0&z?PDi*s-1es>TTs@$7yzC7KXXg(k={7!Xkpnhx?}F|x#%&X;Y|4TSi;D%typ zmef(Vx#{o!_;GJ(+PYqB+66S#$@$CccKdwUx30dicI$rB2sb)Jwgz!}&vAm>9s6O0*? zp&lLR^a%hQL#=|b-YtPa6u_7<8A1bLa|N0!=O+f6JMU%L3r8E^*v7y~@6=k*L9m22y#&y!jk(5}%$ tWL=U01Y<45KrO|X$T@PVMzaSQ!yYpxLy01KbqD%IKnd~A8TbJP-T>o7M>GHc diff --git a/gallery/.DS_Store b/gallery/.DS_Store deleted file mode 100644 index 4636c78a39945ef71256a8688e3a21fd1973ab8f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJ5Iwu5S z&9iwUS{?sU0seM}bVCdL?x_C#;H`G|NZuNBzssuhr`Pycsjz8Z(~;>m)8Q z1$g9OFa?jIzK-6m>f<4g?`6Ovm?O)7d9=k;cwiO7j=`dUCJ)f$9-Iow>U=(4+u}lA z&o66U8|Mu9ZHuk)w9HdY6;K6Kfn61l-HxJ@T_?~yRX`QkTLpN3uxN~l$5PO39q3#M z0PMo;4A)#saE|qucq|1m0yB~dlvLwO3?u3ATOXHrECnT76Q|3hb2vrkf7a0ba?Ut&NxCv(}U;@%zVexw`u* zdi0Ix+1O{H(Jk#`m6_8uT~S20AU7O5=p)NFjOt=S9x?OSqXiUuz%LD1 zL}y{E&*A2L&B4p}8B|A9jNiog6PnTpZ*M{KkVCT#4qFjEkKyH-M=5HfMtxKTAJ#d*gS zlH7>`qQJjWKsAzn(#0dWy|wgkyw_^@J)Di>3X3KMh1-tJ1#iVCaARoGd;yFd78X$h PlMewSgA}5`k1Fs9yyudx diff --git a/references/.DS_Store b/references/.DS_Store deleted file mode 100644 index c67d2d97b7c9173bc916db1ad4014e3fb0b17f2e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMyG|QH6ukpPGOs8F1t5x*D5wz?gknJs1w|THG`wtH(k^SoCWwmd_yXure#8xB zqSPo=M9T-zaqi5po{@KLLz4*Kk!H?n=G>Vx$NSJ76Omd!t1S{u6Hy73@zXe(oWkQ= zJIavlnFA`s6E*0B4pFPL*TI$!aX=gp2gCt!KpglV9Kbi5OR3pko4yp~cjo4#k`%dr&r1*%U*WaEu2`2Zk0?gC?Am z2`6P&RyIRXdUeDFnocS-D6Kdk4s<%ex%)@t>k%^chL^uLk;f-=8XYZGs!3eep=$L< zFMgbq%llEX4_;~J>;1#^_1#+UnA)pj>c_6GT`WN#s741=rzRbHs#k6={ZUurd?nN9 z#K%cnQ9i?4|8vOekLvZC&`^bj9QT$c^I5K*boS4SXF2n^y1FwxeM%_AHJ{qH|rkL&r=M)`RV!wY|@4A^~T>Ko~!rQW~MiHPFHI$uCG(y7gO{d z+W6T?a8LC6E}*LlwD%qB?q_Q|+s`ZF+xO91;{AQMAYF<$FwhPRo83=y{{N-;{C}W5 zlFf?)@7)14Qm&NOu}hmLtF=TEVP&!#0bj2 PIS6R!wWSdU2GxOQpe;KP diff --git a/test/.DS_Store b/test/.DS_Store deleted file mode 100644 index de372509067f06b8498d92d2458dbd04b320e93d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKK~4fO6#W&VAc?YYVO&k32M`yoM}ilSm>4%GFaZ;giD06;Ie>{S$^@(vimxRq&&k|7S zF&cD84Qf+BwM4WyL;+FYuPMN5w}91Jqma(Ab|>q%)@hZhjb1m@lvnc`>5<)DrBd(o z>WIiaE!;Nmhj&|FMK`|@9kydE)_9+yS+MNm6VQPA4$HlV>U4XY+hhOveQ3BjV|05H zMrkn)kNqp=k%3+zg$@InyO^}ZbKh*fjB?w4@vO;KC!gM*G4FuRkTF2sfs-}A!r!CQ z9?!kI^E%2s6*HN4T8zVE(MKAZgYGB#mvThM)N+_m{yG^E=W_Wr<}x)-T#xi9hYzhY zAF-b)fHRvdS1>5GC?E=m0%HYufACNkeT$Jny>+0{R{&rZZf$thXEe}b2hg_|8H5L> zOjn@ks_Yd*neOm=HqN&g88qEV>C70%W>)rwqI7onJsnQUHz>6zAPTq&WXx-k&;RZ8 z`@h>s-b4XW;9n`AvXyG3j49c(H8D9pYdw?$6gG~F4C*eZ>~X9g_$XdR(S|mUAAr8a T$RIp0`6HlZkV+KzRRul(P(rry diff --git a/torchvision/.DS_Store b/torchvision/.DS_Store deleted file mode 100644 index 29d00ea36157dfb7e507d2ac245d07fc8e0094ac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMJ#Q015S_Iz=tx9lBqUOv08x{Y1|88+Ql&sthWMNqVdtE)L2`Ezf;v+0FQAST zrQ-)cO@U}=5OPDsU*OH|+P8Zjj%W~ycctC2ci!B*eYd*vtcghMyo$Gpwus0_XWQDq zP|?`Wz0?|d<_@etJW+?9qxYyw?GoBeI0c*nP64NYQ@|bg0lJhEU<~2MmV| zwat`O;UrWz37uK!3Ps55hzlf~M5wGAodQmQvI1PYZ=hOl(~u@s{od{$Z0~lG!B7SL zp4B4*EW|w~H}75OZ`nfK}BDAwyP zVDdHBVJseZk9oaKuoxB?crP!5+*q8g@nBqR!d-KM(H+49Z<4N>#u2tnV zO_mp7EuJ?QJYVBUGsX^LuFTZtDc*h8JWDu}=k4?NCeI3NzOOFAT09$)Cu{1OQ{b{I z&{PZD;`;wl{rUf8m%**iDX`KC2rt@=b|8VStxIBXtsS60L+8f0Oj#L%L6+mN9;Vs9 h_`?vN13edNn<)zq+CM)CDE$6#?|*q4+^}K_`~kBDN3s9_ From 6cfa9d84b7ad5d5d5f08805b3e577e8ed4cb23b0 Mon Sep 17 00:00:00 2001 From: Mithra Date: Fri, 15 Dec 2023 01:49:30 +0530 Subject: [PATCH 10/20] gitignore updated for ignoring .DS_Store files --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index f16b54061e0..91df13f6b35 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,5 @@ xcuserdata/ # direnv .direnv .envrc + +.DS_Store \ No newline at end of file From 956a7bdea58a81b50ecb58cb3c91b97e15b905d6 Mon Sep 17 00:00:00 2001 From: Mithra <50359897+GsnMithra@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:28:41 +0530 Subject: [PATCH 11/20] Update test/test_utils.py Co-authored-by: Nicolas Hug --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3da87847104..e550b0e586f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -262,7 +262,7 @@ def test_draw_segmentation_masks_dtypes(): out_float = utils.draw_segmentation_masks(img_float, masks) assert img_float is not out_float - assert out_float.is_floating_point () + assert out_float.is_floating_point() torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1.0) From 3bcd9a372ae22bd2514dc4afe85a44d3d2049fde Mon Sep 17 00:00:00 2001 From: Mithra <50359897+GsnMithra@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:28:52 +0530 Subject: [PATCH 12/20] Update .gitignore Co-authored-by: Nicolas Hug --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 91df13f6b35..b4e1f7af1f0 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,4 @@ xcuserdata/ # direnv .direnv -.envrc - -.DS_Store \ No newline at end of file +.envrc \ No newline at end of file From 7e3a1767f9b16b8bf9ea72b36281c71989dbdc87 Mon Sep 17 00:00:00 2001 From: Mithra <50359897+GsnMithra@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:29:33 +0530 Subject: [PATCH 13/20] Update test/test_utils.py Co-authored-by: Nicolas Hug --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index e550b0e586f..fba97cd1126 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -264,7 +264,7 @@ def test_draw_segmentation_masks_dtypes(): assert img_float is not out_float assert out_float.is_floating_point() - torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1.0) + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_draw_segmentation_masks_errors(device): From acdad335bf05262c357278b231f00dd9257264a3 Mon Sep 17 00:00:00 2001 From: Mithra Date: Fri, 15 Dec 2023 21:08:30 +0530 Subject: [PATCH 14/20] reverting boxes and keypoints methods and unit tests --- .gitignore | 2 +- test/test_utils.py | 9 +++++---- torchvision/utils.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index b4e1f7af1f0..f16b54061e0 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,4 @@ xcuserdata/ # direnv .direnv -.envrc \ No newline at end of file +.envrc diff --git a/test/test_utils.py b/test/test_utils.py index fba97cd1126..f2445e32e4d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,11 +7,11 @@ import numpy as np import pytest import torch -import torchvision.utils as utils import torchvision.transforms.functional as F -from torchvision.transforms.v2.functional import to_dtype +import torchvision.utils as utils from common_utils import assert_equal, cpu_and_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageColor +from torchvision.transforms.v2.functional import to_dtype PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) @@ -152,7 +152,7 @@ def test_draw_boxes_grayscale(): def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.long) + img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) @@ -162,7 +162,7 @@ def test_draw_invalid_boxes(): with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) - with pytest.raises(ValueError, match="Tensor uint8 or float expected"): + with pytest.raises(ValueError, match="Tensor uint8 expected"): utils.draw_bounding_boxes(img_wrong1, boxes) with pytest.raises(ValueError, match="Pass individual images, not batches"): utils.draw_bounding_boxes(img_wrong2, boxes) @@ -266,6 +266,7 @@ def test_draw_segmentation_masks_dtypes(): torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_draw_segmentation_masks_errors(device): h, w = 10, 10 diff --git a/torchvision/utils.py b/torchvision/utils.py index 04513f42f10..e023501ca14 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,11 +164,11 @@ def draw_bounding_boxes( """ Draws bounding boxes on given image. - The image values should be uint8 in [0, 255] or float in [0, 1]. + The values of the input image should be uint8 between 0 and 255. If fill is True, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8 or float32. + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -192,8 +192,8 @@ def draw_bounding_boxes( _log_api_usage_once(draw_bounding_boxes) if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif not (image.dtype == torch.uint8 or image.is_floating_point()): - raise ValueError(f"Tensor uint8 or float expected, got {image.dtype}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: @@ -249,7 +249,7 @@ def draw_bounding_boxes( margin = width + 1 draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) @torch.no_grad() @@ -336,10 +336,10 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The image values should be uint8 in [0, 255] or float in [0, 1]. + The values of the input image should be uint8 between 0 and 255. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where, @@ -357,8 +357,8 @@ def draw_keypoints( _log_api_usage_once(draw_keypoints) if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif not (image.dtype == torch.uint8 or image.is_floating_point()): - raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: From cad307c3e324e5adcc40631ad01633c3d9622ca1 Mon Sep 17 00:00:00 2001 From: Mithra Date: Sat, 16 Dec 2023 00:16:14 +0530 Subject: [PATCH 15/20] floating_point check --- torchvision/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e023501ca14..85dc3988f84 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -10,6 +10,8 @@ import torch from PIL import Image, ImageColor, ImageDraw, ImageFont +# from .transforms.v2.functional import to_dtype + __all__ = [ "make_grid", "save_image", @@ -297,7 +299,7 @@ def draw_segmentation_masks( if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") - original_dtype = image.dtype + original_image = image if image.is_floating_point(): image = (image * 255).to(torch.uint8) @@ -318,10 +320,10 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - if original_dtype in {torch.float16, torch.float32, torch.float64}: + if original_image.is_floating_point(): out = out.float() / 255.0 - return out.to(original_dtype) + return out.to(original_image.dtype) @torch.no_grad() From 90982ffe002e4381c7d865e7ff3c790bb9811ade Mon Sep 17 00:00:00 2001 From: Mithra Date: Sat, 16 Dec 2023 00:17:50 +0530 Subject: [PATCH 16/20] cleaning up --- torchvision/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 85dc3988f84..840d8eed3bf 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -10,7 +10,6 @@ import torch from PIL import Image, ImageColor, ImageDraw, ImageFont -# from .transforms.v2.functional import to_dtype __all__ = [ "make_grid", From 020a84850cc2800e21301cd5e1e5fd6b8980be67 Mon Sep 17 00:00:00 2001 From: Mithra Date: Sat, 16 Dec 2023 00:40:28 +0530 Subject: [PATCH 17/20] floating_point check updated --- torchvision/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 840d8eed3bf..3a902ecf995 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -298,7 +298,7 @@ def draw_segmentation_masks( if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") - original_image = image + original_dtype = image.dtype if image.is_floating_point(): image = (image * 255).to(torch.uint8) @@ -319,10 +319,10 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - if original_image.is_floating_point(): + if torch.tensor(0, dtype=original_dtype).is_floating_point(): out = out.float() / 255.0 - return out.to(original_image.dtype) + return out.to(original_dtype) @torch.no_grad() From ee048c59a850ee805c267c54dfba0d484032fb60 Mon Sep 17 00:00:00 2001 From: Mithra Date: Sat, 16 Dec 2023 00:58:45 +0530 Subject: [PATCH 18/20] type conversions using to_dtype() with local import --- torchvision/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 3a902ecf995..8692fde5bb4 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -297,10 +297,11 @@ def draw_segmentation_masks( raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") + from torchvision.transforms.v2.functional import to_dtype original_dtype = image.dtype if image.is_floating_point(): - image = (image * 255).to(torch.uint8) + image = to_dtype(image, torch.uint8, scale=True) num_masks = masks.size()[0] @@ -320,7 +321,7 @@ def draw_segmentation_masks( out = image * (1 - alpha) + img_to_draw * alpha if torch.tensor(0, dtype=original_dtype).is_floating_point(): - out = out.float() / 255.0 + out = to_dtype(out, torch.float) / 255.0 return out.to(original_dtype) From adf249d17fea82970c7f651b41722eccf7b1eaea Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 18 Dec 2023 11:58:19 +0000 Subject: [PATCH 19/20] Use _parse_colors to handle dtype conversion --- torchvision/utils.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 8692fde5bb4..122ae4684eb 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -297,11 +297,6 @@ def draw_segmentation_masks( raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") - from torchvision.transforms.v2.functional import to_dtype - - original_dtype = image.dtype - if image.is_floating_point(): - image = to_dtype(image, torch.uint8, scale=True) num_masks = masks.size()[0] @@ -309,9 +304,10 @@ def draw_segmentation_masks( warnings.warn("masks doesn't contain any mask. No mask was drawn") return image + original_dtype = image.dtype colors = [ - torch.tensor(color, dtype=torch.uint8, device=image.device) - for color in _parse_colors(colors, num_objects=num_masks) + torch.tensor(color, dtype=original_dtype, device=image.device) + for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype) ] img_to_draw = image.detach().clone() @@ -320,9 +316,7 @@ def draw_segmentation_masks( img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha - if torch.tensor(0, dtype=original_dtype).is_floating_point(): - out = to_dtype(out, torch.float) / 255.0 - + # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype return out.to(original_dtype) @@ -524,6 +518,7 @@ def _parse_colors( colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]], *, num_objects: int, + dtype: torch.dtype = torch.uint8, ) -> List[Tuple[int, int, int]]: """ Parses a specification of colors for a set of objects. @@ -560,7 +555,10 @@ def _parse_colors( else: # colors specifies a single color for all objects colors = [colors] * num_objects - return [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors] + colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors] + if dtype.is_floating_point: # [0, 255] -> [0, 1] + colors = [tuple(v / 255 for v in color) for color in colors] + return colors def _log_api_usage_once(obj: Any) -> None: From 3dce586d6b816bfeb812d65801685fd6f7358599 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 18 Dec 2023 11:59:48 +0000 Subject: [PATCH 20/20] Revert some changes --- torchvision/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 122ae4684eb..630eada5cbc 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -186,7 +186,7 @@ def draw_bounding_boxes( font_size (int): The requested font size in points. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with bounding boxes plotted. + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -346,7 +346,7 @@ def draw_keypoints( width (int): Integer denoting width of line connecting keypoints. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -389,7 +389,7 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization