Skip to content

Commit

Permalink
Distributed Inference Support for ESM2/Geneformer (#482)
Browse files Browse the repository at this point in the history
## Summary

This PR adds `PredictionWriter` callback of `BasePredictionWriter` that
supposed to capture the `batch_indices` from dataloader's batch sampler
and `predictions` from `predict_step` of the Lightning module. However,
`batch_indices` are not being tracked properly and that prevents us from
mapping the predictions to input sequences. Here I am listing the issues
that are blocking distributed inference (writing to disk) in Bionemo
2.0. The first one does not have a clear workaround in Bionemo and
requires changes in NeMo. The 2nd and 3rd issue listed below are not
blocking right now.

Reference: #319 


----------------------
## Details

### 1.  NeMo `MegatronDataSampler` is not tracking the batch_indices.

`MegatronDataSampler` is [transforming the
dataloaders](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/pytorch/plugins/data_sampler.py#L67)
by adding megatron batch sampler. But [Lightning prediction
loop](https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/loops/prediction_loop.py#L304)
tracks `batch_indices` only if the batch sampler is an instance of
`lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`. To
get this fixed I propose the following change in
[`nemo.lightning.data.add_megatron_sampler`](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/data.py#L218)

```python
def add_megatron_sampler(...)
    ...
    return DataLoader(
        dataloader.dataset,
        batch_sampler=_IndexBatchSamplerWrapper(batch_sampler),
        num_workers=dataloader.num_workers,
        pin_memory=dataloader.pin_memory,
        persistent_workers=dataloader.persistent_workers,
        collate_fn=dataloader.collate_fn,
    )
```
### Solution:
Fixed by NVIDIA/NeMo/pull/10934

### 2. Lightning `_PredictionLoop` not storing data for prediction
writer callback

Again in [Lightning prediction
loop](https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/loops/prediction_loop.py#L235)
if `data_fetcher` is an instance of
`lightning.pytorch.loops.fetchers._DataLoaderIterDataFetcher`, data
(including `batch_indices`) for prediction writer is not being stored.
This is always the case since [NeMo Megatron Strategy always wraps the
data
fetcher](https://github.com/NVIDIA/NeMo/blob/1757ff9ed10272bf5ee7332d64fccd4bd9676f1b/nemo/lightning/pytorch/strategies/megatron_strategy.py#L833)
with `_DataLoaderIterDataFetcher`. This is also a question for Lightning
team as I don't undrestand the reason behind skipping datastore for this
type of data fetcher.

### Solution:
A temporary workaround I proposed for this issue is to get the indices
from `trainer`
(`trainer.predict_dataloaders.batch_sampler.seen_batch_indices`).
However, this only works for "epoch" writer intervals. We now can
optionally return `input_ids` with the predictions
(`--include-input-ids`) that should be used as a reliable method to map
predictions to input sequences instead.

### 3. Lightning bug when return_predictions is False in
`trainer.predict`

This is also related to the 2nd issue listed here. If the `data_fetcher`
is an instance of `_DataLoaderIterDataFetcher`, lightning prediction
loop skips creating the `any_on_epoch` which is referenced later in the
loop:

```bash
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/prediction_loop.py", line 124, in run
  self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/prediction_loop.py", line 271, in _predict_step
  if self._return_predictions or any_on_epoch:
UnboundLocalError: local variable 'any_on_epoch' referenced before assignment
```
### Solution:
This is also a lightening bug. No solution at the moment. We are avoid
setting `restrun_predictions=False`
  • Loading branch information
farhadrgh authored Dec 16, 2024
1 parent 0360d50 commit c39b2d4
Show file tree
Hide file tree
Showing 11 changed files with 690 additions and 402 deletions.
10 changes: 5 additions & 5 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,22 @@ We download a CSV example dataset of articical sequences for this inference exam
mkdir -p $WORKDIR/esm2_finetune_tutorial
# download sample data CSV for inference
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0 --source ngc)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/inference_results.pt
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/
infer_esm2 --checkpoint-path <finetune checkpoint path> \
--data-path $DATA_PATH \
--results-path $RESULTS_PATH \
--config-class ESM2FineTuneSeqConfig
```

This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/inference_results.pt` which can be loaded via PyTorch library in python environment:
This will create a result `.pt` file under `$WORKDIR/esm2_finetune_tutorial/predictions__rank_0.pt` which can be loaded via PyTorch library in python environment:

```python
import torch
# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/inference_results.pt
# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
results = torch.load(results_path)
# results is a python dict which includes the following result tensors for this example:
Expand Down
63 changes: 50 additions & 13 deletions docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
"source": [
"from bionemo.core.data.load import load\n",
"\n",
"checkpoint_path = load(\"esm2/650m:2.0\", source=\"ngc\")\n",
"checkpoint_path = load(\"esm2/650m:2.0\")\n",
"print(checkpoint_path)"
]
},
Expand Down Expand Up @@ -238,21 +238,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2024-11-25 21:18:43 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
"2024-11-25 21:18:43 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
"[NeMo W 2024-11-25 21:18:43 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
"2024-12-05 17:51:10 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
"2024-12-05 17:51:10 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
"[NeMo W 2024-12-05 17:51:10 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n",
" \n",
"[NeMo W 2024-12-05 17:51:11 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
" cm = get_cmap(\"Set1\")\n",
" \n",
"usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n",
" --results-path RESULTS_PATH\n",
" [--precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}]\n",
" [--num-gpus NUM_GPUS] [--num-nodes NUM_NODES]\n",
" [--micro-batch-size MICRO_BATCH_SIZE]\n",
" [--pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE]\n",
" [--tensor-model-parallel-size TENSOR_MODEL_PARALLEL_SIZE]\n",
" [--include-hiddens] [--include-input-ids]\n",
" [--include-embeddings] [--include-logits]\n",
" [--config-class CONFIG_CLASS]\n",
" [--prediction-interval {epoch,batch}] [--include-hiddens]\n",
" [--include-input-ids] [--include-embeddings]\n",
" [--include-logits] [--config-class CONFIG_CLASS]\n",
"\n",
"Infer ESM2.\n",
"\n",
Expand All @@ -264,7 +267,7 @@
" Path to the CSV file containing sequences and label\n",
" columns\n",
" --results-path RESULTS_PATH\n",
" Path to the results file.\n",
" Path to the results directory.\n",
" --precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}\n",
" Precision type to use for training.\n",
" --num-gpus NUM_GPUS Number of GPUs to use for training. Default is 1.\n",
Expand All @@ -277,6 +280,8 @@
" Pipeline model parallel size. Default is 1.\n",
" --tensor-model-parallel-size TENSOR_MODEL_PARALLEL_SIZE\n",
" Tensor model parallel size. Default is 1.\n",
" --prediction-interval {epoch,batch}\n",
" Intervals to write DDP predictions into disk\n",
" --include-hiddens Include hiddens in output of inference\n",
" --include-input-ids Include input_ids in output of inference\n",
" --include-embeddings Include embeddings in output of inference\n",
Expand Down Expand Up @@ -327,11 +332,11 @@
"source": [
"%%capture --no-display --no-stderr cell_output\n",
"\n",
"results_path = os.path.join(work_dir, \"inference_results.pt\")\n",
"\n",
"! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
" --data-path {data_path} \\\n",
" --results-path {results_path} \\\n",
" --results-path {work_dir} \\\n",
" --micro-batch-size 3 \\\n",
" --num-gpus 1 \\\n",
" --precision \"fp32\" \\\n",
" --include-hiddens \\\n",
" --include-embeddings \\\n",
Expand All @@ -350,7 +355,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The bash command in previous step creates the `inference_results.pt` file under the work directory of this notebook (defined above) to stores the results. The `.pt` file containes a dictionary of `{'result_key': torch.Tensor}` that be loaded with PyTorch:"
"Inference predictions are stored into `.pt` files for each device. Since we only used one device to run the inference (`--num-gpus 1`) in the previous step, the results were written to `{work_dir}/predictions__rank_0.pt` under the work directory of this notebook (defined above). The `.pt` file containes a dictionary of `{'result_key': torch.Tensor}` that be loaded with PyTorch:"
]
},
{
Expand All @@ -371,7 +376,7 @@
],
"source": [
"import torch\n",
"results = torch.load(results_path)\n",
"results = torch.load(f\"{work_dir}/predictions__rank_0.pt\")\n",
"\n",
"for key, val in results.items():\n",
" if val is not None:\n",
Expand Down Expand Up @@ -472,6 +477,38 @@
"mask = torch.isin(input_ids, torch.tensor(extra_indices))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DDP Inference Support\n",
"\n",
"Although this tutorial is utilizing one devive to run the inference, distributed inference is supported for ESM2 in BioNeMo Framework. One can simply set the the `--num-gpus n` to run distributed inference on `n` devices. The output predictions will be written into `predictions__rank_<0...n-1>.pt` under the `--results-path` provided. Moreover, by optionally including input token IDs with `--include-input-ids` we can snure 1:1 mapping between input sequences and output predictions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following snippet can be used to load and collate the predictions into a single dictionary.\n",
"\n",
"\n",
"```python\n",
"import glob\n",
"from bionemo.llm.lightning import batch_collator\n",
"\n",
"collated_preditions = batch_collator([torch.load(path) for path in glob.glob(f\"{work_dir}/predictions__rank_*.pt\")])\n",
"for key, val in collated_preditions.items():\n",
" if val is not None:\n",
" print(f'{key}\\t{val.shape}')\n",
"\n",
"# token_logits\ttorch.Size([1024, 10, 128])\n",
"# hidden_states\ttorch.Size([10, 1024, 1280])\n",
"# input_ids torch.Size([10, 1024])\n",
"# embeddings\ttorch.Size([10, 1280])\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
22 changes: 11 additions & 11 deletions docs/docs/user-guide/examples/bionemo-esm2/mutant-design.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"id": "dd6bed85-787b-4456-8426-55194da94852",
"metadata": {},
"source": [
"This notebbok should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](https://docs.nvidia.com/bionemo-framework/latest/user-guide/getting-started/initialization-guide/)."
"This notebbok should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at `/workspace/bionemo2`. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](../../getting-started/initialization-guide.md)."
]
},
{
Expand Down Expand Up @@ -206,7 +206,7 @@
"source": [
"from bionemo.core.data.load import load\n",
"\n",
"checkpoint_path = load(\"esm2/3b:2.0\", source=\"ngc\")\n",
"checkpoint_path = load(\"esm2/3b:2.0\")\n",
"print(checkpoint_path)"
]
},
Expand Down Expand Up @@ -378,11 +378,13 @@
"source": [
"%%capture --no-display --no-stderr cell_output\n",
"\n",
"results_path = os.path.join(work_dir, \"inference_results.pt\")\n",
"example_dir = os.path.join(work_dir, \"inference_example\")\n",
"os.makedirs(example_dir, exist_ok=True)\n",
"\n",
"! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
" --data-path {data_path} \\\n",
" --results-path {results_path} \\\n",
" --results-path {example_dir} \\\n",
" --num-gpus 1 \\\n",
" --precision \"fp32\" \\\n",
" --include-hiddens \\\n",
" --include-embeddings \\\n",
Expand All @@ -395,8 +397,7 @@
"id": "67d09581-e784-4ccc-be88-194c8909068c",
"metadata": {},
"source": [
"\n",
"This will write the output of ESM-2 inference into a python dictionary and save that into `inference_results.pt` which can be loaded via PyTorch:"
"This will write the output of ESM-2 inference into a python dictionary and save that into `predictions__rank_0.pt` which can be loaded via PyTorch. DDP inference is supported in BioNeMo Framework and can be utilized by setting `--num-gpus n` to use `n` devices. The output predictions are then written to n distinct files `predictions__rank_<0...n-1>.pt`. Please refer to [ESM-2 Inference Tutorial](./inference.ipynb) for more information regarding the DDP support and how to interpret the prediction outputs."
]
},
{
Expand All @@ -417,7 +418,7 @@
}
],
"source": [
"results = torch.load(results_path)\n",
"results = torch.load(f\"{example_dir}/predictions__rank_0.pt\")\n",
"\n",
"for key, val in results.items():\n",
" if val is not None:\n",
Expand Down Expand Up @@ -736,11 +737,10 @@
"source": [
"%%capture --no-display --no-stderr cell_output\n",
"\n",
"sequentially_masked_results_path = os.path.join(work_dir, \"sequentially_masked_inference_results.pt\")\n",
"\n",
"! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
" --data-path {masked_data_path} \\\n",
" --results-path {sequentially_masked_results_path} \\\n",
" --results-path {work_dir} \\\n",
" --num-gpus 1 \\\n",
" --precision \"fp32\" \\\n",
" --include-logits \\\n",
" --include-input-ids"
Expand All @@ -761,7 +761,7 @@
}
],
"source": [
"results = torch.load(sequentially_masked_results_path)\n",
"results = torch.load(f\"{work_dir}/predictions__rank_0.pt\")\n",
"logits = results['token_logits'].transpose(0, 1) # s, b, h -> b, s, h\n",
"\n",
"probs = logits_to_probs(logits)\n",
Expand Down
Loading

0 comments on commit c39b2d4

Please sign in to comment.