Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Unifying Preprocess + Postprocessing logic for Train/Oneshot #1212

Merged
merged 13 commits into from
Mar 6, 2025
1 change: 1 addition & 0 deletions src/llmcompressor/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# flake8: noqa
from .oneshot import Oneshot, oneshot
from .utils import post_process, pre_process
116 changes: 5 additions & 111 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from torch.utils.data import DataLoader
from transformers import PreTrainedModel

from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.entrypoints.utils import post_process, pre_process

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -71,7 +62,7 @@ class Oneshot:
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.

run(**kwargs):
__call__(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.
Expand All @@ -86,17 +77,6 @@ class Oneshot:
defined in the recipe. Each action is executed via the global
`CompressionSession`.

_pre_process():
Handles preprocessing steps, including model initialization,
tokenizer/processor setup, and resolving tied embedding issues.

check_tied_embeddings():
Logs a warning if `tie_word_embeddings=True`, which may interfere with
saving in the one-shot workflow.

_post_process():
Executes postprocessing steps such as saving the model and resetting
lifecycle actions, especially when a custom `output_dir` is specified.
"""

def __init__(
Expand Down Expand Up @@ -151,7 +131,7 @@ def from_args(

# only run for the first oneshot call
if do_preprocess:
instance._pre_process()
pre_process(model_args)

# Set instance attributes
instance.model = instance.model_args.model
Expand All @@ -172,7 +152,7 @@ def __call__(self):
"""
# TODO: move back once stage runner is removed
# Preprocess the model and tokenizer/processor
self._pre_process()
pre_process(self.model_args)
self.model = self.model_args.model
self.recipe = self.recipe_args.recipe
self.processor = self.model_args.processor
Expand All @@ -183,24 +163,7 @@ def __call__(self):
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
)
self._post_process()

def save(self):
"""
Saves the model and tokenizer/processor to the output directory.

The model is saved in a compressed format if specified in `model_args`.
The tokenizer or processor, if available, is also saved.

Raises:
ValueError: If saving fails due to an invalid `output_dir` or other issues.
"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
)
if self.processor is not None:
self.processor.save_pretrained(self.output_dir)
post_process(model_args=self.model_args, output_dir=self.output_dir)

def apply_recipe_modifiers(
self,
Expand Down Expand Up @@ -236,75 +199,6 @@ def apply_recipe_modifiers(
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)

def _pre_process(self):
"""
Prepares the model and tokenizer/processor for calibration.

- Initializes the model if it's specified as a path or string.
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
behavior.
- Initializes the processor if specified as a path or `None`.
- Sets the minimum tokens per module if `dataset_args` are provided.

Raises:
FileNotFoundError: If the model or processor path is invalid.
"""
self.check_tied_embeddings()

# Initialize model
if isinstance(self.model_args.model, (str, PosixPath)):
self.model_args.model, _ = initialize_model_from_path(self.model_args)

patch_tied_tensors_bug(self.model_args.model)
modify_save_pretrained(self.model_args.model)

# Initialize processor
if isinstance(self.model_args.processor, (str, type(None))):
self.model_args.processor = initialize_processor_from_path(
self.model_args, self.model_args.model
)
# TODO: move to init once stage runner is removed
self.processor = self.model_args.processor

# Set minimum tokens per module if data arguments are provided
if self.dataset_args:
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module

def check_tied_embeddings(self):
"""
Logs a warning if the model has tied word embeddings.

The `tie_word_embeddings` flag may cause issues during saving in the one-shot
calibration workflow due to shared tensor addresses.
"""
if self.model_args.tie_word_embeddings:
logger.debug(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save."
)

def _post_process(self):
"""
Executes post-calibration steps.

This method saves the model and resets lifecycle actions if the `output_dir`
is not the default directory.

Raises:
ValueError: If saving fails due to invalid configurations.
"""
if self.output_dir is not None:
self.save()
return

logger.warning(
"Optimized model not saved. To save, please provide",
"`output_dir` as input arg.",
"Ex. `oneshot(..., output_dir=...)`",
)


def oneshot(**kwargs) -> PreTrainedModel:
one_shot = Oneshot(**kwargs)
Expand Down
Loading