From 8d6c9961d11956d38372d78d63353617b8b75d01 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Thu, 26 Dec 2024 08:33:03 +0700 Subject: [PATCH] init whispervq deployment --- .gitignore | 4 +- deployments/whispervq/metadata.yml | 3 + deployments/whispervq/model.yml | 17 ++ deployments/whispervq/requirements.cuda.txt | 24 +++ deployments/whispervq/requirements.txt | 23 +++ deployments/whispervq/src/app.py | 62 +++++++ .../common/abstract/controller_abstract.py | 18 ++ .../src/common/constant/fastapi_constant.py | 8 + .../src/common/utility/convert_utility.py | 38 ++++ .../src/common/utility/generator_utility.py | 13 ++ .../src/common/utility/logger_utility.py | 63 +++++++ .../src/services/audio/audio_controller.py | 23 +++ .../src/services/audio/audio_interface.py | 11 ++ .../src/services/audio/audio_model.py | 35 ++++ .../audio/implementation/audio_service.py | 164 ++++++++++++++++++ .../audio/implementation/audio_utils.py | 151 ++++++++++++++++ .../src/services/health/health_controller.py | 23 +++ .../src/services/health/health_interface.py | 10 ++ .../src/services/health/health_model.py | 11 ++ .../health/implementation/health_service.py | 16 ++ .../model/implementation/model_service.py | 59 +++++++ .../src/services/model/model_controller.py | 0 .../src/services/model/model_interface.py | 10 ++ .../src/services/model/model_model.py | 19 ++ deployments/whispervq/src/variables/.env | 4 + .../src/variables/whisper_variable.py | 13 ++ 26 files changed, 821 insertions(+), 1 deletion(-) create mode 100644 deployments/whispervq/metadata.yml create mode 100644 deployments/whispervq/model.yml create mode 100644 deployments/whispervq/requirements.cuda.txt create mode 100644 deployments/whispervq/requirements.txt create mode 100644 deployments/whispervq/src/app.py create mode 100644 deployments/whispervq/src/common/abstract/controller_abstract.py create mode 100644 deployments/whispervq/src/common/constant/fastapi_constant.py create mode 100644 deployments/whispervq/src/common/utility/convert_utility.py create mode 100644 deployments/whispervq/src/common/utility/generator_utility.py create mode 100644 deployments/whispervq/src/common/utility/logger_utility.py create mode 100644 deployments/whispervq/src/services/audio/audio_controller.py create mode 100644 deployments/whispervq/src/services/audio/audio_interface.py create mode 100644 deployments/whispervq/src/services/audio/audio_model.py create mode 100644 deployments/whispervq/src/services/audio/implementation/audio_service.py create mode 100644 deployments/whispervq/src/services/audio/implementation/audio_utils.py create mode 100644 deployments/whispervq/src/services/health/health_controller.py create mode 100644 deployments/whispervq/src/services/health/health_interface.py create mode 100644 deployments/whispervq/src/services/health/health_model.py create mode 100644 deployments/whispervq/src/services/health/implementation/health_service.py create mode 100644 deployments/whispervq/src/services/model/implementation/model_service.py create mode 100644 deployments/whispervq/src/services/model/model_controller.py create mode 100644 deployments/whispervq/src/services/model/model_interface.py create mode 100644 deployments/whispervq/src/services/model/model_model.py create mode 100644 deployments/whispervq/src/variables/.env create mode 100644 deployments/whispervq/src/variables/whisper_variable.py diff --git a/.gitignore b/.gitignore index 3ac4f46..1392450 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ old_train.py *.ipynb __pycache__ .ipynb_checkpoints -outputs \ No newline at end of file +outputs +*.pt +*.model \ No newline at end of file diff --git a/deployments/whispervq/metadata.yml b/deployments/whispervq/metadata.yml new file mode 100644 index 0000000..20501cf --- /dev/null +++ b/deployments/whispervq/metadata.yml @@ -0,0 +1,3 @@ +version: 1 +name: whispervq +default: fp16 diff --git a/deployments/whispervq/model.yml b/deployments/whispervq/model.yml new file mode 100644 index 0000000..7e91cb6 --- /dev/null +++ b/deployments/whispervq/model.yml @@ -0,0 +1,17 @@ +id: whispervq:fp16 +model: whispervq:fp16 +name: Ichigo WhisperVQ +version: 1 + +engine: python-engine + +extra_params: + device_id: 0 + package_dir: "" + +port: 3348 +script: src/app.py +log_path: whisper.log +log_level: INFO +command: + - python \ No newline at end of file diff --git a/deployments/whispervq/requirements.cuda.txt b/deployments/whispervq/requirements.cuda.txt new file mode 100644 index 0000000..f39f80a --- /dev/null +++ b/deployments/whispervq/requirements.cuda.txt @@ -0,0 +1,24 @@ +openai-whisper==20231117 +huggingface_hub +IPython +pyarrow +matplotlib +librosa +soundfile +datasets +encodec +boto3 +fire +vector_quantize_pytorch +webdataset +whisperspeech +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.2.0 +torchaudio==2.2.0 +numpy==1.26.4 +fastapi +uvicorn + +python-multipart +transformers +psutil diff --git a/deployments/whispervq/requirements.txt b/deployments/whispervq/requirements.txt new file mode 100644 index 0000000..2887d96 --- /dev/null +++ b/deployments/whispervq/requirements.txt @@ -0,0 +1,23 @@ +openai-whisper==20231117 +huggingface_hub +IPython +pyarrow +matplotlib +librosa +soundfile +datasets +encodec +boto3 +fire +vector_quantize_pytorch +webdataset +whisperspeech +torch==2.2.0 +torchaudio==2.2.0 +numpy==1.26.4 +fastapi +uvicorn + +python-multipart +transformers +psutil diff --git a/deployments/whispervq/src/app.py b/deployments/whispervq/src/app.py new file mode 100644 index 0000000..26545bb --- /dev/null +++ b/deployments/whispervq/src/app.py @@ -0,0 +1,62 @@ + +import argparse +import os +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncGenerator, List + +import uvicorn +from dotenv import load_dotenv +from fastapi import APIRouter, FastAPI + +from common.utility.logger_utility import LoggerUtility +from services.audio.audio_controller import AudioController +from services.audio.implementation.audio_service import AudioService +from services.health.health_controller import HealthController + + +@asynccontextmanager +async def application_lifecycle(app: FastAPI) -> AsyncGenerator[None, None]: + try: + AudioService.get_audio_service() + except Exception as e: + LoggerUtility.get_logger().error(f"Error initializing audio service: {e}") + raise e + yield + + +def create_app() -> FastAPI: + routes: List[APIRouter] = [ + HealthController(), + AudioController() + ] + app = FastAPI(lifespan=application_lifecycle) + for route in routes: + app.include_router(route) + return app + + +def parse_argument(): + parser = argparse.ArgumentParser(description="WhisperVQ Application") + parser.add_argument('--log_path', type=str, + default='whisper.log', help='The log file path') + parser.add_argument('--log_level', type=str, default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'TRACE'], help='The log level') + parser.add_argument('--port', type=int, default=3348, + help='The port to run the WhisperVQ app on') + parser.add_argument('--device_id', type=str, default="0", + help='The port to run the WhisperVQ app on') + parser.add_argument('--package_dir', type=str, default="", + help='The package-dir to be extended to sys.path') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_argument() + LoggerUtility.init_logger(__name__, args.log_level, args.log_path) + + env_path = Path(os.path.dirname(os.path.realpath(__file__))) / "variables" / ".env" + load_dotenv(dotenv_path=env_path) + app: FastAPI = create_app() + print("Server is running at: 0.0.0.0:", args.port) + uvicorn.run(app=app, host="0.0.0.0", port=args.port) diff --git a/deployments/whispervq/src/common/abstract/controller_abstract.py b/deployments/whispervq/src/common/abstract/controller_abstract.py new file mode 100644 index 0000000..b10e7ef --- /dev/null +++ b/deployments/whispervq/src/common/abstract/controller_abstract.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from fastapi import APIRouter + + +class ControllerAbstract(APIRouter, ABC): + def __init__(self, prefix: str): + super().__init__(prefix=prefix) + self._setup_services() + self._setup_routes() + + @abstractmethod + def _setup_services(self) -> None: + pass + + @abstractmethod + def _setup_routes(self) -> None: + pass diff --git a/deployments/whispervq/src/common/constant/fastapi_constant.py b/deployments/whispervq/src/common/constant/fastapi_constant.py new file mode 100644 index 0000000..2412232 --- /dev/null +++ b/deployments/whispervq/src/common/constant/fastapi_constant.py @@ -0,0 +1,8 @@ +class ContentTypeConstant: + application_json: str = "application/json" + text_event_stream: str = "text/event-stream" + audio_wav: str = "audio/wav" + +class RestConstant: + post: str = "POST" + get: str = "GET" \ No newline at end of file diff --git a/deployments/whispervq/src/common/utility/convert_utility.py b/deployments/whispervq/src/common/utility/convert_utility.py new file mode 100644 index 0000000..83008b1 --- /dev/null +++ b/deployments/whispervq/src/common/utility/convert_utility.py @@ -0,0 +1,38 @@ +import base64 + + +class ConvertUtility: + @staticmethod + def encode_to_base64(byte_data: bytes) -> str: + + try: + base64_encoded = base64.b64encode(byte_data).decode('utf-8') + return base64_encoded + except IOError as e: + raise IOError(f"Error reading audio file: {e}") + + @staticmethod + def decode_base64( + base64_string: str + ) -> bytes: + """ + Decode a base64 string to audio bytes and optionally save to file. + + Args: + base64_string (str): Base64 encoded string + output_path (Optional[Union[str, Path]]): Path to save the decoded audio file + + Returns: + bytes: Decoded audio bytes + + Raises: + ValueError: If the base64 string is invalid + IOError: If there's an error writing the file + """ + try: + audio_bytes = base64.b64decode(base64_string) + return audio_bytes + except base64.binascii.Error as e: + raise ValueError(f"Invalid base64 string: {e}") + except IOError as e: + raise IOError(f"Error writing audio file: {e}") diff --git a/deployments/whispervq/src/common/utility/generator_utility.py b/deployments/whispervq/src/common/utility/generator_utility.py new file mode 100644 index 0000000..3557fe9 --- /dev/null +++ b/deployments/whispervq/src/common/utility/generator_utility.py @@ -0,0 +1,13 @@ +import hashlib +import uuid + + +class GeneratorUtility: + + @staticmethod + def generate_uuid_v4(seed: str = "") -> uuid.UUID: + if not seed: + return uuid.uuid4() + hash_object: hashlib._Hash = hashlib.sha256(seed.encode('utf-8')) + hash_bytes: bytes = hash_object.digest()[:16] + return uuid.UUID(bytes=hash_bytes, version=4) \ No newline at end of file diff --git a/deployments/whispervq/src/common/utility/logger_utility.py b/deployments/whispervq/src/common/utility/logger_utility.py new file mode 100644 index 0000000..2c8bf1a --- /dev/null +++ b/deployments/whispervq/src/common/utility/logger_utility.py @@ -0,0 +1,63 @@ +import logging +from enum import Enum +from typing import ClassVar, Optional + +from uvicorn.config import LOGGING_CONFIG + + +class LoggerUtility: + """ + This class is used to create a logger object. + """ + _logger: ClassVar[logging.Logger] = None + + class LogLevel(Enum): + """ + This class is used to define the log level. + """ + DEBUG = logging.DEBUG + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + CRITICAL = logging.CRITICAL + + @staticmethod + def init_logger(name: str, log_level: LogLevel = LogLevel.INFO, log_file: Optional[str] = None) -> None: + """ + This method is used to initialize the logger. + """ + if LoggerUtility._logger is None: + LoggerUtility._logger = logging.getLogger(name) + LoggerUtility._logger.setLevel(log_level) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + LoggerUtility._logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + LoggerUtility._logger.addHandler(console_handler) + + LOGGING_CONFIG["handlers"]["default"] = { + "class": "logging.FileHandler", + "filename": log_file, + "formatter": "default" + } + LOGGING_CONFIG["handlers"]["access"] = { + "class": "logging.FileHandler", + "filename": log_file, + "formatter": "access" + } + LOGGING_CONFIG["loggers"]["uvicorn.error"]["level"] = log_level + LOGGING_CONFIG["loggers"]["uvicorn.access"]["level"] = log_level + + @staticmethod + def get_logger() -> logging.Logger: + """ + This method is used to create a logger object. + """ + if LoggerUtility._logger is None: + raise (Exception("Logger is not initialized.")) + else: + return LoggerUtility._logger diff --git a/deployments/whispervq/src/services/audio/audio_controller.py b/deployments/whispervq/src/services/audio/audio_controller.py new file mode 100644 index 0000000..89fc3b9 --- /dev/null +++ b/deployments/whispervq/src/services/audio/audio_controller.py @@ -0,0 +1,23 @@ + +from common.abstract.controller_abstract import ControllerAbstract +from common.constant.fastapi_constant import RestConstant +from services.audio.audio_model import AudioModel +from services.audio.implementation.audio_service import AudioService + + +class AudioController(ControllerAbstract): + + _prefix = "/inference" + + def __init__(self): + super().__init__(prefix=self._prefix) + + def _setup_routes(self): + self.add_api_route("", self.inference, + methods=[RestConstant.post]) + + def _setup_services(self): + self.audio_service = AudioService.get_audio_service() + + async def inference(self, req: AudioModel.Request) -> AudioModel.Response: + return await self.audio_service.inference(req) diff --git a/deployments/whispervq/src/services/audio/audio_interface.py b/deployments/whispervq/src/services/audio/audio_interface.py new file mode 100644 index 0000000..1788f1c --- /dev/null +++ b/deployments/whispervq/src/services/audio/audio_interface.py @@ -0,0 +1,11 @@ + +from abc import ABC, abstractmethod + +from services.audio.audio_model import AudioModel + + +class AudioInterface(ABC): + + @abstractmethod + async def inference(self, req: AudioModel.Request) -> AudioModel.Response: + pass diff --git a/deployments/whispervq/src/services/audio/audio_model.py b/deployments/whispervq/src/services/audio/audio_model.py new file mode 100644 index 0000000..2ada906 --- /dev/null +++ b/deployments/whispervq/src/services/audio/audio_model.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from enum import Enum + + +class AudioFormat(str, Enum): + WAV = "wav" # Supported by both backends + MP3 = "mp3" # Supported by ffmpeg + FLAC = "flac" # Supported by both + AAC = "aac" # Supported by ffmpeg + OGG = "ogg" # Supported by ffmpeg + OPUS = "opus" # Supported by ffmpeg + PCM = "pcm" # Raw PCM data + + +class AudioModel: + FORMAT_BACKENDS = { + AudioFormat.WAV: ["soundfile", "ffmpeg"], + AudioFormat.MP3: ["ffmpeg"], + AudioFormat.FLAC: ["soundfile", "ffmpeg"], + AudioFormat.AAC: ["ffmpeg"], + AudioFormat.OGG: ["ffmpeg"], + AudioFormat.OPUS: ["ffmpeg"], + AudioFormat.PCM: ["soundfile"] + } + + @dataclass + class Request: + data: str + format: AudioFormat = "wav" + + @dataclass + class Response: + tokens: str + sample_rate: int + format: AudioFormat diff --git a/deployments/whispervq/src/services/audio/implementation/audio_service.py b/deployments/whispervq/src/services/audio/implementation/audio_service.py new file mode 100644 index 0000000..b1563f2 --- /dev/null +++ b/deployments/whispervq/src/services/audio/implementation/audio_service.py @@ -0,0 +1,164 @@ +import os +import tempfile +from pathlib import Path +from typing import Tuple + +import torch +import torchaudio +from fastapi import HTTPException +from huggingface_hub import hf_hub_download + +from common.utility.convert_utility import ConvertUtility +from common.utility.logger_utility import LoggerUtility +from services.audio.audio_interface import AudioInterface +from services.audio.audio_model import AudioFormat, AudioModel +from services.audio.implementation.audio_utils import CustomRQBottleneckTransformer +from services.model.implementation.model_service import ModelService +from variables.whisper_variable import WhisperVariable + + +class AudioService(AudioInterface): + _audio_service = None + + @staticmethod + def get_audio_service(): + if AudioService._audio_service is None: + AudioService._audio_service = AudioService() + return AudioService._audio_service + + def __init__(self,): + self.model_service = ModelService() + self.logger = LoggerUtility.get_logger() + self.whisper_variable = WhisperVariable() + self.available_backends = torchaudio.list_audio_backends() + self.download_folder = Path(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))))/"downloads" + self.load_vq_model() + + def load_vq_model(self): + self.has_ffmpeg = "ffmpeg" in self.available_backends + if not self.has_ffmpeg: + self.logger.warning( + "FFMPEG backend not available. Some formats may not be supported") + device = "cuda" if torch.cuda.is_available() else "cpu" + if not os.path.exists(self.download_folder/self.whisper_variable.whisper_model_path): + hf_hub_download( + repo_id=self.whisper_variable.repo_id, + filename=self.whisper_variable.whisper_model_path, + local_dir=self.download_folder, + ) + self.vq_model = CustomRQBottleneckTransformer.load_vq_only( + self.download_folder / self.whisper_variable.whisper_model_path + ).to(device) + self.vq_model.load_encoder(device) + self.vq_model.eval() + + def _get_best_backend(self, format: AudioFormat) -> str: + """Determine the best backend for the given format""" + supported_backends = AudioModel.FORMAT_BACKENDS[format] + for backend in supported_backends: + if backend in self.available_backends: + return backend + raise ValueError(f"No available backend supports format {format}") + + def load_audio( + self, + file_obj: bytes, + format: AudioFormat, + target_sr: int = 16000 + ) -> Tuple[torch.Tensor, int]: + """ + Load audio from bytes object with format handling + + Args: + file_obj: Audio file bytes + format: Audio format enum + target_sr: Target sample rate (default: 16000) + + Returns: + Tuple[torch.Tensor, int]: Audio tensor and sample rate + """ + try: + # Get appropriate backend + backend = self._get_best_backend(format) + torchaudio.set_audio_backend(backend) + self.logger.info(f"Using {backend} backend for {format} format") + + if format == AudioFormat.PCM: + # Handle raw PCM + wav = torch.frombuffer(file_obj, dtype=torch.int16) + wav = wav.float() / 32768.0 # Normalize to [-1, 1] + wav = wav.unsqueeze(0) # Add channel dimension + sr = target_sr + else: + # For formats that might need ffmpeg processing + if os.name == "nt": # for windows + wav, sr = torchaudio.load(io.BytesIO(file_obj)) + else: + with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file: + # Write bytes to temporary file + temp_file.write(file_obj) + temp_file.flush() + + # Load audio + wav, sr = torchaudio.load(temp_file.name) + + # Convert to mono if stereo + if wav.shape[0] > 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + # Resample if needed + if sr != target_sr: + wav = torchaudio.functional.resample(wav, sr, target_sr) + sr = target_sr + + return wav, sr + + except Exception as e: + self.logger.error(f"Error loading audio: {e}") + raise HTTPException( + status_code=400, + detail=f"Error processing {format} audio: {str(e)}" + ) + + def get_format_info(self) -> dict: + """Get information about supported formats""" + supported_formats = {} + for format in AudioFormat: + try: + backend = self._get_best_backend(format) + supported_formats[format] = { + "supported": True, + "backend": backend + } + except ValueError: + supported_formats[format] = { + "supported": False, + "backend": None + } + return supported_formats + + async def inference(self, req: AudioModel.Request) -> AudioModel.Response: + try: + wav, sr = self.load_audio(ConvertUtility.decode_base64(req.data), req.format) + + # Ensure we're using CUDA if available + device = "cuda" if torch.cuda.is_available() else "cpu" + wav = wav.to(device) + + # Generate tokens + with torch.no_grad(): + codes = self.vq_model.encode_audio(wav) + codes = codes[0].cpu().tolist() + + # Format result + result = ''.join(f'<|sound_{num:04d}|>' for num in codes) + + return AudioModel.Response(tokens=f'<|sound_start|>{result}<|sound_end|>', sample_rate=sr, format=req.format) + + except Exception as e: + self.logger.error(f"Error processing request: {e}") + raise HTTPException( + status_code=500, + detail=f"Error processing request: {str(e)}" + ) diff --git a/deployments/whispervq/src/services/audio/implementation/audio_utils.py b/deployments/whispervq/src/services/audio/implementation/audio_utils.py new file mode 100644 index 0000000..082d493 --- /dev/null +++ b/deployments/whispervq/src/services/audio/implementation/audio_utils.py @@ -0,0 +1,151 @@ +import io +import os +from pathlib import Path +from typing import Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +import whisper +from huggingface_hub import hf_hub_download +from whisper.model import AudioEncoder, ModelDimensions +from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables + +from services.model.implementation.model_service import ModelService +from variables.whisper_variable import WhisperVariable + + +class CustomWhisperEncoder(nn.Module): + """ + Lightweight wrapper that only loads the AudioEncoder part of Whisper + """ + + def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,): + super().__init__() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.whisper_variable = WhisperVariable() + checkpoint = self.download(download_root, name, in_memory, device) + dims = ModelDimensions(**checkpoint["dims"]) + self.encoder = AudioEncoder( + dims.n_mels, + dims.n_audio_ctx, + dims.n_audio_state, + dims.n_audio_head, + dims.n_audio_layer, + ) + + self.encoder.load_state_dict(checkpoint["model_state_dict"]) + + if device: + self.to(device) + + self.eval() + + def download(self, download_root: str, name: str, in_memory: bool, device: str) -> Any: + if download_root is None: + download_root = Path(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))))/"downloads" + + if name in self.whisper_variable.HF_MODELS: + checkpoint_file = ModelService.download_encoder( + self.whisper_variable.HF_MODELS[name], download_root, in_memory) + elif os.path.isfile(name): + checkpoint_file = open(name, "rb").read() if in_memory else name + else: + raise RuntimeError( + f"Model {name} not found available models={self.available_models()}" + ) + + # Load weights + with ( + io.BytesIO(checkpoint_file) if in_memory else open( + checkpoint_file, "rb") + ) as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + return checkpoint + + def available_models(self) -> List[str]: + """Returns the names of available models""" + return list(self.whisper_variable.HF_MODELS.keys()) + + def forward(self, mel: torch.Tensor): + return self.encoder(mel) + + +class CustomRQBottleneckTransformer(RQBottleneckTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model", + repo_id=None, filename=None, local_filename=None): + if repo_id is None and filename is None and local_filename is None: + if ":" in str(ref): + repo_id, filename = ref.split(":", 1) + else: + local_filename = ref + if not local_filename: + local_filename = hf_hub_download( + repo_id=repo_id, filename=filename) + + # Load the spec + spec = torch.load(local_filename) + + # Create instance with minimal required components + instance = cls(**spec['config'], tunables=Tunables(** + Tunables.upgrade(spec.get('tunables', {})))) + + # Load only necessary state dict entries + required_components = { + 'rq', 'mlp', 'mlp_ln' + } + filtered_state_dict = { + k: v for k, v in spec['state_dict'].items() + if any(k.startswith(comp) for comp in required_components) + } + + instance.load_state_dict(filtered_state_dict, strict=False) + instance.eval() + return instance + + def load_encoder(self, device=None): + if self.whmodel is not None: + return + device = device or self.device + # Use our custom encoder-only model + if self.whmodel is None: + encoder = CustomWhisperEncoder( + self.whisper_model_name, device=device) + self.whmodel = encoder + multilingual = not self.whisper_model_name.endswith('.en') + self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual) + + def optimzed_encode_mel(self, mel): + assert len( + mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" + self.load_encoder() + n = mel.shape[-1] + if n > whisper.audio.N_FRAMES: + padding = 0 + padded = mel[:, :, :whisper.audio.N_FRAMES] + else: + padding = -n % whisper.audio.N_FRAMES + padded = F.pad(mel, (0, padding), value=-1.5) + # .to(self.whmodel[0].device))#[:,:n//2] + embs = self.whmodel.encoder(padded) + stoks = self.quantize(embs) + if self.tunables.mask_embs: + return stoks[:, :n//2//self.downsample] + else: + return stoks + # overide + + def encode_audio(self, audio): + if isinstance(audio, str): + x, sr = torchaudio.load(audio) + x = torchaudio.transforms.Resample(sr, 16000)(x)[0] + audio = x.unsqueeze(0) + return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device)) diff --git a/deployments/whispervq/src/services/health/health_controller.py b/deployments/whispervq/src/services/health/health_controller.py new file mode 100644 index 0000000..1474682 --- /dev/null +++ b/deployments/whispervq/src/services/health/health_controller.py @@ -0,0 +1,23 @@ + +from common.abstract.controller_abstract import ControllerAbstract +from common.constant.fastapi_constant import RestConstant +from services.health.health_model import ServerStatusModel +from services.health.implementation.health_service import HealthService + + +class HealthController(ControllerAbstract): + + _prefix = "/health" + + def __init__(self): + super().__init__(prefix=self._prefix) + + def _setup_routes(self): + self.add_api_route("", self.server_status, + methods=[RestConstant.get]) + + def _setup_services(self): + self.health_service = HealthService() + + async def server_status(self) -> ServerStatusModel.Response: + return await self.health_service.server_status() diff --git a/deployments/whispervq/src/services/health/health_interface.py b/deployments/whispervq/src/services/health/health_interface.py new file mode 100644 index 0000000..3e34a3e --- /dev/null +++ b/deployments/whispervq/src/services/health/health_interface.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from services.health.health_model import ServerStatusModel + + +class HealthInterface(ABC): + + @abstractmethod + async def server_status(self) -> ServerStatusModel.Response: + pass diff --git a/deployments/whispervq/src/services/health/health_model.py b/deployments/whispervq/src/services/health/health_model.py new file mode 100644 index 0000000..6897e45 --- /dev/null +++ b/deployments/whispervq/src/services/health/health_model.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from enum import Enum + +class ServerStatusModel: + @dataclass() + class Response: + class StatusEnum(Enum): + OK = "OK" + ERROR = "ERROR" + status: StatusEnum + message: str \ No newline at end of file diff --git a/deployments/whispervq/src/services/health/implementation/health_service.py b/deployments/whispervq/src/services/health/implementation/health_service.py new file mode 100644 index 0000000..e8ce6d9 --- /dev/null +++ b/deployments/whispervq/src/services/health/implementation/health_service.py @@ -0,0 +1,16 @@ +from logging import Logger + +from common.utility.logger_utility import LoggerUtility +from services.health.health_interface import HealthInterface +from services.health.health_model import ServerStatusModel + + +class HealthService(HealthInterface): + + def __init__(self,): + self.logger: Logger = LoggerUtility.get_logger() + + async def server_status(self) -> ServerStatusModel.Response: + status = status = ServerStatusModel.Response.StatusEnum.OK + message = "Still alive!" + return ServerStatusModel.Response(status=status, message=message) diff --git a/deployments/whispervq/src/services/model/implementation/model_service.py b/deployments/whispervq/src/services/model/implementation/model_service.py new file mode 100644 index 0000000..67a673e --- /dev/null +++ b/deployments/whispervq/src/services/model/implementation/model_service.py @@ -0,0 +1,59 @@ +import os +import urllib +from pathlib import Path + +from huggingface_hub import hf_hub_download + +from common.utility.logger_utility import LoggerUtility +from services.model.model_interface import ModelInterface +from services.model.model_model import ModelModel + + +class ModelService(ModelInterface): + def __init__(self,): + self.logger = LoggerUtility.get_logger() + self.download_folder = Path(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))))/"downloads" + + @staticmethod + def download_encoder(url: str, root: str, in_memory: bool): + logger = LoggerUtility.get_logger() + os.makedirs(root, exist_ok=True) + download_target = os.path.join(root, os.path.basename(url)) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError( + f"{download_target} exists and is not a regular file") + if os.path.isfile(download_target): + with open(download_target, "rb") as f: + model_bytes = f.read() + return model_bytes if in_memory else download_target + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + total = int(source.info().get("Content-Length")) + downloaded = 0 + count = 0 + while True: + buffer = source.read(8192) + if not buffer: + break + count += 1 + output.write(buffer) + downloaded += len(buffer) + if count % 1000 == 0: + logger.info(f"Downloaded {downloaded}/{total} bytes") + + model_bytes = open(download_target, "rb").read() + return model_bytes if in_memory else download_target + + async def download_model(self, req: ModelModel.Request) -> ModelModel.Response: + if not os.path.exists(self.download_folder/req.whisper_model_path): + hf_hub_download( + repo_id=req.repo_id, + filename=req.whisper_model_path, + local_dir=self.download_folder, + ) + ModelService.download_encoder(req.whisper_encoder_path, + self.download_folder, False) + return ModelModel.Response(status=ModelModel.Response.StatusEnum.OK, message="downloaded successfully!") diff --git a/deployments/whispervq/src/services/model/model_controller.py b/deployments/whispervq/src/services/model/model_controller.py new file mode 100644 index 0000000..e69de29 diff --git a/deployments/whispervq/src/services/model/model_interface.py b/deployments/whispervq/src/services/model/model_interface.py new file mode 100644 index 0000000..83f9583 --- /dev/null +++ b/deployments/whispervq/src/services/model/model_interface.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from services.model.model_model import ModelModel + + +class ModelInterface(ABC): + + @abstractmethod + async def download_model(self, req: ModelModel.Request) -> ModelModel.Response: + pass \ No newline at end of file diff --git a/deployments/whispervq/src/services/model/model_model.py b/deployments/whispervq/src/services/model/model_model.py new file mode 100644 index 0000000..6ea1144 --- /dev/null +++ b/deployments/whispervq/src/services/model/model_model.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from enum import Enum + + +class ModelModel: + @dataclass() + class Request: + whisper_model_path: str + whisper_encoder_name: str + whisper_encoder_path: str + repo_id: str + + @dataclass() + class Response: + class StatusEnum(Enum): + OK = "OK" + ERROR = "ERROR" + status: StatusEnum + message: str diff --git a/deployments/whispervq/src/variables/.env b/deployments/whispervq/src/variables/.env new file mode 100644 index 0000000..8dce535 --- /dev/null +++ b/deployments/whispervq/src/variables/.env @@ -0,0 +1,4 @@ +WHISPER_MODEL_PATH="whisper-vq-stoks-v3-7lang-fixed.model" +WHISPER_ENCODER_NAME="medium" +WHISPER_ENCODER_PATH="https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt" +REPO_ID="jan-hq/WhisperVQ" \ No newline at end of file diff --git a/deployments/whispervq/src/variables/whisper_variable.py b/deployments/whispervq/src/variables/whisper_variable.py new file mode 100644 index 0000000..d816c06 --- /dev/null +++ b/deployments/whispervq/src/variables/whisper_variable.py @@ -0,0 +1,13 @@ +import os + +from dotenv import load_dotenv + + +class WhisperVariable: + + def __init__(self): + load_dotenv() + self.whisper_model_path: str | None = os.getenv("WHISPER_MODEL_PATH") + self.repo_id: str | None = os.getenv("REPO_ID") + self.HF_MODELS: dict[str, str | None] = { + os.getenv("WHISPER_ENCODER_NAME", "medium"): os.getenv("WHISPER_ENCODER_PATH")}