Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sequence packing support for SFTPackedDataset #275

Merged
merged 8 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Critic and Reward Model server refactored. Now the reward model will have a flag called `model.forward_micro_batch_size` which determines the micro batch size on which it runs inferences. This can be higher than the training micro batch size since during inference, we have less memory pressure.
- In the critic and reward model server, it is now possible to specify `inference_micro_batch_size` as a list. This allows us to provide more information to PyTriton regarding the preferred batch sizes for inference.
- It is no longer a requirement to specify `num_rollout_samples` to be a multiple of `inference_micro_batch_size * dp size` in PPO.
- Sequence packing is now supported when running SFT with SFTChatDataset.

### Breaking Changes
- `inference.micro_batch_size` is now renamed to `inference.inference_micro_batch_size` when running reward model inference in `inference_rm.yaml`. This is to stay consistent with the naming scheme of the PPO critic.
Expand Down
13 changes: 13 additions & 0 deletions docs/user-guide/sft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ This script converts the *Instruction*, *Context*, and *Response* fields into *I
"category": "closed_qa"
}

Sequence packing is also supported with prompt-response datasets. Sequence packing is a training technique in which multiple training examples are concatenated to create one longer sequence.
This approach eliminates the need for padding and improves GPU utilization. Refer to the `sequence packing documentation <https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html?highlight=packing#>`_ for a detailed overview of sequence packing and its advantages.

NeMo provides a script to pack your SFT prompt-response dataset. Refer to the `prepare dataset <https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html?highlight=packing#prepare-dataset>`_ section of the documentation for details on how to use this script.

Step 2: Run SFT training
^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -218,6 +223,14 @@ Now, you will use the data for supervised fine-tuning with NeMo-Aligner.
srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}"
set +x

If using sequence packing, replace the data paths with the paths to your packed datasets. For each packed dataset, you should also set ``packed_sequence=True`` in the config:

.. code-block:: python
+model.data.train_ds.packed_sequence=True \
+model.data.validation_ds.packed_sequence=True

It is not required to pack both the train and validation datasets. If packing only the train dataset, exclude ``+model.data.validation_ds.packed_sequence=True``.

To scale to thousands of GPUs, adjust the ``trainer.num_nodes`` and ``trainer.devices`` accordingly based on the size of your machine.

For this particular run on the 2B model, the final training loss is approximately 1.536. Once the training finishes, you’ll find a file called ``megatron_gpt_sft.nemo`` available for use.
Expand Down
20 changes: 18 additions & 2 deletions nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import get_indexed_dataset_
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset, GPTSFTPackedDataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
MegatronPretrainingRandomBatchSampler,
Expand Down Expand Up @@ -265,7 +265,22 @@ def build_dataset(index, name):


def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
dataset_cls = GPTSFTChatDataset if is_chat else GPTSFTDataset
packed_sequence = data_cfg.get("packed_sequence", False)
dataset_kwargs = {}

if is_chat:
assert not packed_sequence, "Sequence packing is currently not supported with chat datasets."
dataset_cls = GPTSFTChatDataset
elif packed_sequence:
dataset_cls = GPTSFTPackedDataset
# Whether to return `cu_seqlen` to pass to the model. Having `cu_seqlen` in the model input
# enables THD attention kernel, which is the correct format for training with packed sequence to prevent
# cross-sequence attention. This flag should be True unless you have a specific use case.
dataset_kwargs = {"return_cu_seqlen": data_cfg.get("packed_sequence_return_cu_seqlen", True)}
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
assert data_cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence"
else:
dataset_cls = GPTSFTDataset

dataset = dataset_cls(
file_path=data_cfg.file_path,
tokenizer=tokenizer,
Expand Down Expand Up @@ -295,6 +310,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
), # used to choose truncation method. Options: ['random', 'left', 'right']
special_tokens=special_tokens,
output_original_text=data_cfg.get("output_original_text", False),
**dataset_kwargs,
)
return dataset

Expand Down
Loading