Skip to content

Commit

Permalink
model download should work + loading model on service startup only
Browse files Browse the repository at this point in the history
  • Loading branch information
greenw0lf committed Sep 25, 2024
1 parent 6690ccd commit b6f1c40
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 18 deletions.
7 changes: 4 additions & 3 deletions asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import os

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
from transcode import try_transcode
from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE

logger = logging.getLogger(__name__)
os.environ['HF_HOME'] = model_base_dir # change dir where model is downloaded


def run(input_uri: str, output_uri: str) -> bool:
def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
# 1. download input
result = download_uri(input_uri)
Expand All @@ -36,7 +37,7 @@ def run(input_uri: str, output_uri: str) -> bool:
# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path)
run_asr(input_path, output_path, model)
else:
logger.info(f"Whisper transcript already present in {output_path}")

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_api(port: int):
uvicorn.run(api, port=port, host="0.0.0.0")


def run_job(intput_uri: str, output_uri: str):
def run_job(input_uri: str, output_uri: str):
import asr

logger.info("Running Whisper as a one time job")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "whisper-asr-worker"
version = "0.2.0"
version = "0.3.0"
description = "Whisper speech-to-text worker"
authors = ["Dragos Alexandru Balan <[email protected]>"]
license = "MIT"
Expand Down
25 changes: 13 additions & 12 deletions whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@
logger = logging.getLogger(__name__)


def run_asr(input_path, output_dir) -> bool:
def run_asr(input_path, output_dir, model=None) -> bool:
logger.info(f"Starting ASR on {input_path}")
logger.info(f"Device used: {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded, now getting segments")
if not model:
logger.info(f"Device used: {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded, now getting segments")
segments, _ = model.transcribe(
input_path,
vad_filter=w_vad,
Expand Down
21 changes: 20 additions & 1 deletion whisper_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,29 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from config import (
model_base_dir,
w_device,
w_model,
)
import faster_whisper
from model_download import check_model_availability

logger = logging.getLogger(__name__)
api = FastAPI()

logger.info(f"Loading model on device {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded!")


class Status(Enum):
CREATED = "CREATED"
Expand Down Expand Up @@ -69,7 +88,7 @@ def try_whisper(task: Task):
try:
task.status = Status.PROCESSING
update_task(task)
run(task.input_uri, task.output_uri)
run(task.input_uri, task.output_uri, model)
task.status = Status.DONE
except Exception:
logger.exception("Failed to run whisper")
Expand Down

0 comments on commit b6f1c40

Please sign in to comment.