Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dataset at "on_train_epoch_start", but "training_step" still get old data #20407

Open
Yak1m4Sg opened this issue Nov 8, 2024 · 1 comment
Labels
bug Something isn't working loops Related to the Loop API

Comments

@Yak1m4Sg
Copy link

Yak1m4Sg commented Nov 8, 2024

Bug description

I use trainer.fit(model, datamodule=dm) to start training.
"dm" is an object whose class inherited from pl.LightningDataModule, and in the class, I override the function:

def train_dataloader(self):
    train_dataset = MixedBatchMultiviewDataset(self.args, self.tokenizer, 
         known_exs=self.known_train, 
         unknown_exs=self.unknown_train, 
         feature=self.args.feature) 
        
    train_dataloader = DataLoader(train_dataset, 
         batch_size = self.args.train_batch_size, 
         shuffle=True, num_workers=self.args.num_workers, 
         pin_memory=True, collate_fn=self.collate_batch_feat)
        
    return train_dataloader

at the model's hook on_train_epoch_start, I update the dataset:

train_dl = self.trainer.train_dataloader
train_dl.dataset.update_pseudo_labels(uid2pl)

loop = self.trainer.fit_loop
loop._combined_loader = None
loop.setup_data()

in the training_step, the batch data is still old data, but trainer.train_dataloader.dataset is new:

def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int):

    self.mv_model._on_train_batch_start()

    logger.info(self.trainer.train_dataloader.dataset.unknown_feats) # new
    logger.info(batch) # old

What version are you seeing the problem on?

v2.3

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @justusschock

@Yak1m4Sg Yak1m4Sg added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 8, 2024
@lantiga
Copy link
Collaborator

lantiga commented Nov 13, 2024

hey, thanks for filing the issue. Can you provide a minimal runnable reproduction I can just copy-paste? that'll speed up the investigation, thanks!

@lantiga lantiga added loops Related to the Loop API and removed needs triage Waiting to be triaged by maintainers ver: 2.3.x labels Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working loops Related to the Loop API
Projects
None yet
Development

No branches or pull requests

2 participants