From fb99bca7a8049ba1eb206b86b7dc8b84836a44e7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Mar 2021 12:26:16 +0000 Subject: [PATCH 1/5] WIP, still needs tests and docs --- torchvision/transforms/functional_tensor.py | 40 ++++++++++++--------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d20d24a8413..1daecfc5243 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con return img -def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor: _assert_image_tensor(img) if not isinstance(size, (int, tuple, list)): @@ -484,34 +484,42 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten if isinstance(size, tuple): size = list(size) - if isinstance(size, list) and len(size) not in [1, 2]: - raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " - "{} element tuple/list".format(len(size))) + if isinstance(size, list): + if len(size) not in [1, 2]: + raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " + "{} element tuple/list".format(len(size))) + if max_size is not None and len(size) != 1: + raise ValueError("TODO proper error message") w, h = _get_image_size(img) - if isinstance(size, int): - size_w, size_h = size, size - elif len(size) < 2: - size_w, size_h = size[0], size[0] - else: - size_w, size_h = size[1], size[0] # Convention (h, w) + if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge + short, long = (w, h) if w <= h else (h, w) - if isinstance(size, int) or len(size) < 2: - if w < h: - size_h = int(size_w * h / w) + if isinstance(size, int): + requested_new_short = size else: - size_w = int(size_h * w / h) + requested_new_short = size[0] - if (w <= h and w == size_w) or (h <= w and h == size_h): + if short == requested_new_short: return img + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None and new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + + else: # specified both h and w + new_w, new_h = size[1], size[0] + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners) + img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) From 4b824a20b7a68876224d7b25701726887f7187cd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Mar 2021 15:28:26 +0000 Subject: [PATCH 2/5] tests --- test/test_functional_tensor.py | 70 ++++++++++++--------- test/test_transforms.py | 41 +++++++----- test/test_transforms_tensor.py | 30 +++++---- torchvision/transforms/functional.py | 6 +- torchvision/transforms/functional_pil.py | 33 ++++++---- torchvision/transforms/functional_tensor.py | 7 ++- torchvision/transforms/transforms.py | 7 ++- 7 files changed, 110 insertions(+), 84 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c043e90129f..0c184228f27 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -13,7 +13,7 @@ from common_utils import TransformsTester -from typing import Dict, List, Tuple +from typing import Dict, List, Sequence, Tuple NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC @@ -409,39 +409,42 @@ def test_resize(self): batch_tensors = batch_tensors.to(dt) for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: - for interpolation in [BILINEAR, BICUBIC, NEAREST]: - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) - - self.assertEqual( - resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) - ) - - if interpolation not in [NEAREST, ]: - # We can not check values if mode = NEAREST, as results are different - # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] - # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] - resized_tensor_f = resized_tensor - # we need to cast to uint8 to compare with PIL image - if resized_tensor_f.dtype == torch.uint8: - resized_tensor_f = resized_tensor_f.to(torch.float) - - # Pay attention to high tolerance for MAE - self.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) + for max_size in (None, 33, 40, 1000): + if max_size is not None and isinstance(size, Sequence) and len(size) != 1: + continue # unsupported, see assertRaises below + for interpolation in [BILINEAR, BICUBIC, NEAREST]: + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) + + self.assertEqual( + resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) ) - if isinstance(size, int): - script_size = [size, ] - else: - script_size = size + if interpolation not in [NEAREST, ]: + # We can not check values if mode = NEAREST, as results are different + # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] + # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] + resized_tensor_f = resized_tensor + # we need to cast to uint8 to compare with PIL image + if resized_tensor_f.dtype == torch.uint8: + resized_tensor_f = resized_tensor_f.to(torch.float) + + # Pay attention to high tolerance for MAE + self.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) + ) - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) - self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size - self._test_fn_on_batch( - batch_tensors, F.resize, size=script_size, interpolation=interpolation - ) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) + self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + + self._test_fn_on_batch( + batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size + ) # assert changed type warning with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): @@ -449,6 +452,13 @@ def test_resize(self): res2 = F.resize(tensor, size=32, interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) + for img in (tensor, pil_img): + with self.assertRaisesRegex(ValueError, "TODO"): + F.resize(img, size=(32, 34), max_size=35) + with self.assertRaisesRegex(ValueError, "TODO"): + F.resize(img, size=32, max_size=32) + + def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity diff --git a/test/test_transforms.py b/test/test_transforms.py index 392978d988b..64c6830472e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -312,23 +312,30 @@ def test_resize(self): img = Image.new("RGB", size=(width, height), color=127) for osize in test_output_sizes_1: - - t = transforms.Resize(osize) - result = t(img) - - msg = "{}, {} - {}".format(height, width, osize) - osize = osize[0] if isinstance(osize, (list, tuple)) else osize - # If size is an int, smaller edge of the image will be matched to this number. - # i.e, if height > width, then image will be rescaled to (size * height / width, size). - if height < width: - expected_size = (int(osize * width / height), osize) # (w, h) - self.assertEqual(result.size, expected_size, msg=msg) - elif width < height: - expected_size = (osize, int(osize * height / width)) # (w, h) - self.assertEqual(result.size, expected_size, msg=msg) - else: - expected_size = (osize, osize) # (w, h) - self.assertEqual(result.size, expected_size, msg=msg) + for max_size in (None, 37, 1000): + + t = transforms.Resize(osize, max_size=max_size) + result = t(img) + + msg = "{}, {} - {} - {}".format(height, width, osize, max_size) + osize = osize[0] if isinstance(osize, (list, tuple)) else osize + # If size is an int, smaller edge of the image will be matched to this number. + # i.e, if height > width, then image will be rescaled to (size * height / width, size). + if height < width: + exp_w, exp_h = (int(osize * width / height), osize) # (w, h) + if max_size is not None and max_size < exp_w: + exp_w, exp_h = max_size, int(max_size * exp_h / exp_w) + self.assertEqual(result.size, (exp_w, exp_h), msg=msg) + elif width < height: + exp_w, exp_h = (osize, int(osize * height / width)) # (w, h) + if max_size is not None and max_size < exp_h: + exp_w, exp_h = int(max_size * exp_w / exp_h), max_size + self.assertEqual(result.size, (exp_w, exp_h), msg=msg) + else: + exp_w, exp_h = (osize, osize) # (w, h) + if max_size is not None and max_size < osize: + exp_w, exp_h = max_size, max_size + self.assertEqual(result.size, (exp_w, exp_h), msg=msg) for height, width in input_sizes: img = Image.new("RGB", size=(width, height), color=127) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 3263dd3e5dd..5ba63b9b6d3 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -7,6 +7,7 @@ import numpy as np import unittest +from typing import Sequence from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes @@ -322,32 +323,29 @@ def test_resize(self): tensor, _ = self._create_data(height=34, width=36, device=self.device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) - script_fn = torch.jit.script(F.resize) for dt in [None, torch.float32, torch.float64]: if dt is not None: # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]: - for interpolation in [BILINEAR, BICUBIC, NEAREST]: + for max_size in (None, 35, 1000): + if max_size is not None and isinstance(size, Sequence) and len(size) != 1: + continue # Not supported + for interpolation in [BILINEAR, BICUBIC, NEAREST]: - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size - if isinstance(size, int): - script_size = [size, ] - else: - script_size = size - - s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation) - self.assertTrue(s_resized_tensor.equal(resized_tensor)) - - transform = T.Resize(size=script_size, interpolation=interpolation) - s_transform = torch.jit.script(transform) - self._test_transform_vs_scripted(transform, s_transform, tensor) - self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size) + s_transform = torch.jit.script(transform) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) with get_tmp_dir() as tmp_dir: - script_fn.save(os.path.join(tmp_dir, "t_resize.pt")) + s_transform.save(os.path.join(tmp_dir, "t_resize.pt")) def test_resized_crop(self): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 34ea8de6ad0..16881238e31 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -337,7 +337,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -372,9 +372,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte if not isinstance(img, torch.Tensor): pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.resize(img, size=size, interpolation=pil_interpolation) + return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) - return F_t.resize(img, size=size, interpolation=interpolation.value) + return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size) def scale(*args, **kwargs): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 6999a2acf5f..7554bfc94a6 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -204,27 +204,34 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag @torch.jit.unused -def resize(img, size, interpolation=Image.BILINEAR): +def resize(img, size, interpolation=Image.BILINEAR, max_size=None): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): raise TypeError('Got inappropriate size arg: {}'.format(size)) - if isinstance(size, int) or len(size) == 1: - if isinstance(size, Sequence): - size = size[0] + if isinstance(size, Sequence) and len(size) == 1: + size = size[0] + if isinstance(size, int): w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): + + short, long = (w, h) if w <= h else (h, w) + if short == size: return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), interpolation) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), interpolation) + + new_short, new_long = size, int(size * long / short) + + if max_size is not None: + if max_size <= size: + raise ValueError("TODO") + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + return img.resize((new_w, new_h), interpolation) else: + if max_size is not None: + raise ValueError("TODO: proper error message") return img.resize(size[::-1], interpolation) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1daecfc5243..2c1152e6204 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -506,8 +506,11 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si new_short, new_long = requested_new_short, int(requested_new_short * long / short) - if max_size is not None and new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError("TODO") + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 916956e29fd..312b448ab8f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -244,13 +244,14 @@ class Resize(torch.nn.Module): """ - def __init__(self, size, interpolation=InterpolationMode.BILINEAR): + def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) if isinstance(size, Sequence) and len(size) not in (1, 2): raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size + self.max_size = max_size # Backward compatibility with integer value if isinstance(interpolation, int): @@ -270,11 +271,11 @@ def forward(self, img): Returns: PIL Image or Tensor: Rescaled image. """ - return F.resize(img, self.size, self.interpolation) + return F.resize(img, self.size, self.interpolation, self.max_size) def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format(self.size, interpolate_str, max_size) class Scale(Resize): From 0fa14d3a7620146b110bdaef5f0caec230fef6b4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Mar 2021 15:31:52 +0000 Subject: [PATCH 3/5] flake8 --- test/test_functional_tensor.py | 7 ++++--- torchvision/transforms/functional.py | 3 ++- torchvision/transforms/functional_tensor.py | 2 +- torchvision/transforms/transforms.py | 3 ++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0c184228f27..11d586a56af 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -417,7 +417,8 @@ def test_resize(self): resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) self.assertEqual( - resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) + resized_tensor.size()[1:], resized_pil_img.size[::-1], + msg="{}, {}".format(size, interpolation) ) if interpolation not in [NEAREST, ]: @@ -439,7 +440,8 @@ def test_resize(self): else: script_size = size - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, + max_size=max_size) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) self._test_fn_on_batch( @@ -458,7 +460,6 @@ def test_resize(self): with self.assertRaisesRegex(ValueError, "TODO"): F.resize(img, size=32, max_size=32) - def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 16881238e31..1a1fdd3129d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2c1152e6204..5e63c84c4ec 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -487,7 +487,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si if isinstance(size, list): if len(size) not in [1, 2]: raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " - "{} element tuple/list".format(len(size))) + "{} element tuple/list".format(len(size))) if max_size is not None and len(size) != 1: raise ValueError("TODO proper error message") diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 312b448ab8f..40c06fbe778 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -275,7 +275,8 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format(self.size, interpolate_str, max_size) + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format( + self.size, interpolate_str, max_size) class Scale(Resize): From 399dadf4e78800d4cc53578f8bca4e94a2b05875 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Mar 2021 15:58:13 +0000 Subject: [PATCH 4/5] Docs + fixed some tests --- torchvision/transforms/functional.py | 8 ++++++++ torchvision/transforms/transforms.py | 10 +++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 1a1fdd3129d..5b630e72c75 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -356,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image: if the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, then + the image is resized again so that the longer edge is equal to + ``max_size``. As a result, ```size` might be overruled, i.e the + smaller edge may be shorter than ``size``. This is only supported + if ``size`` is an int (or a sequence of length 1 in torchscript + mode). Returns: PIL Image or Tensor: Resized image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 40c06fbe778..76bba5fd983 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -241,6 +241,14 @@ class Resize(torch.nn.Module): If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image: if the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, then + the image is resized again so that the longer edge is equal to + ``max_size``. As a result, ```size` might be overruled, i.e the + smaller edge may be shorter than ``size``. This is only supported + if ``size`` is an int (or a sequence of length 1 in torchscript + mode). """ @@ -276,7 +284,7 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format( - self.size, interpolate_str, max_size) + self.size, interpolate_str, self.max_size) class Scale(Resize): From aa5162c5a9875e880248a7a0e53ff0a503842062 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Mar 2021 16:14:33 +0000 Subject: [PATCH 5/5] proper error messages --- test/test_functional_tensor.py | 5 +++-- torchvision/transforms/functional_pil.py | 10 ++++++++-- torchvision/transforms/functional_tensor.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 11d586a56af..f1219ff7ce9 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -455,9 +455,10 @@ def test_resize(self): self.assertTrue(res1.equal(res2)) for img in (tensor, pil_img): - with self.assertRaisesRegex(ValueError, "TODO"): + exp_msg = "max_size should only be passed if size specifies the length of the smaller edge" + with self.assertRaisesRegex(ValueError, exp_msg): F.resize(img, size=(32, 34), max_size=35) - with self.assertRaisesRegex(ValueError, "TODO"): + with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"): F.resize(img, size=32, max_size=32) def test_resized_crop(self): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 7554bfc94a6..42d7db9f260 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -223,7 +223,10 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None): if max_size is not None: if max_size <= size: - raise ValueError("TODO") + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) if new_long > max_size: new_short, new_long = int(max_size * new_short / new_long), max_size @@ -231,7 +234,10 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None): return img.resize((new_w, new_h), interpolation) else: if max_size is not None: - raise ValueError("TODO: proper error message") + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) return img.resize(size[::-1], interpolation) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5e63c84c4ec..f4358bb6c8c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -489,7 +489,10 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size))) if max_size is not None and len(size) != 1: - raise ValueError("TODO proper error message") + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) w, h = _get_image_size(img) @@ -508,7 +511,10 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si if max_size is not None: if max_size <= requested_new_short: - raise ValueError("TODO") + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) if new_long > max_size: new_short, new_long = int(max_size * new_short / new_long), max_size