Skip to content

Commit de38a64

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 29f93d3 commit de38a64

File tree

7 files changed

+43
-19
lines changed

7 files changed

+43
-19
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from llmcompressor.transformers import oneshot
66

77
# Select model and load it.
8-
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
8+
#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
9+
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
910

1011
model = AutoModelForCausalLM.from_pretrained(
1112
MODEL_ID,
@@ -22,6 +23,7 @@
2223
# Increasing the number of samples can improve accuracy.
2324
NUM_CALIBRATION_SAMPLES = 512
2425
MAX_SEQUENCE_LENGTH = 2048
26+
BATCH_SIZE = 2
2527

2628
# Load dataset and preprocess.
2729
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
@@ -64,6 +66,7 @@ def tokenize(sample):
6466
recipe=recipe,
6567
max_seq_length=MAX_SEQUENCE_LENGTH,
6668
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
per_device_oneshot_batch_size=BATCH_SIZE,
6770
)
6871

6972
# Confirm generations of the quantized model look sane.

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
254254
)
255255
if isinstance(exception, unfixable_errors):
256256
raise exception
257+
258+
raise exception
257259

258260
warnings.warn("Falling back to layer_sequential pipeline")
259261
try:

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def trace_subgraphs(
7171
concrete_args = populate_concrete_args(model, sample_input)
7272

7373
# trace
74+
breakpoint()
7475
with (
7576
calibration_forward_context(model),
7677
HooksMixin.disable_hooks(),

src/llmcompressor/transformers/finetune/data/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@ def __init__(
5353
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
5454

5555
if self.tokenizer is not None:
56-
# fill in pad token
57-
if not self.tokenizer.pad_token:
58-
self.tokenizer.pad_token = self.tokenizer.eos_token
59-
60-
# configure sequence length
56+
# resolve sequence length
6157
max_seq_length = data_args.max_seq_length
6258
if data_args.max_seq_length > self.tokenizer.model_max_length:
6359
logger.warning(
@@ -69,7 +65,7 @@ def __init__(
6965
data_args.max_seq_length, self.tokenizer.model_max_length
7066
)
7167

72-
# configure padding
68+
# resolve padding
7369
self.padding = (
7470
False
7571
if self.data_args.concatenate_data

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def format_calibration_data(
3030
batch_size: int = 1,
3131
do_shuffle: bool = True,
3232
processor: Optional[Processor] = None,
33-
collate_fn: Callable = default_data_collator,
33+
collate_fn: Optional[Callable] = None,
3434
accelerator: Optional[Any] = None,
3535
) -> List[torch.Tensor]:
3636
"""
@@ -41,7 +41,9 @@ def format_calibration_data(
4141
:param num_calibration_samples: number of data samples to convert
4242
:param do_shuffle: whether to shuffle the dataset before selecting calibration
4343
samples, true by default
44-
:param collate_fn: optional custom collate function, or use default
44+
:param collate_fn: optional custom collate function, defaults to
45+
`DataCollatorWithPadding` if None is provided. uses . If the tokenizer fails to
46+
resolve, then `default_data_collator` is used
4547
:param accelerator: optional accelerator for if preparing in FSDP mode
4648
:return: list of trimmed calibration data tensors
4749
"""
@@ -61,16 +63,17 @@ def format_calibration_data(
6163
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
6264

6365
# collate data
66+
breakpoint()
6467
if collate_fn is None:
6568
tokenizer = getattr(processor, "tokenizer", processor)
66-
if tokenizer is None:
69+
if hasattr(tokenizer, "pad"):
70+
collate_fn = DataCollatorWithPadding(tokenizer)
71+
else:
6772
warnings.warn(
6873
"Could not find processor, attempting to collate with without padding "
6974
"(may fail for batch_size > 1)"
7075
)
71-
return default_data_collator()
72-
73-
collate_fn = DataCollatorWithPadding(tokenizer)
76+
collate_fn = default_data_collator
7477

7578
dataloader_params = {
7679
"batch_size": batch_size,

src/llmcompressor/transformers/finetune/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ def __init__(
4949
data_args: "DataTrainingArguments",
5050
model_args: "ModelArguments",
5151
training_args: "TrainingArguments",
52+
processor: Processor,
5253
):
5354
self._data_args = data_args
5455
self._model_args = model_args
5556
self._training_args = training_args
5657

5758
self.datasets = {}
5859
self.trainer = None
59-
self.processor = None
60+
self.processor = processor
6061
self.parent_output_dir = self._training_args.output_dir
6162
self._output_dir = self._training_args.output_dir
6263

@@ -68,8 +69,8 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True):
6869
:param processor: processor or tokenizer to use for dataset tokenization
6970
:param add_labels: if True, add labels column to dataset splits
7071
"""
72+
# TODO: remove `processor` arg in favor of self.processor
7173
if self._data_args.dataset is None:
72-
self.processor = self._model_args.processor
7374
logger.info(
7475
"Running oneshot without calibration data. This is expected for "
7576
"weight-only and dynamic quantization"
@@ -110,7 +111,7 @@ def _get_split_name(inp_str):
110111
registry_id,
111112
data_args=self._data_args,
112113
split=split_str,
113-
processor=processor,
114+
processor=self.processor,
114115
)
115116
tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels)
116117

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import warnings
2222
from pathlib import PosixPath
23+
from types import NoneType
2324

2425
from loguru import logger
2526
from transformers import (
@@ -286,6 +287,20 @@ def initialize_processor_from_path(
286287
return processor
287288

288289

290+
def configure_processor(processor: Processor):
291+
# configure tokenizer pad_token, required for padding and data collation
292+
tokenizer = getattr(processor, "tokenizer", processor)
293+
if getattr(tokenizer, "pad_token", None) is None:
294+
if hasattr(tokenizer, "eos_token"):
295+
logger.debug("Tokenizer is missing pad_token, using eos_token instead")
296+
tokenizer.pad_token = tokenizer.eos_token
297+
else:
298+
logger.debug(
299+
"Tokenizer is missing pad_token and eos_token, this may lead to issues "
300+
" when padding"
301+
)
302+
303+
289304
def main(
290305
model_args: ModelArguments,
291306
data_args: DataTrainingArguments,
@@ -361,8 +376,9 @@ def main(
361376
teacher.eval()
362377

363378
processor = model_args.processor
364-
if isinstance(processor, str) or processor is None:
379+
if isinstance(processor, (str, NoneType)):
365380
processor = initialize_processor_from_path(model_args, model, teacher)
381+
configure_processor(processor)
366382

367383
pre_initialize_structure(model=model)
368384

@@ -371,10 +387,12 @@ def main(
371387

372388
# Load datasets
373389
stage_runner = StageRunner(
374-
model_args=model_args, data_args=data_args, training_args=training_args
390+
model_args=model_args, data_args=data_args, training_args=training_args, processor=processor
375391
)
376392
add_labels = training_args.do_train or training_args.run_stages
377-
stage_runner.populate_datasets(processor=processor, add_labels=add_labels)
393+
stage_runner.populate_datasets(
394+
processor=processor, add_labels=add_labels
395+
)
378396
train_dataset = stage_runner.get_dataset_split("train")
379397
eval_dataset = stage_runner.get_dataset_split("validation")
380398
calib_dataset = stage_runner.get_dataset_split("calibration")

0 commit comments

Comments
 (0)