Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANS, IMP] Add new max_size parameter to Resize #3494

Merged
merged 7 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from common_utils import TransformsTester

from typing import Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple


NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
Expand Down Expand Up @@ -409,46 +409,58 @@ def test_resize(self):
batch_tensors = batch_tensors.to(dt)

for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)

self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
)

if interpolation not in [NEAREST, ]:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)

# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
for max_size in (None, 33, 40, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # unsupported, see assertRaises below
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)

self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation)
)

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
if interpolation not in [NEAREST, ]:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)

# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
)

resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation
)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
max_size=max_size)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))

self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))

for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
with self.assertRaisesRegex(ValueError, exp_msg):
F.resize(img, size=(32, 34), max_size=35)
with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
F.resize(img, size=32, max_size=32)

def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
Expand Down
41 changes: 24 additions & 17 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,30 @@ def test_resize(self):
img = Image.new("RGB", size=(width, height), color=127)

for osize in test_output_sizes_1:

t = transforms.Resize(osize)
result = t(img)

msg = "{}, {} - {}".format(height, width, osize)
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
if height < width:
expected_size = (int(osize * width / height), osize) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
elif width < height:
expected_size = (osize, int(osize * height / width)) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
else:
expected_size = (osize, osize) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
for max_size in (None, 37, 1000):

t = transforms.Resize(osize, max_size=max_size)
result = t(img)

msg = "{}, {} - {} - {}".format(height, width, osize, max_size)
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
if height < width:
exp_w, exp_h = (int(osize * width / height), osize) # (w, h)
if max_size is not None and max_size < exp_w:
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
elif width < height:
exp_w, exp_h = (osize, int(osize * height / width)) # (w, h)
if max_size is not None and max_size < exp_h:
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
else:
exp_w, exp_h = (osize, osize) # (w, h)
if max_size is not None and max_size < osize:
exp_w, exp_h = max_size, max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)

for height, width in input_sizes:
img = Image.new("RGB", size=(width, height), color=127)
Expand Down
30 changes: 14 additions & 16 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

import unittest
from typing import Sequence

from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes

Expand Down Expand Up @@ -322,32 +323,29 @@ def test_resize(self):

tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize)

for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
for max_size in (None, 35, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # Not supported
for interpolation in [BILINEAR, BICUBIC, NEAREST]:

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))

transform = T.Resize(size=script_size, interpolation=interpolation)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))

def test_resized_crop(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
Expand Down
15 changes: 12 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor


def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor:
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Expand All @@ -355,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ```size` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).

Returns:
PIL Image or Tensor: Resized image.
Expand All @@ -372,9 +381,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation)
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)

return F_t.resize(img, size=size, interpolation=interpolation.value)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size)


def scale(*args, **kwargs):
Expand Down
39 changes: 26 additions & 13 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,40 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag


@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR):
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int) or len(size) == 1:
if isinstance(size, Sequence):
size = size[0]
if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):

short, long = (w, h) if w <= h else (h, w)
if short == size:
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)

new_short, new_long = size, int(size * long / short)

if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation)


Expand Down
49 changes: 33 additions & 16 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img


def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor:
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor:
_assert_image_tensor(img)

if not isinstance(size, (int, tuple, list)):
Expand All @@ -484,34 +484,51 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
if isinstance(size, tuple):
size = list(size)

if isinstance(size, list) and len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)

w, h = _get_image_size(img)

if isinstance(size, int):
size_w, size_h = size, size
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[1], size[0] # Convention (h, w)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)

if isinstance(size, int) or len(size) < 2:
if w < h:
size_h = int(size_w * h / w)
if isinstance(size, int):
requested_new_short = size
else:
size_w = int(size_h * w / h)
requested_new_short = size[0]

if (w <= h and w == size_w) or (h <= w and h == size_h):
if short == requested_new_short:
return img

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

else: # specified both h and w
new_w, new_h = size[1], size[0]

img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])

# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None

img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners)
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)

if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
Expand Down
Loading