Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove torchvision dep and simplify resize and rgb_to_grayscale in torch backend #20868

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration_tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BACKEND_REQ = {
"tensorflow": ("tensorflow-cpu", ""),
"torch": (
"torch torchvision",
"torch",
"--extra-index-url https://download.pytorch.org/whl/cpu ",
),
"jax": ("jax[cpu]", ""),
Expand Down
268 changes: 134 additions & 134 deletions keras/src/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import operator

import torch
import torch.nn.functional as F

from keras.src import backend
from keras.src.backend.torch.core import convert_to_tensor
from keras.src.utils.module_utils import torchvision

RESIZE_INTERPOLATIONS = {} # populated after torchvision import
RESIZE_INTERPOLATIONS = {
"bilinear": "bilinear",
"nearest": "nearest-exact",
"bicubic": "bicubic",
}

UNSUPPORTED_INTERPOLATIONS = (
"lanczos3",
Expand All @@ -19,23 +23,27 @@
def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
data_format = backend.standardize_data_format(data_format)
if data_format == "channels_last":
if images.ndim == 4:
images = images.permute((0, 3, 1, 2))
elif images.ndim == 3:
images = images.permute((2, 0, 1))
else:
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
images = torchvision.transforms.functional.rgb_to_grayscale(img=images)
if data_format == "channels_last":
if len(images.shape) == 4:
images = images.permute((0, 2, 3, 1))
elif len(images.shape) == 3:
images = images.permute((1, 2, 0))
if images.ndim not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
channel_axis = -3 if data_format == "channels_first" else -1
if images.shape[channel_axis] not in (1, 3):
raise ValueError(
"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). "
f"Received input with shape: images.shape={images.shape}"
)

# This implementation is based on
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
if images.shape[channel_axis] == 3:
r, g, b = images.unbind(dim=channel_axis)
images = (0.2989 * r + 0.587 * g + 0.114 * b).to(images.dtype)
images = images.unsqueeze(dim=channel_axis)
else:
images = images.clone()
return images


Expand Down Expand Up @@ -129,6 +137,40 @@ def hsv_planes_to_rgb_planes(hue, saturation, value):
return images


def _cast_squeeze_in(image, req_dtypes):
need_squeeze = False
# make image NCHW
if image.ndim < 4:
image = image.unsqueeze(dim=0)
need_squeeze = True

out_dtype = image.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
image = image.to(req_dtype)
return image, need_cast, need_squeeze, out_dtype


def _cast_squeeze_out(image, need_cast, need_squeeze, out_dtype):
if need_squeeze:
image = image.squeeze(dim=0)

if need_cast:
if out_dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
):
# it is better to round before cast
image = torch.round(image)
image = image.to(out_dtype)
return image


def resize(
images,
size,
Expand All @@ -141,13 +183,6 @@ def resize(
data_format=None,
):
data_format = backend.standardize_data_format(data_format)
RESIZE_INTERPOLATIONS.update(
{
"bilinear": torchvision.transforms.InterpolationMode.BILINEAR,
"nearest": torchvision.transforms.InterpolationMode.NEAREST_EXACT,
"bicubic": torchvision.transforms.InterpolationMode.BICUBIC,
}
)
if interpolation in UNSUPPORTED_INTERPOLATIONS:
raise ValueError(
"Resizing with Lanczos interpolation is "
Expand Down Expand Up @@ -182,11 +217,11 @@ def resize(
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
images, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
images, [torch.float32, torch.float64]
)
if data_format == "channels_last":
if images.ndim == 4:
images = images.permute((0, 3, 1, 2))
else:
images = images.permute((2, 0, 1))
images = images.permute((0, 3, 1, 2))

if crop_to_aspect_ratio:
shape = images.shape
Expand All @@ -198,19 +233,12 @@ def resize(
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
elif pad_to_aspect_ratio:
shape = images.shape
height, width = shape[-2], shape[-1]
Expand All @@ -221,105 +249,77 @@ def resize(
pad_width = max(width, pad_width)
img_box_hstart = int(float(pad_height - height) / 2)
img_box_wstart = int(float(pad_width - width) / 2)
if len(images.shape) == 4:
batch_size = images.shape[0]
channels = images.shape[1]
if img_box_hstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
else:
padded_img = images

if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
),
padded_img,
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=3,
)

batch_size = images.shape[0]
channels = images.shape[1]
if img_box_hstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
else:
channels = images.shape[0]
if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=1,
)
else:
padded_img = images
if img_box_wstart > 0:
torch.cat(
[
torch.ones(
(channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
padded_img,
torch.ones(
(channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
padded_img = images
if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
),
padded_img,
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=3,
)
images = padded_img

resized = torchvision.transforms.functional.resize(
img=images,
# This implementation is based on
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
if antialias and interpolation not in ("bilinear", "bicubic"):
# We manually set it to False to avoid an error downstream in
# interpolate(). This behaviour is documented: the parameter is
# irrelevant for modes that are not bilinear or bicubic. We used to
# raise an error here, but now we don't use True as the default.
antialias = False
# Define align_corners to avoid warnings
align_corners = False if interpolation in ("bilinear", "bicubic") else None
resized = F.interpolate(
images,
size=size,
interpolation=RESIZE_INTERPOLATIONS[interpolation],
mode=RESIZE_INTERPOLATIONS[interpolation],
align_corners=align_corners,
antialias=antialias,
)
if interpolation == "bicubic" and out_dtype == torch.uint8:
resized = resized.clamp(min=0, max=255)
if data_format == "channels_last":
if len(images.shape) == 4:
resized = resized.permute((0, 2, 3, 1))
elif len(images.shape) == 3:
resized = resized.permute((1, 2, 0))
resized = resized.permute((0, 2, 3, 1))
resized = _cast_squeeze_out(
resized,
need_cast=need_cast,
need_squeeze=need_squeeze,
out_dtype=out_dtype,
)
return resized


Expand Down
1 change: 0 additions & 1 deletion keras/src/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __repr__(self):
tensorflow_io = LazyModule("tensorflow_io")
scipy = LazyModule("scipy")
jax = LazyModule("jax")
torchvision = LazyModule("torchvision")
torch_xla = LazyModule(
"torch_xla",
import_error_msg=(
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytest
numpy
scipy
scikit-learn
pillow
pandas
absl-py
requests
Expand Down
1 change: 0 additions & 1 deletion requirements-jax-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu

# Jax with cuda support.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
1 change: 0 additions & 1 deletion requirements-tensorflow-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu

# Jax cpu-only version (needed for testing).
jax[cpu]
Expand Down
1 change: 0 additions & 1 deletion requirements-torch-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ tf2onnx
# - torch-xla is pinned to a version that supports GPU (2.6 doesn't)
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.5.1+cu121
torchvision==0.20.1+cu121
torch-xla==2.5.1;sys_platform != 'darwin'

# Jax cpu-only version (needed for testing).
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ tf2onnx
# Torch.
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu
torch-xla==2.6.0;sys_platform != 'darwin'

# Jax.
Expand Down