Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Training Pipeline #1214

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
active_session,
callbacks,
create_session,
finalize,
initialize,
reset_session,
)
from llmcompressor.entrypoints import Oneshot, oneshot
from llmcompressor.entrypoints import Oneshot, oneshot, train
4 changes: 0 additions & 4 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
active_session,
callbacks,
create_session,
finalize,
initialize,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -35,8 +33,6 @@
"create_session",
"active_session",
"reset_session",
"initialize",
"finalize",
"apply",
"callbacks",
"LifecycleCallbacks",
Expand Down
78 changes: 1 addition & 77 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import threading
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
from llmcompressor.core.state import ModifiedState
from llmcompressor.recipe import Recipe

__all__ = [
"create_session",
"active_session",
"reset_session",
"initialize",
"finalize",
"callbacks",
"LifecycleCallbacks",
]
Expand Down Expand Up @@ -58,79 +55,6 @@ def reset_session():
session._lifecycle.reset()


def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Optional[Dict[str, Any]] = None,
model: Optional[Any] = None,
teacher_model: Optional[Any] = None,
optimizer: Optional[Any] = None,
attach_optim_callbacks: bool = True,
train_data: Optional[Any] = None,
val_data: Optional[Any] = None,
test_data: Optional[Any] = None,
calib_data: Optional[Any] = None,
copy_data: bool = True,
start: Optional[float] = None,
steps_per_epoch: Optional[int] = None,
batches_per_step: Optional[int] = None,
**kwargs,
) -> ModifiedState:
"""
A method to initialize the active session for sparsification

:param recipe: the recipe to use for the sparsification, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list of recipe objects.
:param recipe_stage: the stage to target for the sparsification
:param recipe_args: the args to use for overriding the recipe defaults
:param model: the model to sparsify
:param teacher_model: the teacher model to use for knowledge distillation
:param optimizer: the optimizer to use for the sparsification
:param attach_optim_callbacks: True to attach the optimizer callbacks to the
sparsification lifecycle, False otherwise
:param train_data: the training data to use for the sparsification
:param val_data: the validation data to use for the sparsification
:param test_data: the testing data to use for the sparsification
:param calib_data: the calibration data to use for the sparsification
:param copy_data: True to copy the data, False otherwise
:param start: the start epoch to use for the sparsification
:param steps_per_epoch: the number of steps per epoch to use for the
sparsification
:param batches_per_step: the number of batches per step to use for
sparsification
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
:return: the modified state of the active session after initializing
"""
return active_session().initialize(
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
model=model,
teacher_model=teacher_model,
optimizer=optimizer,
attach_optim_callbacks=attach_optim_callbacks,
train_data=train_data,
val_data=val_data,
test_data=test_data,
calib_data=calib_data,
copy_data=copy_data,
start=start,
steps_per_epoch=steps_per_epoch,
batches_per_step=batches_per_step,
**kwargs,
)


def finalize(**kwargs) -> ModifiedState:
"""
Method to finalize the active session for sparsification

:param kwargs: additional kwargs to pass to the lifecycle's finalize method
:return: the modified state of the active session after finalizing
"""
return active_session().finalize(**kwargs)


class LifecycleCallbacks:
"""
A class for invoking lifecycle events for the active session
Expand Down
Loading
Loading