Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pre-tokenized streaming dataset finetuning #601

Closed
wants to merge 11 commits into from
24 changes: 23 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from typing import Any, Callable, Dict, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import torch
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -67,6 +69,19 @@ def _tokenize_formatted_example(
return tokenizer(text=example['prompt'], text_target=example['response'])


def _read_binary_tokenized_sample(sample: Dict[str, Any]):
example = {
'input_ids':
torch.from_numpy(
np.frombuffer(sample['tokens'], dtype=np.int64).copy()),
'labels':
torch.from_numpy(
np.frombuffer(sample['labels'], dtype=np.int64).copy()),
}
example['attention_mask'] = torch.ones(example['input_ids'].size())
return example


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down Expand Up @@ -185,7 +200,14 @@ def __init__(self,
# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
if 'prompt' in sample and 'response' in sample:
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
elif 'tokens' in sample and 'labels' in sample:
return _read_binary_tokenized_sample(sample)
else:
raise RuntimeError(
'FineTurningDataset needs samples to have prompt/response columns ' +\
boomanaiden154 marked this conversation as resolved.
Show resolved Hide resolved
'or tokens/labels columns')


class DatasetConstructor:
Expand Down
62 changes: 62 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest
import torch
import numpy
from composer.utils import dist, using_torch_2
from omegaconf import OmegaConf as om
from streaming import MDSWriter
Expand Down Expand Up @@ -62,6 +63,25 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str):
output_writer.write(sample)


def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str):
columns = {'tokens': 'bytes', 'labels': 'bytes'}

dataset = [{
'tokens': numpy.asarray([1, 2, 3, 4]).tobytes(),
'labels': numpy.asarray([2, 3, 4, 5]).tobytes()
}, {
'tokens': numpy.asarray([2, 3, 4, 5]).tobytes(),
'labels': numpy.asarray([3, 4, 5, 6]).tobytes()
}]

output_path = os.path.join(data_path, split)

with MDSWriter(columns=columns, out=output_path,
compression=None) as output_writer:
for sample in dataset:
output_writer.write(sample)


@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@pytest.mark.parametrize('pretokenize', [False, True])
def test_correct_padding(tokenizer_name: str,
Expand Down Expand Up @@ -474,6 +494,48 @@ def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
_ = build_finetuning_dataloader(cfg, tokenizer, 4)


def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path):
remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

build_mock_tokenized_ft_streaming_dataset(remote_path, 'train')

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': 2048},
)

ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 4)

expected_keys = ['input_ids', 'attention_mask', 'labels']

for batch in ft_dataloader:
for k in expected_keys:
assert k in batch
boomanaiden154 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
@pytest.mark.parametrize('add_bad_data_error', [True, False])
def test_malformed_data(
Expand Down
Loading