Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

What happens with custom samplers? #252

Open
AugustoPeres opened this issue Mar 9, 2023 · 0 comments
Open

What happens with custom samplers? #252

AugustoPeres opened this issue Mar 9, 2023 · 0 comments

Comments

@AugustoPeres
Copy link

For training reasons I had to write my own sampler. Something like:

class MySampler(torch-utils.data.Sampler):
    def __init__(self, data, batches_per_epoch, batch_size):
        # some python code
    def __iter__():
        # My iter method obeying specific rules

To create the data loader I then simply use:

sampler = MySampler(data, batches_per_epoch, batch_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=MySampler)

This works great up to the point where I try to use ray lightning. At first I tried to use ray lightning as follows:

plugin = RayStrategy(num_workers=num_workers,
                     num_cpus_per_worker=num_cpus_per_worker,
                     use_gpu=use_gpu)
trainer = pl.Trainer(max_epochs=max_epochs,
                     strategy=plugin,
                     logger=False)

Which raised the error:

AttributeError: 'SeqMatchSeqSampler' object has no attribute 'drop_last'

I then saw that there is a FLAG that disables sampler replacement: replace_sampler_ddp. Using this code:

plugin = RayStrategy(num_workers=num_workers,
                     num_cpus_per_worker=num_cpus_per_worker,
                     use_gpu=use_gpu)
trainer = pl.Trainer(max_epochs=max_epochs,
                     strategy=plugin,
                     logger=False,
                     replace_sampler_ddp=False)

I no longer that an error. However something strange seems to happen. On my local machine, when I use more workers each epoch takes longer. Why is that? Which exactly are the effects on the distributed dataloading of using replace_sampler_ddp=False?

I could not find clear documentation on this particular topic:

  • Does every worker have its own copy of the sampler?
  • If so, are there in fact more batches being computed in every epoch?
  • How can I wrap my own sampler for ddp? Is there a way to instantiate the sampler in a way such that every worker will handle different batches:

For example if I use:

sampler = MySampler(data, int(batches_per_epoch/num_ray_workers), batch_size)

Will this be equivalent for, for example 1 and 4 workers?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant