Skip to content

Commit

Permalink
[Cosmetic] Rename data_args to dataset_args (#1206)
Browse files Browse the repository at this point in the history
Order of reviews:
#1206  <-- Here
#1207
#1209 
#1212
#1214 

SUMMARY:
Rename data_args to dataset_args

TEST PLAN:
Pass tests
FInd `data_args` using `grep`

---------

Signed-off-by: George Ohashi <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
  • Loading branch information
horheynm and dsikka authored Mar 5, 2025
1 parent 07726ef commit 391b202
Show file tree
Hide file tree
Showing 25 changed files with 256 additions and 226 deletions.
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
max_seq_length = 512

# Load gsm8k using SparseML dataset tools
data_args = DatasetArguments(
dataset_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
dataset_args.dataset,
dataset_args=dataset_args,
split="train",
processor=tokenizer,
)
Expand Down Expand Up @@ -69,7 +69,7 @@
train_dataset=train_dataset,
data_collator=data_collator,
trl_sft_config_args=trl_sft_config_args,
data_args=data_args,
dataset_args=dataset_args,
model_args=model_args,
)
trainer.train()
Expand Down
26 changes: 15 additions & 11 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Oneshot:
`kwargs` are parsed into:
- `model_args`: Arguments for loading and configuring a pretrained model
(e.g., `AutoModelForCausalLM`).
- `data_args`: Arguments for dataset-related configurations, such as
- `dataset_args`: Arguments for dataset-related configurations, such as
calibration dataloaders.
- `recipe_args`: Arguments for defining and configuring recipes that specify
optimization actions.
Expand Down Expand Up @@ -108,24 +108,23 @@ def __init__(
"""
Initializes the `Oneshot` class with provided arguments.
Parses the input keyword arguments into `model_args`, `data_args`, and
Parses the input keyword arguments into `model_args`, `dataset_args`, and
`recipe_args`. Performs preprocessing to initialize the model and
tokenizer/processor.
:param model_args: ModelArguments parameters, responsible for controlling
model loading and saving logic
:param data_args: DatasetArguments parameters, responsible for controlling
:param dataset_args: DatasetArguments parameters, responsible for controlling
dataset loading, preprocessing and dataloader loading
:param recipe_args: RecipeArguments parameters, responsible for containing
recipe-related parameters
:param output_dir: Path to save the output model after carrying out oneshot
"""

model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs)

self.model_args = model_args
self.data_args = dataset_args
self.dataset_args = dataset_args
self.recipe_args = recipe_args
self.output_dir = output_dir

Expand All @@ -136,14 +135,19 @@ def __init__(

@classmethod
def from_args(
cls, model_args, data_args, recipe_args, output_dir, do_preprocess: bool = True
cls,
model_args,
dataset_args,
recipe_args,
output_dir,
do_preprocess: bool = True,
):
"""
Used only for the stage runner to populate the args.
"""
instance = super().__new__(cls)
instance.model_args = model_args
instance.data_args = data_args
instance.dataset_args = dataset_args
instance.recipe_args = recipe_args
instance.output_dir = output_dir

Expand Down Expand Up @@ -176,7 +180,7 @@ def __call__(self):
self.processor = self.model_args.processor

calibration_dataloader = get_calibration_dataloader(
self.data_args, self.processor
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
Expand Down Expand Up @@ -242,7 +246,7 @@ def _pre_process(self):
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
behavior.
- Initializes the processor if specified as a path or `None`.
- Sets the minimum tokens per module if `data_args` are provided.
- Sets the minimum tokens per module if `dataset_args` are provided.
Raises:
FileNotFoundError: If the model or processor path is invalid.
Expand All @@ -265,8 +269,8 @@ def _pre_process(self):
self.processor = self.model_args.processor

# Set minimum tokens per module if data arguments are provided
if self.data_args:
self.min_tokens_per_module = self.data_args.min_tokens_per_module
if self.dataset_args:
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module

def check_tied_embeddings(self):
"""
Expand Down
76 changes: 40 additions & 36 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin):
3. Tokenize dataset using model tokenizer/processor
4. Apply post processing such as grouping text and/or adding labels for finetuning
:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
Expand All @@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DatasetArguments,
dataset_args: DatasetArguments,
split: str,
processor: Processor,
):
self.data_args = data_args
self.dataset_args = dataset_args
self.split = split
self.processor = processor

Expand All @@ -58,23 +58,23 @@ def __init__(
self.tokenizer.pad_token = self.tokenizer.eos_token

# configure sequence length
max_seq_length = data_args.max_seq_length
if data_args.max_seq_length > self.tokenizer.model_max_length:
max_seq_length = dataset_args.max_seq_length
if dataset_args.max_seq_length > self.tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"maximum length for model ({self.tokenizer.model_max_length}). "
f"Using max_seq_length={self.tokenizer.model_max_length}."
)
self.max_seq_length = min(
data_args.max_seq_length, self.tokenizer.model_max_length
dataset_args.max_seq_length, self.tokenizer.model_max_length
)

# configure padding
self.padding = (
False
if self.data_args.concatenate_data
if self.dataset_args.concatenate_data
else "max_length"
if self.data_args.pad_to_max_length
if self.dataset_args.pad_to_max_length
else False
)

Expand All @@ -83,7 +83,7 @@ def __init__(
self.padding = False

def __call__(self, add_labels: bool = True) -> DatasetType:
dataset = self.data_args.dataset
dataset = self.dataset_args.dataset

if isinstance(dataset, str):
# load dataset: load from huggingface or disk
Expand All @@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
dataset,
self.preprocess,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Preprocessing",
)
logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")
Expand All @@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
# regardless of `batched` argument
remove_columns=get_columns(dataset), # assumes that input names
# and output names are disjoint
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Tokenizing",
)
logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")

if self.data_args.concatenate_data:
if self.dataset_args.concatenate_data:
# postprocess: group text
dataset = self.map(
dataset,
self.group_text,
batched=True,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Concatenating data",
)
logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")
Expand All @@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
dataset,
self.add_labels,
batched=False, # not compatible with batching, need row lengths
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Adding labels",
)
logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")
Expand All @@ -165,27 +165,31 @@ def load_dataset(self):
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
if self.data_args.dataset_path is not None:
if self.data_args.dvc_data_repository is not None:
self.data_args.raw_kwargs["storage_options"] = {
"url": self.data_args.dvc_data_repository
if self.dataset_args.dataset_path is not None:
if self.dataset_args.dvc_data_repository is not None:
self.dataset_args.raw_kwargs["storage_options"] = {
"url": self.dataset_args.dvc_data_repository
}
self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path
self.dataset_args.raw_kwargs["data_files"] = (
self.dataset_args.dataset_path
)
else:
self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path(
self.data_args.dataset_path,
self.data_args.dataset
if hasattr(self.data_args, "dataset")
else self.data_args.dataset_name,
self.dataset_args.raw_kwargs["data_files"] = (
get_custom_datasets_from_path(
self.dataset_args.dataset_path,
self.dataset_args.dataset
if hasattr(self.dataset_args, "dataset")
else self.dataset_args.dataset_name,
)
)

logger.debug(f"Loading dataset {self.data_args.dataset}")
logger.debug(f"Loading dataset {self.dataset_args.dataset}")
return get_raw_dataset(
self.data_args,
self.dataset_args,
None,
split=self.split,
streaming=self.data_args.streaming,
**self.data_args.raw_kwargs,
streaming=self.dataset_args.streaming,
**self.dataset_args.raw_kwargs,
)

@cached_property
Expand All @@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
The function must return keys which correspond to processor/tokenizer kwargs,
optionally including PROMPT_KEY
"""
preprocessing_func = self.data_args.preprocessing_func
preprocessing_func = self.dataset_args.preprocessing_func

if callable(preprocessing_func):
return preprocessing_func
Expand All @@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
def rename_columns(self, dataset: DatasetType) -> DatasetType:
# rename columns to match processor/tokenizer kwargs
column_names = get_columns(dataset)
if self.data_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`")
dataset = dataset.rename_column(self.data_args.text_column, "text")
if self.dataset_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`")
dataset = dataset.rename_column(self.dataset_args.text_column, "text")

return dataset

Expand Down
14 changes: 8 additions & 6 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ class C4Dataset(TextGenerationDataset):
"""
Child text generation class for the C4 dataset
:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
def __init__(
self, dataset_args: "DatasetArguments", split: str, processor: Processor
):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "allenai/c4"
dataset_args.text_column = "text"

super().__init__(data_args=data_args, split=split, processor=processor)
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
14 changes: 8 additions & 6 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ class CNNDailyMailDataset(TextGenerationDataset):
"""
Text generation class for the CNN/DailyMail dataset
:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
def __init__(
self, dataset_args: "DatasetArguments", split: str, processor: Processor
):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "cnn_dailymail"
dataset_args.dataset_config_name = "3.0.0"

super().__init__(data_args=data_args, split=split, processor=processor)
super().__init__(dataset_args=dataset_args, split=split, processor=processor)

def dataset_template(self, sample):
return {
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class CustomDataset(TextGenerationDataset):
Child text generation class for custom local dataset supporting load
for csv and json
:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
Can also be set to None to load all the splits
:param processor: processor or tokenizer to use on dataset
Expand Down
Loading

0 comments on commit 391b202

Please sign in to comment.