diff --git a/src/llmcompressor/args/__init__.py b/src/llmcompressor/args/__init__.py index d60435c42..26ad530b6 100644 --- a/src/llmcompressor/args/__init__.py +++ b/src/llmcompressor/args/__init__.py @@ -4,3 +4,4 @@ from .model_arguments import ModelArguments from .recipe_arguments import RecipeArguments from .training_arguments import TrainingArguments +from .utils import parse_args diff --git a/src/llmcompressor/args/utils.py b/src/llmcompressor/args/utils.py new file mode 100644 index 000000000..810d2f6ab --- /dev/null +++ b/src/llmcompressor/args/utils.py @@ -0,0 +1,73 @@ +from typing import Tuple + +from loguru import logger +from transformers import HfArgumentParser + +from llmcompressor.args import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) +from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args + + +def parse_args( + include_training_args: bool = False, **kwargs +) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments, str]: + """ + Keyword arguments passed in from `oneshot` or `train` will + separate the arguments into the following: + + * ModelArguments in + src/llmcompressor/args/model_args.py + * DatasetArguments in + src/llmcompressor/args/dataset_args.py + * RecipeArguments in + src/llmcompressor/args/recipe_args.py + * TrainingArguments in + src/llmcompressor/args/training_args.py + + ModelArguments, DatasetArguments, and RecipeArguments are used for both + `oneshot` and `train`. TrainingArguments is only used for `train`. + + """ + + # pop output_dir, used as an attr in TrainingArguments, where oneshot is not used + output_dir = kwargs.pop("output_dir", None) + + parser_args = (ModelArguments, DatasetArguments, RecipeArguments) + if include_training_args: + parser_args += (TrainingArguments,) + + parser = HfArgumentParser(parser_args) + parsed_args = parser.parse_dict(kwargs) + + training_args = None + if include_training_args: + model_args, dataset_args, recipe_args, training_args = parsed_args + if output_dir is not None: + training_args.output_dir = output_dir + else: + model_args, dataset_args, recipe_args = parsed_args + + if recipe_args.recipe_args is not None: + if not isinstance(recipe_args.recipe_args, dict): + arg_dict = {} + for recipe_arg in recipe_args.recipe_args: + key, value = recipe_arg.split("=") + arg_dict[key] = value + recipe_args.recipe_args = arg_dict + + # raise depreciation warnings + if dataset_args.remove_columns is not None: + logger.warn( + "`remove_columns` argument is depreciated. When tokenizing datasets, all " + "columns which are invalid inputs the tokenizer will be removed", + DeprecationWarning, + ) + + # silently assign tokenizer to processor + resolve_processor_from_model_args(model_args) + + return model_args, dataset_args, recipe_args, training_args, output_dir diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 1440c08ad..ecdebf46b 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -1,11 +1,11 @@ from pathlib import PosixPath -from typing import Optional, Tuple +from typing import Optional from loguru import logger from torch.utils.data import DataLoader -from transformers import HfArgumentParser, PreTrainedModel +from transformers import PreTrainedModel -from llmcompressor.args import DatasetArguments, ModelArguments, RecipeArguments +from llmcompressor.args import parse_args from llmcompressor.core.session_functions import active_session from llmcompressor.transformers.finetune.data.data_helpers import ( get_calibration_dataloader, @@ -18,9 +18,8 @@ modify_save_pretrained, patch_tied_tensors_bug, ) -from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args -__all__ = ["Oneshot", "oneshot", "parse_oneshot_args"] +__all__ = ["Oneshot", "oneshot"] class Oneshot: @@ -123,10 +122,10 @@ def __init__( """ - model_args, data_args, recipe_args, output_dir = parse_oneshot_args(**kwargs) + model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs) self.model_args = model_args - self.data_args = data_args + self.data_args = dataset_args self.recipe_args = recipe_args self.output_dir = output_dir @@ -310,64 +309,3 @@ def oneshot(**kwargs) -> PreTrainedModel: one_shot() return one_shot.model - - -def parse_oneshot_args( - **kwargs, -) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, str]: - """ - Parses kwargs by grouping into model, data or training arg groups: - * model_args in - src/llmcompressor/transformers/utils/arg_parser/model_args.py - * data_args in - src/llmcompressor/transformers/utils/arg_parser/data_args.py - * recipe_args in - src/llmcompressor/transformers/utils/arg_parser/recipe_args.py - * training_args in - src/llmcompressor/transformers/utils/arg_parser/training_args.py - """ - output_dir = kwargs.pop("output_dir", None) - - parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) - - if not kwargs: - - def _get_output_dir_from_argv() -> Optional[str]: - import sys - - output_dir = None - if "--output_dir" in sys.argv: - index = sys.argv.index("--output_dir") - sys.argv.pop(index) - if index < len(sys.argv): # Check if value exists afer the flag - output_dir = sys.argv.pop(index) - - return output_dir - - output_dir = _get_output_dir_from_argv() or output_dir - parsed_args = parser.parse_args_into_dataclasses() - else: - parsed_args = parser.parse_dict(kwargs) - - model_args, data_args, recipe_args = parsed_args - - if recipe_args.recipe_args is not None: - if not isinstance(recipe_args.recipe_args, dict): - arg_dict = {} - for recipe_arg in recipe_args.recipe_args: - key, value = recipe_arg.split("=") - arg_dict[key] = value - recipe_args.recipe_args = arg_dict - - # raise depreciation warnings - if data_args.remove_columns is not None: - logger.warning( - "`remove_columns` argument is depreciated. When tokenizing datasets, all " - "columns which are invalid inputs the tokenizer will be removed", - DeprecationWarning, - ) - - # silently assign tokenizer to processor - resolve_processor_from_model_args(model_args) - - return model_args, data_args, recipe_args, output_dir diff --git a/tests/llmcompressor/entrypoints/test_oneshot.py b/tests/llmcompressor/entrypoints/test_oneshot.py index 4a7f2a5a7..1d00c828f 100644 --- a/tests/llmcompressor/entrypoints/test_oneshot.py +++ b/tests/llmcompressor/entrypoints/test_oneshot.py @@ -1,7 +1,7 @@ from transformers import AutoModelForCausalLM from llmcompressor import Oneshot -from llmcompressor.entrypoints.oneshot import parse_oneshot_args +from llmcompressor.args import parse_args def test_oneshot_from_args(): @@ -17,7 +17,7 @@ def test_oneshot_from_args(): output_dir = "bar_output_dir" - model_args, data_args, recipe_args, output_dir = parse_oneshot_args( + model_args, data_args, recipe_args, _, output_dir = parse_args( model=model, dataset=dataset, recipe=recipe,