diff --git a/asr.py b/asr.py index 24c3f9c..521afe5 100644 --- a/asr.py +++ b/asr.py @@ -2,7 +2,7 @@ import os from base_util import get_asset_info, asr_output_dir -from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket +from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir from download import download_uri from whisper import run_asr, WHISPER_JSON_FILE from s3_util import S3Store @@ -10,9 +10,10 @@ from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE logger = logging.getLogger(__name__) +os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded -def run(input_uri: str, output_uri: str) -> bool: +def run(input_uri: str, output_uri: str, model=None) -> bool: logger.info(f"Processing {input_uri} (save to --> {output_uri})") # 1. download input result = download_uri(input_uri) @@ -36,7 +37,7 @@ def run(input_uri: str, output_uri: str) -> bool: # 3. run ASR if not asr_already_done(output_path): logger.info("No Whisper transcript found") - run_asr(input_path, output_path) + run_asr(input_path, output_path, model) else: logger.info(f"Whisper transcript already present in {output_path}") diff --git a/base_util.py b/base_util.py index 4ffd5fe..d0384b8 100644 --- a/base_util.py +++ b/base_util.py @@ -1,5 +1,4 @@ import logging -import ntpath import os import subprocess from typing import Tuple @@ -12,7 +11,7 @@ # the file name without extension is used as asset ID def get_asset_info(input_file: str) -> Tuple[str, str]: - file_name = ntpath.basename(input_file) + file_name = os.path.basename(input_file) asset_id, extension = os.path.splitext(file_name) logger.info(f"working with this asset ID {asset_id}") return asset_id, extension diff --git a/main.py b/main.py index 0d2fcc1..90a5e85 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ def run_api(port: int): uvicorn.run(api, port=port, host="0.0.0.0") -def run_job(intput_uri: str, output_uri: str): +def run_job(input_uri: str, output_uri: str): import asr logger.info("Running Whisper as a one time job") diff --git a/poetry.lock b/poetry.lock index 58232e9..73762e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -146,17 +146,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.24" +version = "1.35.26" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.24-py3-none-any.whl", hash = "sha256:97fcc1a14cbc759e4ba9535ced703a99fcf652c9c4b8dfcd06f292c80551684b"}, - {file = "boto3-1.35.24.tar.gz", hash = "sha256:be7807f30f26d6c0057e45cfd09dad5968e664488bf4f9138d0bb7a0f6d8ed40"}, + {file = "boto3-1.35.26-py3-none-any.whl", hash = "sha256:c31db992655db233d98762612690cfe60723c9e1503b5709aad92c1c564877bb"}, + {file = "boto3-1.35.26.tar.gz", hash = "sha256:b04087afd3570ba540fd293823c77270ec675672af23da9396bd5988a3f8128b"}, ] [package.dependencies] -botocore = ">=1.35.24,<1.36.0" +botocore = ">=1.35.26,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -165,13 +165,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.24" +version = "1.35.26" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.24-py3-none-any.whl", hash = "sha256:eb9ccc068255cc3d24c36693fda6aec7786db05ae6c2b13bcba66dce6a13e2e3"}, - {file = "botocore-1.35.24.tar.gz", hash = "sha256:1e59b0f14f4890c4f70bd6a58a634b9464bed1c4c6171f87c8795d974ade614b"}, + {file = "botocore-1.35.26-py3-none-any.whl", hash = "sha256:0b9dee5e4a3314e251e103585837506b17fcc7485c3c8adb61a9a913f46da1e7"}, + {file = "botocore-1.35.26.tar.gz", hash = "sha256:19efc3a22c9df77960712b4e203f912486f8bcd3794bff0fd7b2a0f5f1d5712d"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 7294fa3..bb9f202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "whisper-asr-worker" -version = "0.2.0" +version = "0.3.0" description = "Whisper speech-to-text worker" authors = ["Dragos Alexandru Balan "] license = "MIT" diff --git a/s3_util.py b/s3_util.py index 842aee7..303d36e 100644 --- a/s3_util.py +++ b/s3_util.py @@ -1,10 +1,10 @@ import boto3 import logging -import ntpath import os from pathlib import Path import tarfile from typing import List, Tuple, Optional +from urllib.parse import urlparse logger = logging.getLogger(__name__) @@ -16,12 +16,12 @@ def generate_asset_id_from_input_file( input_file: str, with_extension: bool = False ) -> str: logger.info(f"generating asset ID for {input_file}") - file_name = ntpath.basename(input_file) # grab the file_name from the path + file_name = os.path.basename(input_file) # grab the file_name from the path if with_extension: return file_name # otherwise cut off the extension - asset_id, extension = os.path.splitext(file_name) + asset_id, _ = os.path.splitext(file_name) return asset_id @@ -60,10 +60,11 @@ def tar_list_of_files(archive_path: str, file_list: List[str]) -> bool: def validate_s3_uri(s3_uri: str) -> bool: - if s3_uri[0:5] != "s3://": + o = urlparse(s3_uri, allow_fragments=False) + if o.scheme != "s3": logger.error(f"Invalid protocol in {s3_uri}") return False - if len(s3_uri[5:].split("/")) < 2: + if o.path == "": logger.error(f"No object_name specified {s3_uri}") return False return True @@ -72,9 +73,9 @@ def validate_s3_uri(s3_uri: str) -> bool: # e.g. "s3://beng-daan-visxp/jaap-dane-test/dane-test.tar.gz" def parse_s3_uri(s3_uri: str) -> Tuple[str, str]: logger.info(f"Parsing s3 URI {s3_uri}") - tmp = s3_uri[5:] - bucket = tmp[: tmp.find("/")] # beng-daan-visxp - object_name = s3_uri[len(bucket) + 6 :] # jaap-dane-test/dane-test.tar.gz + o = urlparse(s3_uri, allow_fragments=False) + bucket = o.netloc # beng-daan-visxp + object_name = o.path.lstrip("/") # jaap-dane-test/dane-test.tar.gz return bucket, object_name diff --git a/transcode.py b/transcode.py index 2a1231d..63e019c 100644 --- a/transcode.py +++ b/transcode.py @@ -3,7 +3,7 @@ from typing import Optional import base_util -from base_util import data_base_dir +from config import data_base_dir logger = logging.getLogger(__name__) diff --git a/whisper.py b/whisper.py index 41cd5aa..68e5e90 100644 --- a/whisper.py +++ b/whisper.py @@ -22,19 +22,20 @@ logger = logging.getLogger(__name__) -def run_asr(input_path, output_dir) -> bool: +def run_asr(input_path, output_dir, model=None) -> bool: logger.info(f"Starting ASR on {input_path}") - logger.info(f"Device used: {w_device}") - # checking if model needs to be downloaded from HF or not - model_location = model_base_dir if check_model_availability() else w_model - model = faster_whisper.WhisperModel( - model_location, - device=w_device, - compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU - "float16" if w_device == "cuda" else "float32" - ), - ) - logger.info("Model loaded, now getting segments") + if not model: + logger.info(f"Device used: {w_device}") + # checking if model needs to be downloaded from HF or not + model_location = model_base_dir if check_model_availability() else w_model + model = faster_whisper.WhisperModel( + model_location, + device=w_device, + compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU + "float16" if w_device == "cuda" else "float32" + ), + ) + logger.info("Model loaded, now getting segments") segments, _ = model.transcribe( input_path, vad_filter=w_vad, diff --git a/whisper_api.py b/whisper_api.py index dbae13d..21ee862 100644 --- a/whisper_api.py +++ b/whisper_api.py @@ -5,10 +5,29 @@ from enum import Enum from typing import Optional from pydantic import BaseModel +from config import ( + model_base_dir, + w_device, + w_model, +) +import faster_whisper +from model_download import check_model_availability logger = logging.getLogger(__name__) api = FastAPI() +logger.info(f"Loading model on device {w_device}") +# checking if model needs to be downloaded from HF or not +model_location = model_base_dir if check_model_availability() else w_model +model = faster_whisper.WhisperModel( + model_location, + device=w_device, + compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU + "float16" if w_device == "cuda" else "float32" + ), +) +logger.info("Model loaded!") + class Status(Enum): CREATED = "CREATED" @@ -69,7 +88,7 @@ def try_whisper(task: Task): try: task.status = Status.PROCESSING update_task(task) - run(task.input_uri, task.output_uri) + run(task.input_uri, task.output_uri, model) task.status = Status.DONE except Exception: logger.exception("Failed to run whisper")