-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathoneshot.py
207 lines (166 loc) · 7.24 KB
/
oneshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from typing import Optional
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
__all__ = ["Oneshot", "oneshot"]
class Oneshot:
"""
Class responsible for carrying out one-shot calibration on a pretrained model.
This class handles the entire lifecycle of one-shot calibration, including
preprocessing (model and tokenizer/processor initialization), model optimization
(quantization or sparsification), and postprocessing (saving outputs). The
intructions for model optimization can be specified by using a recipe.
- **Input Keyword Arguments:**
`kwargs` are parsed into:
- `model_args`: Arguments for loading and configuring a pretrained model
(e.g., `AutoModelForCausalLM`).
- `dataset_args`: Arguments for dataset-related configurations, such as
calibration dataloaders.
- `recipe_args`: Arguments for defining and configuring recipes that specify
optimization actions.
Parsers are defined in `src/llmcompressor/args/`.
- **Lifecycle Overview:**
The oneshot calibration lifecycle consists of three steps:
1. **Preprocessing**:
- Instantiates a pretrained model and tokenizer/processor.
- Ensures input and output embedding layers are untied if they share
tensors.
- Patches the model to include additional functionality for saving with
quantization configurations.
2. **Oneshot Calibration**:
- Optimizes the model using a global `CompressionSession` and applies
recipe-defined modifiers (e.g., `GPTQModifier`, `SparseGPTModifier`)
3. **Postprocessing**:
- Saves the model, tokenizer/processor, and configuration to the specified
`output_dir`.
- **Usage:**
```python
oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset)
oneshot()
# Access the processed components
model = oneshot.model
processor = oneshot.processor
recipe = oneshot.recipe
```
Methods:
__init__(**kwargs):
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.
__call__(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.
save():
Saves the calibrated model and tokenizer/processor to the specified
`output_dir`. Supports saving in compressed formats based on model
arguments.
apply_recipe_modifiers(calibration_dataloader, **kwargs):
Applies lifecycle actions (e.g., `initialize`, `finalize`) using modifiers
defined in the recipe. Each action is executed via the global
`CompressionSession`.
"""
def __init__(
self,
**kwargs,
):
"""
Initializes the `Oneshot` class with provided arguments.
Parses the input keyword arguments into `model_args`, `dataset_args`, and
`recipe_args`. Performs preprocessing to initialize the model and
tokenizer/processor.
:param model_args: ModelArguments parameters, responsible for controlling
model loading and saving logic
:param dataset_args: DatasetArguments parameters, responsible for controlling
dataset loading, preprocessing and dataloader loading
:param recipe_args: RecipeArguments parameters, responsible for containing
recipe-related parameters
:param output_dir: Path to save the output model after carrying out oneshot
"""
model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs)
self.model_args = model_args
self.dataset_args = dataset_args
self.recipe_args = recipe_args
self.output_dir = output_dir
# Set instance attributes
self.model = self.model_args.model
self.processor = self.model_args.processor
self.recipe = self.recipe_args.recipe
@classmethod
def from_args(
cls,
model_args,
dataset_args,
recipe_args,
output_dir,
do_preprocess: bool = True,
):
"""
Used only for the stage runner to populate the args.
"""
instance = super().__new__(cls)
instance.model_args = model_args
instance.dataset_args = dataset_args
instance.recipe_args = recipe_args
instance.output_dir = output_dir
# only run for the first oneshot call
if do_preprocess:
pre_process(model_args)
# Set instance attributes
instance.model = instance.model_args.model
instance.recipe = instance.recipe_args.recipe
instance.processor = instance.model_args.processor
return instance
def __call__(self):
"""
Performs one-shot calibration.
This method prepares a calibration dataloader using dataset arguments and
applies recipe-based modifiers to optimize the model. The lifecycle actions
are executed sequentially, and the modified model is saved during
postprocessing.
"""
# TODO: move back once stage runner is removed
# Preprocess the model and tokenizer/processor
pre_process(self.model_args)
self.model = self.model_args.model
self.recipe = self.recipe_args.recipe
self.processor = self.model_args.processor
calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
)
post_process(model_args=self.model_args, output_dir=self.output_dir)
def apply_recipe_modifiers(
self,
calibration_dataloader: Optional[DataLoader],
recipe_stage: Optional[str] = None,
):
"""
Applies recipe modifiers to the model during the lifecycle.
The modifiers are defined in the recipe and executed via lifecycle actions
(`initialize`, `finalize`) through the global `CompressionSession`.
:param: calibration_dataloader: Dataloader for calibration data.
Raises:
RuntimeError: If any modifier fails during execution.
"""
session = active_session()
session_kwargs = dict(
model=self.model,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
recipe_stage=recipe_stage,
)
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)
def oneshot(**kwargs) -> PreTrainedModel:
one_shot = Oneshot(**kwargs)
one_shot()
return one_shot.model