Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Mar 5, 2025
1 parent ad2733f commit 9d1bfb3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 40 deletions.
40 changes: 3 additions & 37 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from typing import Optional

from typing import Optional, Tuple

from loguru import logger
from torch.utils.data import DataLoader
from transformers import PreTrainedModel

from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args


__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -68,7 +62,7 @@ class Oneshot:
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.
run(**kwargs):
__call__(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.
Expand All @@ -83,17 +77,6 @@ class Oneshot:
defined in the recipe. Each action is executed via the global
`CompressionSession`.
_pre_process():
Handles preprocessing steps, including model initialization,
tokenizer/processor setup, and resolving tied embedding issues.
check_tied_embeddings():
Logs a warning if `tie_word_embeddings=True`, which may interfere with
saving in the one-shot workflow.
_post_process():
Executes postprocessing steps such as saving the model and resetting
lifecycle actions, especially when a custom `output_dir` is specified.
"""

def __init__(
Expand Down Expand Up @@ -182,23 +165,6 @@ def __call__(self):
)
post_process(model_args=self.model_args, output_dir=self.output_dir)

def save(self):
"""
Saves the model and tokenizer/processor to the output directory.
The model is saved in a compressed format if specified in `model_args`.
The tokenizer or processor, if available, is also saved.
Raises:
ValueError: If saving fails due to an invalid `output_dir` or other issues.
"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
)
if self.processor is not None:
self.processor.save_pretrained(self.output_dir)

def apply_recipe_modifiers(
self,
calibration_dataloader: Optional[DataLoader],
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ def train(**kwargs):
"""
CLI entrypoint for running training
"""
model_args, dataset_args, recipe_args, training_args, _ = parse_args(**kwargs)
model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs)
training_args.do_train = True
main(model_args, dataset_args, recipe_args, training_args)



@deprecated(
message=(
"`from llmcompressor.transformers import oneshot` is deprecated, "
Expand All @@ -69,10 +68,11 @@ def apply(**kwargs):
CLI entrypoint for any of training, oneshot
"""
from llmcompressor.args import parse_args

model_args, dataset_args, recipe_args, training_args, _ = parse_args(
include_training_args=True, **kwargs
)

training_args.run_stages = True
report_to = kwargs.get("report_to", None)
if report_to is None: # user didn't specify any reporters
Expand Down

0 comments on commit 9d1bfb3

Please sign in to comment.