diff --git a/examples/online_serving/whisper_api_client.py b/examples/online_serving/whisper_api_client.py new file mode 100644 index 0000000000000..654832b89ee37 --- /dev/null +++ b/examples/online_serving/whisper_api_client.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import asyncio +import json +from subprocess import CalledProcessError, run + +import aiohttp +import numpy as np + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +SAMPLE_RATE = 16000 + + +def load_audio_from_file(file: str, sample_rate: int = SAMPLE_RATE): + cmd = [ + "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", + "-ac", "1", "-acodec", "pcm_s16le", "-ar", + str(sample_rate), "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +async def iterate_response(response): + output_text = "" + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + output_text += json.loads(chunk)["text"] + return output_text + + +async def transcribe_from_waveform(base_url: str, file_path: str): + """Send waveform to API Server for transcription.""" + + waveform = load_audio_from_file(file_path, SAMPLE_RATE) + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + + url = f"{base_url}/generate_from_waveform" + data = { + "waveform_bytes": waveform.tobytes(), + "sampling_rate": str(SAMPLE_RATE) + } + async with session.post(url, data=data) as response: + output = await iterate_response(response) + return output + + +async def transcribe_from_file(base_url: str, file_path: str): + """Send file to API Server for transcription.""" + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + + url = f"{base_url}/generate_from_file" + with open(file_path, 'rb') as f: + async with session.post(url, data={'file': f}) as response: + output = await iterate_response(response) + print(output) + + +parser = argparse.ArgumentParser() +parser.add_argument("--filepath", type=str, default="1221-135766-0002.wav") +parser.add_argument("--send-waveform", action="store_true") +parser.add_argument("--host", type=str, default="localhost") +parser.add_argument("--port", type=int, default=8000) + +if __name__ == "__main__": + args = parser.parse_args() + api_url = f"http://{args.host}:{args.port}" + + if args.send_waveform: + asyncio.run( + transcribe_from_waveform(base_url=api_url, + file_path=args.filepath)) + else: + asyncio.run( + transcribe_from_file(base_url=api_url, file_path=args.filepath)) diff --git a/examples/online_serving/whisper_wer_test.py b/examples/online_serving/whisper_wer_test.py new file mode 100644 index 0000000000000..449385a2afb34 --- /dev/null +++ b/examples/online_serving/whisper_wer_test.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +""" + +import asyncio +import json +import time +from argparse import ArgumentParser +from statistics import mean, median +from typing import List, Optional + +import aiohttp +import librosa +import numpy as np +import torch +from datasets import load_dataset +from evaluate import load +from transformers import AutoTokenizer, PreTrainedTokenizer + +WHISPER_SAMPLING_RATE = 16000 + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +async def iterate_response(response) -> str: + output_text = "" + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + output_text += json.loads(chunk)["text"] + return output_text + + +async def _transcribe_from_waveform(base_url: str, waveform: np.array, + sr: int) -> str: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + + assert sr == WHISPER_SAMPLING_RATE + url = f"{base_url}/generate_from_waveform" + data = {"waveform_bytes": waveform.tobytes(), "sampling_rate": str(sr)} + async with session.post(url, data=data) as response: + return await iterate_response(response) + + +async def transcribe(tokenizer: PreTrainedTokenizer, sem: asyncio.Semaphore, + base_url: str, waveform: np.ndarray, sampling_rate: int, + reference: str): + + # Use semaphore to limit concurrent requests. + async with sem: + start = time.perf_counter() + transcribed_text = await _transcribe_from_waveform( + base_url=base_url, + waveform=waveform, + sr=sampling_rate, + ) + latency = time.perf_counter() - start + + num_tokens = len( + tokenizer(transcribed_text, add_special_tokens=False).input_ids) + + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(transcribed_text) + ref = tokenizer.normalize(reference) + return latency, num_tokens, out, ref + + +async def process_dataset(model_name, + data, + concurrent_request, + base_url="http://localhost:8000"): + tokenizer = AutoTokenizer.from_pretrained(model_name) + + sem = asyncio.Semaphore(concurrent_request) + tasks: List[asyncio.Task] = [] + for sample in data: + waveform = sample["audio"]["array"].astype(np.float32) + sampling_rate = sample["audio"]["sampling_rate"] + reference = sample["text"] + assert sampling_rate == WHISPER_SAMPLING_RATE + + task = asyncio.create_task( + transcribe(tokenizer, sem, base_url, waveform, sampling_rate, + reference)) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def to_float32(sample): + sample["audio"]["array"] = sample["audio"]["array"].astype(np.float32) + return sample + + +def load_hf_dataset(dataset_repo: str, + dataset_name: str, + split="validation", + **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, + dataset_name, + split=split, + **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + + return dataset + + +def run_evaluation(model: str, + dataset, + n_examples: int = -1, + max_concurrent_reqs: Optional[int] = None, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + + # Warmup + _ = asyncio.run( + process_dataset(model, dataset.select(range(1)), max_concurrent_reqs)) + + start = time.perf_counter() + results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +if __name__ == "__main__": + args = ArgumentParser() + # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo". + args.add_argument("-m", + "--model-name", + type=str, + help="Name of the ASR model to evaluate.", + default="openai/whisper-large-v3") + args.add_argument("-dr", + "--dataset-repo", + type=str, + help="Path/repo of the hf asr dataset to test on.") + args.add_argument("-dn", + "--dataset-name", + type=str, + help="Name of the hf asr dataset to test on.") + args.add_argument("--n-examples", + type=int, + help="Limit the number of examples to evaluate on.", + default=-1) + args.add_argument( + "--max-concurrent-request", + type=int, + help="Limit the number of requests sent to the server at the same time" + ) + args.add_argument("--expected-wer", + type=float, + help="Expected WER to compare against.") + args.add_argument( + "--extra", + nargs="*", + help="Extra keyword arguments (key=value pairs) to be passed " + "to hf `load_dataset`") + args = args.parse_args() + + extra_kwargs = {} + if args.extra: + for item in args.extra: + key, value = item.split("=", 1) + extra_kwargs[key] = value + + print("Running evaluation with args", vars(args)) + dataset = load_hf_dataset(args.dataset_repo, args.dataset_name, + **extra_kwargs) + + if not args.max_concurrent_request: + # No max concurrency + args.max_concurrent_request = args.n_examples if args.n_examples > 0\ + else len(dataset) + + wer = run_evaluation(args.model_name, dataset, args.n_examples, + args.max_concurrent_request) + if args.expected_wer: + torch.testing.assert_close(wer, + args.expected_wer, + atol=1e-1, + rtol=1e-2) diff --git a/vllm/entrypoints/whisper_server/__init__.py b/vllm/entrypoints/whisper_server/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/whisper_server/api_server.py b/vllm/entrypoints/whisper_server/api_server.py new file mode 100644 index 0000000000000..59c7dd0ca4662 --- /dev/null +++ b/vllm/entrypoints/whisper_server/api_server.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import multiprocessing +import signal +from argparse import Namespace +from contextlib import asynccontextmanager +from functools import partial +from typing import Annotated, Any, AsyncGenerator, AsyncIterator + +import numpy as np +import uvloop +from fastapi import FastAPI, Form, Request, UploadFile +from fastapi.responses import Response, StreamingResponse +from pydantic import BaseModel + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.entrypoints.whisper_server.helper import (load_audio_from_bytes, + validate_length) +from vllm.logger import init_logger +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, + random_uuid, set_ulimit) +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server_whisper") + +TRANSCRIBE_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" +SAMPLING_RATE = 16000 +TIMEOUT_KEEP_ALIVE = 5 # seconds. + +app = FastAPI() + + +def format_prompt(waveform: np.ndarray, sampling_rate: int): + assert sampling_rate == SAMPLING_RATE + return { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (waveform, sampling_rate), + } + }, + "decoder_prompt": TRANSCRIBE_PROMPT, + } + + +class TranscriptionResponse(BaseModel): + """The response object from the transcription.""" + text: str + + +class TranscribeFromFile(BaseModel): + """The audio file (flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm).""" + file: UploadFile + + async def to_prompt(self): + audio_bytes = await self.file.read() + audio_data = load_audio_from_bytes(audio_bytes, SAMPLING_RATE) + validate_length(audio_data) + return format_prompt(audio_data, SAMPLING_RATE) + + +class TranscribeFromWaveform(BaseModel): + """Numpy array of audio waveform to be transcribed.""" + + waveform_bytes: UploadFile + sampling_rate: Annotated[str, Form()] + + async def to_prompt(self): + waveform = np.frombuffer(await self.waveform_bytes.read(), + dtype=np.float32) + sampling_rate = int(self.sampling_rate) + if sampling_rate != SAMPLING_RATE: + raise ValueError( + f"Model uses sampling rate of {SAMPLING_RATE}, but got " + f"sampling_rate = {sampling_rate}.") + return format_prompt(waveform, SAMPLING_RATE) + + +@app.post("/generate_from_waveform") +async def generate_from_waveform(data: Annotated[TranscribeFromWaveform, + Form()], + raw_request: Request): + """Transcribe from a waveform.""" + + prompt = await data.to_prompt() + return await _generate(prompt, raw_request=raw_request) + + +@app.post("/generate_from_file") +async def generate_from_file(data: Annotated[TranscribeFromFile, + Form()], raw_request: Request): + """Transcribe from a file.""" + + prompt = await data.to_prompt() + return await _generate(prompt, raw_request=raw_request) + + +@with_cancellation +async def _generate(prompt, raw_request: Request) -> Response: + + sampling_params = SamplingParams(temperature=0, + max_tokens=440, + output_kind=RequestOutputKind.DELTA) + request_id = random_uuid() + + engine = raw_request.app.state.engine + results_generator = engine.generate(prompt, sampling_params, request_id) + + async def stream_results() -> AsyncGenerator[str, None]: + async for request_output in results_generator: + assert len(request_output.outputs) == 1 + chunk = TranscriptionResponse(text=request_output.outputs[0].text) + response_json = chunk.model_dump_json(exclude_unset=False) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_results()) + + +@asynccontextmanager +async def build_engine( + engine_args: AsyncEngineArgs, ) -> AsyncIterator[MQLLMEngineClient]: + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + context = multiprocessing.get_context("spawn") + engine_alive = multiprocessing.Value('b', True, lock=False) + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path, engine_alive)) + engine_process.start() + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + engine_config = engine_args.create_engine_config() + build_client = partial(MQLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) + try: + while True: + try: + await mq_engine_client.setup() + break + except TimeoutError: + if (not engine_process.is_alive() or not engine_alive.value): + raise RuntimeError( + "Engine process failed to start. See stack " + "trace for the root cause.") from None + + yield mq_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mq_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + engine_process.kill() + + +async def run_server(args: Namespace, **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine_args = AsyncEngineArgs.from_cli_args(args) + async with build_engine(engine_args) as engine: + # Build App. + app.state.engine = engine + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/whisper_server/helper.py b/vllm/entrypoints/whisper_server/helper.py new file mode 100644 index 0000000000000..c12ce809b4305 --- /dev/null +++ b/vllm/entrypoints/whisper_server/helper.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/openai/whisper/blob/fc5ded7d9045c693692f13853857c3f8baea3a7b/whisper/audio.py +# MIT License + +import subprocess +from subprocess import CalledProcessError + +import numpy as np + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, + HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, + N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio_from_bytes(audio_bytes: bytes, sample_rate: int = SAMPLE_RATE): + """ + Read bytes from audio file as mono waveform, resampling as necessary + + Parameters + ---------- + audio_bytes: bytes + sample_rate: int + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + cmd = [ + "ffmpeg", "-nostdin", "-threads", "0", "-i", "-", "-f", "s16le", "-ac", + "1", "-acodec", "pcm_s16le", "-ar", + str(sample_rate), "-" + ] + try: + process = subprocess.Popen(cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, _ = process.communicate(input=audio_bytes) + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(stdout, np.int16).flatten().astype( + np.float32) / 32768.0 + + +def validate_length(array, max_length: int = N_SAMPLES, axis: int = -1): + audio_length = array.shape[axis] + if audio_length > max_length: + raise ValueError( + f"Length of audio {audio_length} is bigger than the maximum " + f"length of {max_length} = MAX_LENGTH * SAMPLING_RATE") diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0a3011d361013..5e2c439529a41 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -598,11 +598,15 @@ def input_processor_for_whisper(ctx: InputContext, inputs): audio, orig_sr = multi_modal_data["audio"] processor = cached_get_processor(ctx.model_config.model) target_sr = processor.feature_extractor.sampling_rate - audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) + # NOTE: resampling is expensive, so skip it if the audio data + # sent to the Engine is already in Whisper's SAMPLE_RATE=16000. + if orig_sr != target_sr: + audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) multi_modal_data["audio"] = (audio, target_sr) # Pre-allocate placeholder tokens in encoder sequence num_tokens = get_max_whisper_audio_tokens(ctx) inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens + return inputs @@ -623,6 +627,9 @@ def input_mapper_for_whisper( audios = [audio for audio, _ in multi_modal_data] + # 1) Pad out with empty audio to N_SAMPLES=480000 (30s * SAMPLE_RATE) + # 2) Apply log_mel_spectrogram to padded (N_MEL_FILTERS=128, N_FRAMES=3000) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py#L175 # noqa: E501 kwargs = processor(audios, sampling_rate=sampling_rate, return_tensors="pt")