Skip to content

Commit

Permalink
Bringing back the changes I made
Browse files Browse the repository at this point in the history
They actually work! Only thing to track is memory usage
  • Loading branch information
greenw0lf committed Sep 26, 2024
2 parents 6690ccd + d894219 commit 58dc75f
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 37 deletions.
7 changes: 4 additions & 3 deletions asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import os

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
from transcode import try_transcode
from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE

logger = logging.getLogger(__name__)
os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded


def run(input_uri: str, output_uri: str) -> bool:
def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
# 1. download input
result = download_uri(input_uri)
Expand All @@ -36,7 +37,7 @@ def run(input_uri: str, output_uri: str) -> bool:
# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path)
run_asr(input_path, output_path, model)
else:
logger.info(f"Whisper transcript already present in {output_path}")

Expand Down
3 changes: 1 addition & 2 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import ntpath
import os
import subprocess
from typing import Tuple
Expand All @@ -12,7 +11,7 @@

# the file name without extension is used as asset ID
def get_asset_info(input_file: str) -> Tuple[str, str]:
file_name = ntpath.basename(input_file)
file_name = os.path.basename(input_file)
asset_id, extension = os.path.splitext(file_name)
logger.info(f"working with this asset ID {asset_id}")
return asset_id, extension
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_api(port: int):
uvicorn.run(api, port=port, host="0.0.0.0")


def run_job(intput_uri: str, output_uri: str):
def run_job(input_uri: str, output_uri: str):
import asr

logger.info("Running Whisper as a one time job")
Expand Down
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "whisper-asr-worker"
version = "0.2.0"
version = "0.3.0"
description = "Whisper speech-to-text worker"
authors = ["Dragos Alexandru Balan <[email protected]>"]
license = "MIT"
Expand Down
17 changes: 9 additions & 8 deletions s3_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import boto3
import logging
import ntpath
import os
from pathlib import Path
import tarfile
from typing import List, Tuple, Optional
from urllib.parse import urlparse


logger = logging.getLogger(__name__)
Expand All @@ -16,12 +16,12 @@ def generate_asset_id_from_input_file(
input_file: str, with_extension: bool = False
) -> str:
logger.info(f"generating asset ID for {input_file}")
file_name = ntpath.basename(input_file) # grab the file_name from the path
file_name = os.path.basename(input_file) # grab the file_name from the path
if with_extension:
return file_name

# otherwise cut off the extension
asset_id, extension = os.path.splitext(file_name)
asset_id, _ = os.path.splitext(file_name)
return asset_id


Expand Down Expand Up @@ -60,10 +60,11 @@ def tar_list_of_files(archive_path: str, file_list: List[str]) -> bool:


def validate_s3_uri(s3_uri: str) -> bool:
if s3_uri[0:5] != "s3://":
o = urlparse(s3_uri, allow_fragments=False)
if o.scheme != "s3":
logger.error(f"Invalid protocol in {s3_uri}")
return False
if len(s3_uri[5:].split("/")) < 2:
if o.path == "":
logger.error(f"No object_name specified {s3_uri}")
return False
return True
Expand All @@ -72,9 +73,9 @@ def validate_s3_uri(s3_uri: str) -> bool:
# e.g. "s3://beng-daan-visxp/jaap-dane-test/dane-test.tar.gz"
def parse_s3_uri(s3_uri: str) -> Tuple[str, str]:
logger.info(f"Parsing s3 URI {s3_uri}")
tmp = s3_uri[5:]
bucket = tmp[: tmp.find("/")] # beng-daan-visxp
object_name = s3_uri[len(bucket) + 6 :] # jaap-dane-test/dane-test.tar.gz
o = urlparse(s3_uri, allow_fragments=False)
bucket = o.netloc # beng-daan-visxp
object_name = o.path.lstrip("/") # jaap-dane-test/dane-test.tar.gz
return bucket, object_name


Expand Down
2 changes: 1 addition & 1 deletion transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import base_util
from base_util import data_base_dir
from config import data_base_dir

logger = logging.getLogger(__name__)

Expand Down
25 changes: 13 additions & 12 deletions whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@
logger = logging.getLogger(__name__)


def run_asr(input_path, output_dir) -> bool:
def run_asr(input_path, output_dir, model=None) -> bool:
logger.info(f"Starting ASR on {input_path}")
logger.info(f"Device used: {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded, now getting segments")
if not model:
logger.info(f"Device used: {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded, now getting segments")
segments, _ = model.transcribe(
input_path,
vad_filter=w_vad,
Expand Down
21 changes: 20 additions & 1 deletion whisper_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,29 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from config import (
model_base_dir,
w_device,
w_model,
)
import faster_whisper
from model_download import check_model_availability

logger = logging.getLogger(__name__)
api = FastAPI()

logger.info(f"Loading model on device {w_device}")
# checking if model needs to be downloaded from HF or not
model_location = model_base_dir if check_model_availability() else w_model
model = faster_whisper.WhisperModel(
model_location,
device=w_device,
compute_type=( # float16 only works on GPU, float32 or int8 are recommended for CPU
"float16" if w_device == "cuda" else "float32"
),
)
logger.info("Model loaded!")


class Status(Enum):
CREATED = "CREATED"
Expand Down Expand Up @@ -69,7 +88,7 @@ def try_whisper(task: Task):
try:
task.status = Status.PROCESSING
update_task(task)
run(task.input_uri, task.output_uri)
run(task.input_uri, task.output_uri, model)
task.status = Status.DONE
except Exception:
logger.exception("Failed to run whisper")
Expand Down

0 comments on commit 58dc75f

Please sign in to comment.