Skip to content

Added the KeyPoints TVTensor #8817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8253305
Added Keypoints to the library
Alexandre-SCHOEPP Dec 12, 2024
484561d
Improved KeyPoints to be exported
Alexandre-SCHOEPP Dec 13, 2024
3255890
Added kernels to support the keypoints
Alexandre-SCHOEPP Dec 13, 2024
7436636
Added tests for keypoints
Alexandre-SCHOEPP Dec 13, 2024
b35cba6
Applied ufmt formatting
Alexandre-SCHOEPP Dec 13, 2024
a19ec0b
Fixed the bugs found while testing
Alexandre-SCHOEPP Dec 16, 2024
5f4b188
Improved documentation to take KeyPoints into account
Alexandre-SCHOEPP Dec 17, 2024
cabce1c
Applied ufmt check
Alexandre-SCHOEPP Dec 17, 2024
d1b27ad
Fixed the hflip not being along the right coordinate
Alexandre-SCHOEPP Dec 17, 2024
6fa38f4
Merge branch 'main' into main
Alexandre-SCHOEPP Dec 18, 2024
05e4ad6
Merge branch 'main' into main
Alexandre-SCHOEPP Feb 10, 2025
03dc6c8
Merge branch 'main' into main
Alexandre-SCHOEPP Feb 20, 2025
d4d087c
Merge branch 'main' into main
Alexandre-SCHOEPP Mar 4, 2025
5a8c5b4
Fixed order of arguments
Alex-S-H-P Apr 30, 2025
dea31e2
Reworked logic of the conditions to better handle mutable/non mutable…
Alex-S-H-P Apr 30, 2025
71e20a5
Renamed out variable to be more similar with _resized_crop_bounding_b…
Alex-S-H-P Apr 30, 2025
2f77527
renamed _xyxy_to_points to _xyxy_to_keypoints for consistency
Alex-S-H-P Apr 30, 2025
517a6de
clarified _xyxy_to_points and changed the name of its caller for the …
Alex-S-H-P Apr 30, 2025
63ed4a5
Renamed half_point to more explicit single_coord_shape
Alex-S-H-P Apr 30, 2025
166c1ec
Integrated KeyPoints better in the transforms. It now warns alongside…
Alex-S-H-P Apr 30, 2025
fcfd597
Merge branch 'main' into main
Alexandre-SCHOEPP Apr 30, 2025
1cc3b6f
Fixed _geometry.py post botched merge request
Alex-S-H-P Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/tv_tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ info.

Image
Video
KeyPoints
BoundingBoxFormat
BoundingBoxes
Mask
Expand Down
11 changes: 10 additions & 1 deletion gallery/transforms/plot_tv_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.tv_tensors` supports four types of TVTensors:
# :mod:`torchvision.tv_tensors` supports five types of TVTensors:
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.KeyPoints`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a TVTensor?
Expand Down Expand Up @@ -96,6 +97,7 @@
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
# In a similar fashion, :class:`~torchvision.tv_tensors.KeyPoints` also require the ``canvas_size`` metadata to be added.

bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
Expand All @@ -104,6 +106,13 @@
)
print(bboxes)


keypoints = tv_tensors.KeyPoints(
[[17, 16], [344, 495], [0, 10], [0, 10]],
canvas_size=image.shape[-2:]
)
print(keypoints)

# %%
# Using ``tv_tensors.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
15 changes: 15 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import sys
import tempfile
from typing import Sequence, Tuple
import warnings
from subprocess import CalledProcessError, check_output, STDOUT

Expand Down Expand Up @@ -400,6 +401,20 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_keypoints(
canvas_size: Tuple[int, int] = DEFAULT_SIZE, *, num_points: int | Sequence[int] = 4, dtype=None, device='cpu'
) -> tv_tensors.KeyPoints:
"""Make the KeyPoints for testing purposes"""
if isinstance(num_points, int):
num_points = [num_points]
single_coord_shape: Tuple[int, ...] = tuple(num_points) + (1,)
y = torch.randint(0, canvas_size[0] - 1, single_coord_shape, dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1] - 1, single_coord_shape, dtype=dtype, device=device)
points = torch.cat((x, y), dim=-1)
keypoints = tv_tensors.KeyPoints(points, canvas_size=canvas_size)
return keypoints


def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
Expand Down
54 changes: 49 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
make_image,
make_image_pil,
make_image_tensor,
make_keypoints,
make_segmentation_mask,
make_video,
make_video_tensor,
Expand Down Expand Up @@ -232,6 +233,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
# explicitly passed to the kernel.
explicit_metadata = {
tv_tensors.BoundingBoxes: {"format", "canvas_size"},
tv_tensors.KeyPoints: {"canvas_size"}
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

Expand Down Expand Up @@ -336,6 +338,18 @@ def _make_transform_sample(transform, *, image_or_video, adapter):
canvas_size=size,
device=device,
),
keypoints=make_keypoints(canvas_size=size), keypoints_degenerate=tv_tensors.KeyPoints(
[
[0, 1], # left edge
[1, 0], # top edge
[0, 0], # top left corner
[size[1], 1], # right edge
[size[1], 0], # top right corner
[1, size[0]], # bottom edge
[0, size[0]], # bottom left corner
[size[1], size[0]] # bottom right corner
], canvas_size=size, device=device
),
detection_mask=make_detection_masks(size, device=device),
segmentation_mask=make_segmentation_mask(size, device=device),
int=0,
Expand Down Expand Up @@ -689,6 +703,7 @@ def test_functional(self, size, make_input):
(F.resize_image, torch.Tensor),
(F._geometry._resize_image_pil, PIL.Image.Image),
(F.resize_image, tv_tensors.Image),
(F.resize_keypoints, tv_tensors.KeyPoints),
(F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resize_mask, tv_tensors.Mask),
(F.resize_video, tv_tensors.Video),
Expand Down Expand Up @@ -1044,6 +1059,7 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1214,6 +1230,7 @@ def test_functional(self, make_input):
(F.affine_image, torch.Tensor),
(F._geometry._affine_image_pil, PIL.Image.Image),
(F.affine_image, tv_tensors.Image),
(F.affine_keypoints, tv_tensors.KeyPoints),
(F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
(F.affine_mask, tv_tensors.Mask),
(F.affine_video, tv_tensors.Video),
Expand Down Expand Up @@ -1496,6 +1513,7 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1638,6 +1656,7 @@ def test_functional(self, make_input):
(F.rotate_image, torch.Tensor),
(F._geometry._rotate_image_pil, PIL.Image.Image),
(F.rotate_image, tv_tensors.Image),
(F.rotate_keypoints, tv_tensors.KeyPoints),
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
Expand Down Expand Up @@ -2343,7 +2362,9 @@ def test_error(self, T):
F.to_pil_image(imgs[0]),
tv_tensors.Mask(torch.rand(12, 12)),
tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
tv_tensors.KeyPoints(torch.rand(2, 2), canvas_size=(12, 12))
):
print(type(input_with_bad_type), cutmix_mixup)
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)

Expand Down Expand Up @@ -2751,8 +2772,9 @@ def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints
],
)
def test_displacement_error(self, make_input):
input = make_input()
Expand All @@ -2764,8 +2786,10 @@ def test_displacement_error(self, make_input):
F.elastic(input, displacement=torch.rand(F.get_size(input)))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video,
make_keypoints
],
)
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
Expand Down Expand Up @@ -3471,7 +3495,7 @@ def _sample_input_adapter(self, transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
Expand Down Expand Up @@ -6271,3 +6295,23 @@ def test_different_sizes(self, make_input1, make_input2, query):
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])

@pytest.mark.parametrize(
'boxes', [
tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
]
)
def test_convert_bounding_boxes_to_points(self, boxes: tv_tensors.BoundingBoxes):
# TODO: this test can't handle rotated boxes yet
kp = F.convert_bounding_boxes_to_points(boxes)
assert kp.shape == boxes.shape + (2, )
assert kp.dtype == boxes.dtype
# kp is a list of A, B, C, D polygons.
# If we use A | C, we should get back the XYXY format of bounding box
reconverted = torch.cat([kp[..., 0, :], kp[..., 2, :]], dim=-1)
reconverted_bbox = F.convert_bounding_box_format(
tv_tensors.BoundingBoxes(
reconverted, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=kp.canvas_size
), new_format=boxes.format
)
assert (reconverted_bbox == boxes).all(), f"Invalid reconversion : {reconverted_bbox}"
54 changes: 30 additions & 24 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image, make_keypoints

from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any
Expand All @@ -14,29 +14,32 @@
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_masks(DEFAULT_SIZE)
KEYPOINTS = make_keypoints(DEFAULT_SIZE)


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.KeyPoints,), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((KEYPOINTS,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: True,), True),
((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
Expand All @@ -57,15 +60,18 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
Expand Down
Loading