Skip to content

Commit

Permalink
Fix v2 transforms in spawn mp context (#8067)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Oct 27, 2023
1 parent 96d2ce9 commit b80bdb7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 40 deletions.
34 changes: 19 additions & 15 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size


__all__ = [
Expand Down Expand Up @@ -568,9 +572,6 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
Expand Down Expand Up @@ -709,26 +710,29 @@ def _no_collate(batch):
return batch


def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.

from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
# To save CI/test time, we only check on Windows where "spawn" is the default
if platform.system() != "Windows":
pytest.skip("Multiprocessing spawning is only checked on macOS.")

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)

for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
62 changes: 38 additions & 24 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.transforms import v2


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -184,8 +185,9 @@ def test_combined_targets(self):
f"{actual} is not {expected}",

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["split"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self):
(polygon_target, info["expected_polygon_target"])

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config):
return num_examples

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -625,9 +630,10 @@ def test_images_names_split(self):
assert merged_imgs_names == all_imgs_names

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None):
return data

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCDetectionTestCase(VOCSegmentationTestCase):
Expand All @@ -741,8 +748,9 @@ def test_annotations(self):
assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -815,8 +823,9 @@ def _create_json(self, root, name, content):
return file

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoCaptionsTestCase(CocoDetectionTestCase):
Expand Down Expand Up @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config):
)
return num_videos_per_class * len(classes)

@pytest.mark.xfail(reason="FIXME")
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -1237,8 +1248,9 @@ def _file_stem(self, idx):
return f"2008_{idx:06d}"

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["train"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
return (image_id, class_id, species, breed_id)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down
14 changes: 13 additions & 1 deletion torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
from collections import defaultdict
from copy import copy

import torch

Expand Down Expand Up @@ -198,8 +199,19 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._dataset)

# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.

# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform

return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)


def raise_not_supported(description):
Expand Down

0 comments on commit b80bdb7

Please sign in to comment.