-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix v2 transforms in spawn mp context #8067
Changes from 5 commits
ad22dca
9e584c7
5132eab
52f05fb
b8e2e1b
e948758
33dc494
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import torch.nn.functional as F | ||
from common_utils import combinations_grid | ||
from torchvision import datasets | ||
from torchvision.transforms import v2 | ||
|
||
|
||
class STL10TestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -184,8 +185,9 @@ def test_combined_targets(self): | |
f"{actual} is not {expected}", | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset(target_type="category") as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config): | |
return split_to_num_examples[config["split"]] | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self): | |
(polygon_target, info["expected_polygon_target"]) | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
expected_size = (123, 321) | ||
for target_type in ["instance", "semantic", ["instance", "semantic"]]: | ||
with self.create_dataset(target_type=target_type) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config): | |
return num_examples | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -625,9 +630,10 @@ def test_images_names_split(self): | |
assert merged_imgs_names == all_imgs_names | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
expected_size = (123, 321) | ||
for target_type in ["identity", "bbox", ["identity", "bbox"]]: | ||
with self.create_dataset(target_type=target_type) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None): | |
return data | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class VOCDetectionTestCase(VOCSegmentationTestCase): | ||
|
@@ -741,8 +748,9 @@ def test_annotations(self): | |
assert object == info["annotation"] | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -815,8 +823,9 @@ def _create_json(self, root, name, content): | |
return file | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class CocoCaptionsTestCase(CocoDetectionTestCase): | ||
|
@@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config): | |
) | ||
return num_videos_per_class * len(classes) | ||
|
||
@pytest.mark.xfail(reason="FIXME") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails, I think it's because Kinetics doesn't convert its There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK with investigating later. Will have a look. |
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset(output_format="TCHW") as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): | ||
|
@@ -1237,8 +1248,9 @@ def _file_stem(self, idx): | |
return f"2008_{idx:06d}" | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset(mode="segmentation") as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config): | |
return split_to_num_examples[config["train"]] | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class SvhnTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
@@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx): | |
return (image_id, class_id, species, breed_id) | ||
|
||
def test_transforms_v2_wrapper_spawn(self): | ||
with self.create_dataset() as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) | ||
expected_size = (123, 321) | ||
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): | ||
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) | ||
|
||
|
||
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -6,6 +6,7 @@ | |||
|
||||
import contextlib | ||||
from collections import defaultdict | ||||
from copy import copy | ||||
|
||||
import torch | ||||
|
||||
|
@@ -198,8 +199,19 @@ def __getitem__(self, idx): | |||
def __len__(self): | ||||
return len(self._dataset) | ||||
|
||||
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See just above this link There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can try, but I think it is not possible. The "state" in |
||||
def __reduce__(self): | ||||
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) | ||||
# __reduce__ gets called when we try to pickle the dataset. | ||||
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process. | ||||
|
||||
# We have to reset the [target_]transform[s] attributes of the dataset | ||||
# to their original values, because we previously set them to None in __init__(). | ||||
dataset = copy(self._dataset) | ||||
dataset.transform = self.transform | ||||
dataset.transforms = self.transforms | ||||
dataset.target_transform = self.target_transform | ||||
|
||||
return wrap_dataset_for_transforms_v2, (dataset, self._target_keys) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, I don't really understand why we needed to have The whole logic seems really strange, i.e. calling yet again
@pmeier can you remind me the details of this? Do you think we could support pickleability of these datasets in a different way that would require this "fix"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
by default. Thus, we implement the We have the dynamic type in the first place to support |
||||
|
||||
|
||||
def raise_not_supported(description): | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spawn is also the default on Windows. Since macOS CI is by far the costliest, should we just use Windows here?