Skip to content

Commit

Permalink
Merge pull request #85 from kadirnar/add_metric_md
Browse files Browse the repository at this point in the history
Add Readme file for metric results
  • Loading branch information
kadirnar committed May 5, 2024
2 parents e8cbada + 9ee5370 commit 02688b9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 57 deletions.
8 changes: 8 additions & 0 deletions benckmarks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
**Metrics on Mozilla-Foundation/Common-Voice-17-0 dataset**

| Model | Metric Value |
| ------------------------------------ | ------------ |
| distil-whisper/distil-large-v3 | 120.48 |
| distil-whisper/distil-large-v3 + Hqq | 120.88 |

Hqq: https://github.com/mobiusml/hqq/
125 changes: 68 additions & 57 deletions whisperplus/audio_eval.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,83 @@
import torch
from datasets import load_dataset
from evaluate import load
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.utils.patching import prepare_for_inference
from pipelines.whisper import SpeechToTextPipeline
from tqdm import tqdm
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, BitsAndBytesConfig, HqqConfig, pipeline
from transformers.pipelines.pt_utils import KeyDataset

from whisperplus.pipelines.whisper import SpeechToTextPipeline
model_id = "distil-whisper/distil-large-v3"
processor = AutoProcessor.from_pretrained(model_id)

HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

model_id = "distil-whisper/distil-large-v3"
def hqq_load_model():
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.utils.patching import prepare_for_inference

hqq_config = HqqConfig(
nbits=4,
group_size=64,
quant_zero=False,
quant_scale=False,
axis=0,
offload_meta=False,
) # axis=0 is used by default
HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(
HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
hqq_config = HqqConfig(
nbits=4,
group_size=64,
quant_zero=False,
quant_scale=False,
axis=0,
offload_meta=False,
) # axis=0 is used by default

model = SpeechToTextPipeline(model_id="distil-whisper/distil-large-v3", quant_config=hqq_config)
model = SpeechToTextPipeline(model_id="distil-whisper/distil-large-v3", quant_config=hqq_config)
model = model.model
return model

model = model.model

processor = AutoProcessor.from_pretrained(model_id)
def base_load_model():
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_safetensors=True,
device="cuda:0",
)

return model


def calculate_metrics(model, processor, model_id):
pipe = pipeline(
"automatic-speech-recognition",
model=model,
torch_dtype=torch.bfloat16,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
model_kwargs={"use_flash_attention_2": True},
)

wer_metric = load("wer")

common_voice_test = load_dataset(
"mozilla-foundation/common_voice_17_0", # mozilla-foundation/common_voice_17_0
"dv",
split="test")

all_predictions = []

# run streamed inference
for prediction in tqdm(
pipe(
KeyDataset(common_voice_test, "audio"),
max_new_tokens=128,
generate_kwargs={"task": "transcribe"},
batch_size=32,
),
total=len(common_voice_test),
):
all_predictions.append(prediction["text"])

wer_ortho = 100 * wer_metric.compute(
references=common_voice_test["sentence"], predictions=all_predictions)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
torch_dtype=torch.bfloat16,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
model_kwargs={"use_flash_attention_2": True},
)

wer_metric = load("wer")

common_voice_test = load_dataset(
"mozilla-foundation/common_voice_17_0", # mozilla-foundation/common_voice_17_0
"dv",
split="test")

all_predictions = []

# run streamed inference
for prediction in tqdm(
pipe(
KeyDataset(common_voice_test, "audio"),
max_new_tokens=128,
generate_kwargs={"task": "transcribe"},
batch_size=32,
),
total=len(common_voice_test),
):
all_predictions.append(prediction["text"])

wer_ortho = 100 * wer_metric.compute(references=common_voice_test["sentence"], predictions=all_predictions)

print(f"WER: {wer_ortho:.2f}%")
print(f"WER: {wer_ortho:.2f}%")
return wer_ortho

0 comments on commit 02688b9

Please sign in to comment.