Skip to content

Commit

Permalink
asdf
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Jan 21, 2025
1 parent 36ec9f0 commit d11af96
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 26 deletions.
17 changes: 12 additions & 5 deletions examples/multimodal_audio/qwen2_audio_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from datasets import load_dataset
from transformers import AutoProcessor

Expand All @@ -6,7 +7,6 @@
from llmcompressor.transformers.tracing import (
TraceableQwen2AudioForConditionalGeneration,
)
from llmcompressor.transformers.utils.data_collator import qwen2_audio_data_collator

# Select model and load it.
MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
Expand Down Expand Up @@ -67,14 +67,21 @@ def tokenize(sample):

ds = ds.map(tokenize, remove_columns=ds.column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"re:audio_tower.*",
"re:multi_modal_projector.*",
# "re:audio_tower.*",
#"re:multi_modal_projector.*",
"lm_head",
], # TODO: honestly, there's a decent number of parameters in the audio tower worth quantizing
)
Expand All @@ -86,14 +93,14 @@ def tokenize(sample):
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
data_collator=qwen2_audio_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
breakpoint()
sample_input = qwen2_audio_data_collator([next(iter(ds))])
sample_input = data_collator([next(iter(ds))])
sample_input = {k: v.to(model.device) for k, v in sample_input.items()}
output = model.generate(**sample_input)
print(processor.batch_decode(output, skip_special_tokens=True)[0])
Expand Down
10 changes: 8 additions & 2 deletions examples/multimodal_audio/whisper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableWhisperForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import whisper_data_collator

# Select model and load it.
MODEL_ID = "openai/whisper-large-v2"
Expand Down Expand Up @@ -70,6 +69,13 @@ def process(sample):

ds = ds.map(process, remove_columns=ds.column_names)


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
Expand All @@ -81,7 +87,7 @@ def process(sample):
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
data_collator=whisper_data_collator,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
Expand Down
19 changes: 0 additions & 19 deletions src/llmcompressor/transformers/utils/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,3 @@ def phi3_vision_data_collator(batch):
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
"image_sizes": torch.tensor(batch[0]["image_sizes"]),
}


def whisper_data_collator(batch):
assert len(batch) == 1
return {
"input_features": torch.tensor(batch[0]["input_features"]),
"decoder_input_ids": torch.tensor(batch[0]["decoder_input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
}


def qwen2_audio_data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"input_features": torch.tensor(batch[0]["input_features"]),
"feature_attention_mask": torch.tensor(batch[0]["feature_attention_mask"]),
}

0 comments on commit d11af96

Please sign in to comment.