Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Datasets - update Module #1209

Merged
merged 9 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# flake8: noqa

from .utils import (
format_calibration_data,
get_calibration_dataloader,
get_processed_dataset,
make_dataset_splits,
)
191 changes: 191 additions & 0 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import re
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator

from llmcompressor.args import DatasetArguments
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor


def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor,
do_oneshot: bool = False,
do_train: bool = True,
) -> Optional[Dict[str, Dataset]]:
"""
Loads datasets for each flow based on dataset_args, stores a Dataset for each
enabled flow in datasets
:param dataset_args: DatasetArguments that contain dataset loading and
processing params
:param processor: processor or tokenizer to use for dataset tokenization
:param do_oneshot: True for oneshot pathway
:param do_train: True for train pathway
:return: A dataset corresponding to either train or calibration (oneshot)
"""
if dataset_args.dataset is None:
logger.warning(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return

splits = dataset_args.splits
tokenized_datasets = {}

def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
match = re.match(r"(\w*)\[.*\]", inp_str)
if match is not None:
return match.group(1)
return inp_str

if splits is None:
splits = {"all": None}
elif isinstance(splits, str):
splits = {_get_split_name(splits): splits}
elif isinstance(splits, List):
splits = {_get_split_name(s): s for s in splits}

# default to custom dataset if dataset provided isn't a string
registry_id = (
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
)
for split_name, split_str in splits.items():
dataset = dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
dataset_args=dataset_args,
split=split_str,
processor=processor,
)
tokenized_datasets[split_name] = dataset_manager(add_labels=do_train)

return make_dataset_splits(
tokenized_datasets,
do_oneshot=do_oneshot,
do_train=do_train,
)


def get_calibration_dataloader(
dataset_args: DatasetArguments,
processor: Processor,
) -> torch.utils.data.DataLoader:
"""
Get the dataloader used for oneshot calibration.
:param dataset_args: DatasetArguments that contains the dataset parameters.
:param processor: Processor or the tokenizer of the model.
:return: PyTorch dataloader object that contains the calibration dataset.
"""
if dataset_args.dataset is None:
# weight-only quantization or dynamic quantization
return

datasets = get_processed_dataset(
dataset_args=dataset_args,
processor=processor,
do_oneshot=True,
do_train=False,
)

calibration_dataset = datasets.get("calibration")

return format_calibration_data(
tokenized_dataset=calibration_dataset,
num_calibration_samples=dataset_args.num_calibration_samples,
do_shuffle=dataset_args.shuffle_calibration_samples,
collate_fn=dataset_args.data_collator,
)


def format_calibration_data(
tokenized_dataset: Dataset,
num_calibration_samples: Optional[int] = None,
do_shuffle: bool = True,
collate_fn: Callable = default_data_collator,
) -> List[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples
:param tokenized_dataset: dataset to convert to dataloader
:param num_calibration_samples: number of data samples to convert
:param do_shuffle: whether to shuffle the dataset before selecting calibration
samples, true by default
:param collate_fn: optional custom collate function, or use default
:return: list of trimmed calibration data tensors
"""
safe_calibration_samples = len(tokenized_dataset)
if num_calibration_samples is not None:
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
if safe_calibration_samples != num_calibration_samples:
logger.warn(
f"Requested {num_calibration_samples} calibration samples but "
f"the provided dataset only has {safe_calibration_samples}. "
)

if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

dataloader_params = {
"batch_size": 1,
"sampler": RandomSampler(tokenized_calibration)
if do_shuffle
else SequentialSampler(tokenized_calibration),
"collate_fn": collate_fn,
"pin_memory": True,
}

calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params)

return calibration_dataloader


def make_dataset_splits(
tokenized_datasets: Dict[str, Any],
do_oneshot: bool = True,
do_train: bool = False,
) -> Dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
train
:param tokenized_datasets: dictionary of processed datasets
:param do_oneshot: Whether to store the calibration dataset
:return: A dataset corresponding to either train or calibration (oneshot)
"""

# handles case where all splits are contained in a single dataset
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")
if isinstance(tokenized_datasets, Dataset):
tokenized_datasets = {"train": tokenized_datasets}

train_split = calib_split = None

if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_datasets["train"]
if do_oneshot:
calib_split = tokenized_datasets.get("calibration")
if calib_split is None:
if "train" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["train"]

split_datasets = {
"train": train_split,
"calibration": calib_split,
}
return split_datasets
4 changes: 1 addition & 3 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
Expand Down
Loading
Loading