diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index f35571a0356..efb48cb1c6c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2098,6 +2098,7 @@ def prepare_data_loader( even_batches=self.even_batches, slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, + data_seed=self.dataloader_config.data_seed, non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, ) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index c793e85f8dc..bf3f35fb7e8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -77,9 +77,11 @@ class SeedableRandomSampler(RandomSampler): """ def __init__(self, *args, **kwargs): + data_seed = kwargs.pop("data_seed", None) super().__init__(*args, **kwargs) + + self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed() self.epoch = 0 - self.initial_seed = torch.random.initial_seed() def __iter__(self): if self.generator is None: @@ -937,6 +939,7 @@ def prepare_data_loader( even_batches: bool = True, slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, + data_seed: Optional[int] = None, non_blocking: bool = False, use_stateful_dataloader: bool = False, ) -> DataLoader: @@ -996,6 +999,9 @@ def prepare_data_loader( reproducability. Comes at a cost of potentially different performances due to different shuffling algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every `self.set_epoch` + data_seed (`int`, *optional*, defaults to `None`): + The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator + will use the current default seed from torch. non_blocking (`bool`, *optional*, defaults to `False`): If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. @@ -1069,6 +1075,7 @@ def prepare_data_loader( replacement=sampler.replacement, num_samples=sampler._num_samples, generator=getattr(sampler, "generator", torch.Generator()), + data_seed=data_seed, ) if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA: diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 1d54d098a63..3b6f302f76a 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -400,6 +400,34 @@ def check_seedable_sampler_in_batch_sampler_shard(): ), "Sampler in BatchSamplerShard is not SeedableRandomSampler." +def check_seedable_sampler_with_data_seed(): + # Set seed + set_seed(42) + data_seed = 42 + train_set = RegressionDataset(length=10, seed=42) + train_dl = DataLoader(train_set, batch_size=2, shuffle=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed) + accelerator = Accelerator(dataloader_config=config) + prepared_dl = accelerator.prepare(train_dl) + original_items = [] + for _ in range(3): + for batch in prepared_dl: + original_items.append(batch["x"]) + original_items = torch.cat(original_items) + + # Set new data seed + config.data_seed = 43 + accelerator = Accelerator(dataloader_config=config) + prepared_dl = accelerator.prepare(train_dl) + new_items = [] + for _ in range(3): + for batch in prepared_dl: + new_items.append(batch["x"]) + new_items = torch.cat(new_items) + assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed." + + def mock_training(length, batch_size, generator, use_seedable_sampler=False): set_seed(42) generator.manual_seed(42) @@ -800,6 +828,7 @@ def main(): central_dl_preparation_check() custom_sampler_check() check_seedable_sampler() + check_seedable_sampler_with_data_seed() if state.num_processes > 1: check_seedable_sampler_in_batch_sampler_shard() diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index a06ceb0f659..a28cc52ebdf 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -737,6 +737,9 @@ class DataLoaderConfiguration: training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results. + data_seed (`int`, defaults to `None`): + The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator + will use the current default seed from torch. non_blocking (`bool`, defaults to `False`): If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device transfers, allowing for better overlap between dataloader communication and computation. Recommended that @@ -781,6 +784,13 @@ class DataLoaderConfiguration: "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." }, ) + data_seed: int = field( + default=None, + metadata={ + "help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator" + " will use the current default seed from torch." + }, + ) non_blocking: bool = field( default=False, metadata={