Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support of float dtypes for draw_segmentation_masks #8150

Merged
merged 23 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f36dd2c
support for float32 for draw_*
GsnMithra Dec 10, 2023
f091cc7
cleaning up
GsnMithra Dec 10, 2023
5672d8c
method description
GsnMithra Dec 10, 2023
e06c1f6
Merge branch 'main' into support-for-float32
GsnMithra Dec 11, 2023
301e42b
Update torchvision/utils.py
GsnMithra Dec 11, 2023
9146ebe
floating point dtype check
GsnMithra Dec 11, 2023
bb1cd08
test_draw_segmentation_masks
GsnMithra Dec 12, 2023
8d486a8
unit test for draw_segmentation_masks dtypes
GsnMithra Dec 14, 2023
baec368
cleaning up other method and checks
GsnMithra Dec 14, 2023
745e1b2
removing .DS_Store files
GsnMithra Dec 14, 2023
6cfa9d8
gitignore updated for ignoring .DS_Store files
GsnMithra Dec 14, 2023
956a7bd
Update test/test_utils.py
GsnMithra Dec 15, 2023
3bcd9a3
Update .gitignore
GsnMithra Dec 15, 2023
7e3a176
Update test/test_utils.py
GsnMithra Dec 15, 2023
acdad33
reverting boxes and keypoints methods and unit tests
GsnMithra Dec 15, 2023
cad307c
floating_point check
GsnMithra Dec 15, 2023
90982ff
cleaning up
GsnMithra Dec 15, 2023
020a848
floating_point check updated
GsnMithra Dec 15, 2023
ee048c5
type conversions using to_dtype() with local import
GsnMithra Dec 15, 2023
a5e4c54
Merge branch 'main' of github.com:pytorch/vision into support-for-flo…
NicolasHug Dec 18, 2023
0b82508
Merge branch 'support-for-float32' of github.com:GsnMithra/vision int…
NicolasHug Dec 18, 2023
adf249d
Use _parse_colors to handle dtype conversion
NicolasHug Dec 18, 2023
3dce586
Revert some changes
NicolasHug Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ xcuserdata/
# direnv
.direnv
.envrc

.DS_Store
GsnMithra marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 23 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import numpy as np
import pytest
import torch
import torchvision.transforms.functional as F
import torchvision.utils as utils
import torchvision.transforms.functional as F
from torchvision.transforms.v2.functional import to_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to re-order the imports, check out https://github.com/pytorch/vision/actions/runs/7214131414/job/19676695616?pr=8150, or run the pre-commit hooks locally (check our contributing instructions)

from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor

Expand Down Expand Up @@ -151,7 +152,7 @@ def test_draw_boxes_grayscale():

def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.long)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
Expand All @@ -161,7 +162,7 @@ def test_draw_invalid_boxes():

with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
with pytest.raises(ValueError, match="Tensor uint8 or float expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
Expand Down Expand Up @@ -246,6 +247,25 @@ def test_draw_segmentation_masks(colors, alpha, device):
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
num_masks, h, w = 2, 100, 100

masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)

assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8

img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_segmentation_masks(img_float, masks)

assert img_float is not out_float
assert out_float.is_floating_point ()
GsnMithra marked this conversation as resolved.
Show resolved Hide resolved

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1.0)
GsnMithra marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
h, w = 10, 10
Expand Down
44 changes: 25 additions & 19 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def draw_bounding_boxes(

"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
If fill is True, Resulting Tensor should be saved as PNG image.

Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
image (Tensor): Tensor of shape (C x H x W) and dtype uint8 or float32.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
Expand All @@ -185,15 +185,15 @@ def draw_bounding_boxes(
font_size (int): The requested font size in points.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with bounding boxes plotted.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"Tensor uint8 or float expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
Expand Down Expand Up @@ -249,7 +249,7 @@ def draw_bounding_boxes(
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype)


@torch.no_grad()
Expand All @@ -262,10 +262,10 @@ def draw_segmentation_masks(

"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].

Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
Expand All @@ -282,8 +282,8 @@ def draw_segmentation_masks(
_log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
Expand All @@ -297,15 +297,18 @@ def draw_segmentation_masks(
if masks.shape[-2:] != image.shape[-2:]:
raise ValueError("The image and the masks must have the same height and width")

original_dtype = image.dtype
if image.is_floating_point():
image = (image * 255).to(torch.uint8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use to_dtype here instead, like you did in the test.


num_masks = masks.size()[0]

if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image

out_dtype = torch.uint8
colors = [
torch.tensor(color, dtype=out_dtype, device=image.device)
torch.tensor(color, dtype=torch.uint8, device=image.device)
for color in _parse_colors(colors, num_objects=num_masks)
]

Expand All @@ -315,7 +318,10 @@ def draw_segmentation_masks(
img_to_draw[:, mask] = color[:, None]

out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
if original_dtype in {torch.float16, torch.float32, torch.float64}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for is_floating_point() instead.

out = out.float() / 255.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same: Let's use to_dtype here instead, like you did in the test.


return out.to(original_dtype)


@torch.no_grad()
Expand All @@ -330,10 +336,10 @@ def draw_keypoints(

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].

Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where,
Expand All @@ -344,15 +350,15 @@ def draw_keypoints(
width (int): Integer denoting width of line connecting keypoints.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
Expand Down Expand Up @@ -387,7 +393,7 @@ def draw_keypoints(
width=width,
)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype)


# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
Expand Down
Loading