From 8144798a33cf274cfd6ba5d0b351a2d70f2d9613 Mon Sep 17 00:00:00 2001 From: Jaap Blom Date: Tue, 1 Oct 2024 16:39:12 +0200 Subject: [PATCH] centralized model download in new function in whisper.py --- model_download.py | 1 + whisper.py | 36 +++++++++++++++++++++++++----------- whisper_api.py | 21 +++------------------ 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/model_download.py b/model_download.py index 86a5a79..bed7e18 100644 --- a/model_download.py +++ b/model_download.py @@ -6,6 +6,7 @@ # makes sure the model is available locally, if not download it from S3, if that fails download from Huggingface +# FIXME should also check if the correct w_model type is available locally! def check_model_availability() -> bool: logger = logging.getLogger(__name__) if os.path.exists(model_base_dir + "/model.bin"): diff --git a/whisper.py b/whisper.py index e2d7d88..d730d22 100644 --- a/whisper.py +++ b/whisper.py @@ -6,6 +6,7 @@ from typing import Optional import faster_whisper +from faster_whisper import WhisperModel from config import ( model_base_dir, w_beam_size, @@ -24,21 +25,34 @@ logger = logging.getLogger(__name__) +# loads the whisper model +# FIXME does not check if the specific model_type is available locally! +def load_model(model_base_dir: str, model_type: str, device: str) -> WhisperModel: + logger.info(f"Loading Whisper model {model_type} for device: {device}") + + # change HuggingFace dir to where model is downloaded + os.environ["HF_HOME"] = model_base_dir + + # determine loading locally or have Whisper download from HuggingFace + model_location = model_base_dir if check_model_availability() else model_type + model = WhisperModel( + model_location, # either local path or e.g. large-v2 (means HuggingFace download) + device=device, + compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU + "float16" if device == "cuda" else "float32" + ), + ) + logger.info(f"Model loaded from location: {model_location}") + return model + + def run_asr(input_path, output_dir, model=None) -> Optional[dict]: logger.info(f"Starting ASR on {input_path}") start_time = time.time() if not model: - logger.info(f"Device used: {w_device}") - # checking if model needs to be downloaded from HF or not - model_location = model_base_dir if check_model_availability() else w_model - model = faster_whisper.WhisperModel( - model_location, - device=w_device, - compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU - "float16" if w_device == "cuda" else "float32" - ), - ) - logger.info("Model loaded, now getting segments") + logger.info("Model not passed as param, need to obtain it first") + model = load_model(model_base_dir, w_model, w_device) + logger.info("Processing segments") segments, _ = model.transcribe( input_path, vad_filter=w_vad, diff --git a/whisper_api.py b/whisper_api.py index 4537f02..a3b7af3 100644 --- a/whisper_api.py +++ b/whisper_api.py @@ -1,9 +1,9 @@ import logging -import os from typing import Optional from uuid import uuid4 from fastapi import BackgroundTasks, FastAPI, HTTPException, status, Response from asr import run +from whisper import load_model from enum import Enum from pydantic import BaseModel from config import ( @@ -11,29 +11,14 @@ w_device, w_model, ) -import faster_whisper -from model_download import check_model_availability logger = logging.getLogger(__name__) api = FastAPI() logger.info(f"Loading model on device {w_device}") - -# change hugging face home dir where model is downloaded -os.environ["HF_HOME"] = model_base_dir - -# checking if model needs to be downloaded from HF or not -model_location = model_base_dir if check_model_availability() else w_model - -model = faster_whisper.WhisperModel( - model_location, - device=w_device, - compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU - "float16" if w_device == "cuda" else "float32" - ), -) -logger.info("Model loaded!") +# load the model in memory on API startup +model = load_model(model_base_dir, w_model, w_device) class Status(Enum):