diff --git a/asr.py b/asr.py index 24c3f9c..2c3defd 100644 --- a/asr.py +++ b/asr.py @@ -2,7 +2,7 @@ import os from base_util import get_asset_info, asr_output_dir -from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket +from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir from download import download_uri from whisper import run_asr, WHISPER_JSON_FILE from s3_util import S3Store @@ -10,9 +10,10 @@ from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE logger = logging.getLogger(__name__) +os.environ['HF_HOME'] = model_base_dir # change dir where model is downloaded -def run(input_uri: str, output_uri: str) -> bool: +def run(input_uri: str, output_uri: str, model=None) -> bool: logger.info(f"Processing {input_uri} (save to --> {output_uri})") # 1. download input result = download_uri(input_uri) @@ -36,7 +37,7 @@ def run(input_uri: str, output_uri: str) -> bool: # 3. run ASR if not asr_already_done(output_path): logger.info("No Whisper transcript found") - run_asr(input_path, output_path) + run_asr(input_path, output_path, model) else: logger.info(f"Whisper transcript already present in {output_path}") diff --git a/main.py b/main.py index 0d2fcc1..90a5e85 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ def run_api(port: int): uvicorn.run(api, port=port, host="0.0.0.0") -def run_job(intput_uri: str, output_uri: str): +def run_job(input_uri: str, output_uri: str): import asr logger.info("Running Whisper as a one time job") diff --git a/pyproject.toml b/pyproject.toml index 7294fa3..bb9f202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "whisper-asr-worker" -version = "0.2.0" +version = "0.3.0" description = "Whisper speech-to-text worker" authors = ["Dragos Alexandru Balan "] license = "MIT" diff --git a/whisper.py b/whisper.py index 41cd5aa..68e5e90 100644 --- a/whisper.py +++ b/whisper.py @@ -22,19 +22,20 @@ logger = logging.getLogger(__name__) -def run_asr(input_path, output_dir) -> bool: +def run_asr(input_path, output_dir, model=None) -> bool: logger.info(f"Starting ASR on {input_path}") - logger.info(f"Device used: {w_device}") - # checking if model needs to be downloaded from HF or not - model_location = model_base_dir if check_model_availability() else w_model - model = faster_whisper.WhisperModel( - model_location, - device=w_device, - compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU - "float16" if w_device == "cuda" else "float32" - ), - ) - logger.info("Model loaded, now getting segments") + if not model: + logger.info(f"Device used: {w_device}") + # checking if model needs to be downloaded from HF or not + model_location = model_base_dir if check_model_availability() else w_model + model = faster_whisper.WhisperModel( + model_location, + device=w_device, + compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU + "float16" if w_device == "cuda" else "float32" + ), + ) + logger.info("Model loaded, now getting segments") segments, _ = model.transcribe( input_path, vad_filter=w_vad, diff --git a/whisper_api.py b/whisper_api.py index dbae13d..21ee862 100644 --- a/whisper_api.py +++ b/whisper_api.py @@ -5,10 +5,29 @@ from enum import Enum from typing import Optional from pydantic import BaseModel +from config import ( + model_base_dir, + w_device, + w_model, +) +import faster_whisper +from model_download import check_model_availability logger = logging.getLogger(__name__) api = FastAPI() +logger.info(f"Loading model on device {w_device}") +# checking if model needs to be downloaded from HF or not +model_location = model_base_dir if check_model_availability() else w_model +model = faster_whisper.WhisperModel( + model_location, + device=w_device, + compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU + "float16" if w_device == "cuda" else "float32" + ), +) +logger.info("Model loaded!") + class Status(Enum): CREATED = "CREATED" @@ -69,7 +88,7 @@ def try_whisper(task: Task): try: task.status = Status.PROCESSING update_task(task) - run(task.input_uri, task.output_uri) + run(task.input_uri, task.output_uri, model) task.status = Status.DONE except Exception: logger.exception("Failed to run whisper")