Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distributed Inference Support for ESM2/Geneformer (#482)
## Summary This PR adds `PredictionWriter` callback of `BasePredictionWriter` that supposed to capture the `batch_indices` from dataloader's batch sampler and `predictions` from `predict_step` of the Lightning module. However, `batch_indices` are not being tracked properly and that prevents us from mapping the predictions to input sequences. Here I am listing the issues that are blocking distributed inference (writing to disk) in Bionemo 2.0. The first one does not have a clear workaround in Bionemo and requires changes in NeMo. The 2nd and 3rd issue listed below are not blocking right now. Reference: #319 ---------------------- ## Details ### 1. NeMo `MegatronDataSampler` is not tracking the batch_indices. `MegatronDataSampler` is [transforming the dataloaders](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/pytorch/plugins/data_sampler.py#L67) by adding megatron batch sampler. But [Lightning prediction loop](https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/loops/prediction_loop.py#L304) tracks `batch_indices` only if the batch sampler is an instance of `lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`. To get this fixed I propose the following change in [`nemo.lightning.data.add_megatron_sampler`](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/data.py#L218) ```python def add_megatron_sampler(...) ... return DataLoader( dataloader.dataset, batch_sampler=_IndexBatchSamplerWrapper(batch_sampler), num_workers=dataloader.num_workers, pin_memory=dataloader.pin_memory, persistent_workers=dataloader.persistent_workers, collate_fn=dataloader.collate_fn, ) ``` ### Solution: Fixed by NVIDIA/NeMo/pull/10934 ### 2. Lightning `_PredictionLoop` not storing data for prediction writer callback Again in [Lightning prediction loop](https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/loops/prediction_loop.py#L235) if `data_fetcher` is an instance of `lightning.pytorch.loops.fetchers._DataLoaderIterDataFetcher`, data (including `batch_indices`) for prediction writer is not being stored. This is always the case since [NeMo Megatron Strategy always wraps the data fetcher](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/pytorch/strategies/megatron_strategy.py#L833) with `_DataLoaderIterDataFetcher`. This is also a question for Lightning team as I don't undrestand the reason behind skipping datastore for this type of data fetcher. ### Solution: A temporary workaround I proposed for this issue is to get the indices from `trainer` (`trainer.predict_dataloaders.batch_sampler.seen_batch_indices`). However, this only works for "epoch" writer intervals. We now can optionally return `input_ids` with the predictions (`--include-input-ids`) that should be used as a reliable method to map predictions to input sequences instead. ### 3. Lightning bug when return_predictions is False in `trainer.predict` This is also related to the 2nd issue listed here. If the `data_fetcher` is an instance of `_DataLoaderIterDataFetcher`, lightning prediction loop skips creating the `any_on_epoch` which is referenced later in the loop: ```bash File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/prediction_loop.py", line 124, in run self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/prediction_loop.py", line 271, in _predict_step if self._return_predictions or any_on_epoch: UnboundLocalError: local variable 'any_on_epoch' referenced before assignment ``` ### Solution: This is also a lightening bug. No solution at the moment. We are avoid setting `restrun_predictions=False`
- Loading branch information