Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix v2 transforms in spawn mp context #8067

Merged
merged 7 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size


__all__ = [
Expand Down Expand Up @@ -568,9 +572,6 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
Expand Down Expand Up @@ -709,26 +710,29 @@ def _no_collate(batch):
return batch


def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.

# To save CI/test time, we only check on macOS where "spawn" is the default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spawn is also the default on Windows. Since macOS CI is by far the costliest, should we just use Windows here?

if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")

from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)

for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
62 changes: 38 additions & 24 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.transforms import v2


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -184,8 +185,9 @@ def test_combined_targets(self):
f"{actual} is not {expected}",

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["split"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self):
(polygon_target, info["expected_polygon_target"])

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config):
return num_examples

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -625,9 +630,10 @@ def test_images_names_split(self):
assert merged_imgs_names == all_imgs_names

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None):
return data

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCDetectionTestCase(VOCSegmentationTestCase):
Expand All @@ -741,8 +748,9 @@ def test_annotations(self):
assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -815,8 +823,9 @@ def _create_json(self, root, name, content):
return file

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoCaptionsTestCase(CocoDetectionTestCase):
Expand Down Expand Up @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config):
)
return num_videos_per_class * len(classes)

@pytest.mark.xfail(reason="FIXME")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails, I think it's because Kinetics doesn't convert its transform attribute into a transforms attribute, but I haven't double-checked. If that's OK I'd like to merge this PR right now to get it behind us, and investigate that separately just after.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK with investigating later. Will have a look.

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -1237,8 +1248,9 @@ def _file_stem(self, idx):
return f"2008_{idx:06d}"

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["train"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
return (image_id, class_id, species, breed_id)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down
14 changes: 13 additions & 1 deletion torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
from collections import defaultdict
from copy import copy

import torch

Expand Down Expand Up @@ -198,8 +199,19 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._dataset)

# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See just above this link

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try, but I think it is not possible. The "state" in __get_state__ and __set_state__ is the second return value of __reduce__ below. And while we can recreate a VisionDatasetTVTensorWrapper from them, pickle does not know how to create the dynamic type. I'll give it a shot.

def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.

# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform

return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I don't really understand why we needed to have __reduce__ in the first place. I understand it's needed to support pickle, but my understanding stops here.

The whole logic seems really strange, i.e. calling yet again wrap_dataset_for_transforms_v2(), which is how the instance-being-pickled was created in the first place anyway. And now on top of that we add the "field resetting logic" i.e. we just undo what we did in __init__.

I understand it's needed to support pickle, but my understanding stops here.

@pmeier can you remind me the details of this? Do you think we could support pickleability of these datasets in a different way that would require this "fix"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pickle does not support dynamic types like we do in

wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})

by default. Thus, we implement the __reduce__ method and tell it how to construct the object by its parts. Now only the items on L212 need to be pickled. While unpickling, the dynamic type is created anew and thus we circumvent the issue.

We have the dynamic type in the first place to support isinstance checks. There was also a different option in #7239 (comment) that would work without a dynamic type. I think that could potentially work without your fix, but I would need to test. However, this option also has its drawbacks (see discussion in the original PR).



def raise_not_supported(description):
Expand Down