Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into rework-dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
greenw0lf committed Oct 9, 2024
2 parents 39c7599 + c02f737 commit 410cb41
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 42 deletions.
12 changes: 12 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import subprocess
import json
from urllib.parse import urlparse
from typing import Tuple
from config import data_base_dir

Expand Down Expand Up @@ -73,3 +74,14 @@ def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
return False

return True


def validate_http_uri(http_uri: str) -> bool:
o = urlparse(http_uri, allow_fragments=False)
if o.scheme != "http" and o.scheme != "https":
logger.error(f"Invalid protocol in {http_uri}")
return False
if o.path == "":
logger.error(f"No object_name specified in {http_uri}")
return False
return True
19 changes: 10 additions & 9 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ def assert_tuple(param: str) -> str:


assert w_device in ["cuda", "cpu"], "Please use either cuda|cpu for W_DEVICE"
assert w_model in [
"tiny",
"base",
"small",
"medium",
"large",
"large-v2",
"large-v3",
], "Please use one of: tiny|base|small|medium|large|large-v2|large-v3 for W_MODEL"
if input_uri[0:5] != "s3://" and not validators.url(output_uri):
assert w_model in [
"tiny",
"base",
"small",
"medium",
"large",
"large-v2",
"large-v3",
], "Please use one of: tiny|base|small|medium|large|large-v2|large-v3 for W_MODEL"
103 changes: 76 additions & 27 deletions model_download.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,83 @@
import logging
import os
import tarfile
from urllib.parse import urlparse
import requests
from s3_util import S3Store, parse_s3_uri, validate_s3_uri
from config import model_base_dir, w_model, s3_endpoint_url


# makes sure the model is available locally, if not download it from S3, if that fails download from Huggingface
# FIXME should also check if the correct w_model type is available locally!
def check_model_availability() -> bool:
logger = logging.getLogger(__name__)
if os.path.exists(model_base_dir + "/model.bin"):
logger.info("Model found locally")
return True
else:
logger.info("Model not found locally, attempting to download from S3")
if not validate_s3_uri(w_model):
logger.info("No S3 URI detected")
logger.info(f"Downloading version {w_model} from Huggingface instead")
return False
s3 = S3Store(s3_endpoint_url)
bucket, object_name = parse_s3_uri(w_model)
success = s3.download_file(bucket, object_name, model_base_dir)
if not success:
logger.error(f"Could not download {w_model} into {model_base_dir}")
return False
logger.info(f"Downloaded {w_model} into {model_base_dir}")
logger.info("Extracting the model")
tar_path = model_base_dir + "/" + object_name
from base_util import get_asset_info, validate_http_uri
from config import s3_endpoint_url


logger = logging.getLogger(__name__)


# e.g. {base_dir}/modelx.tar.gz will be extracted in {base_dir}/modelx
def extract_model(destination: str, extension: str) -> str:
tar_path = f"{destination}.{extension}"
logger.info(f"Extracting {tar_path} into {destination}")
if not os.path.exists(destination): # Create dir for model to be extracted in
os.makedirs(destination)
logger.info(f"Extracting the model into {destination}")
try:
with tarfile.open(tar_path) as tar:
tar.extractall(path=model_base_dir)
tar.extractall(path=destination)
# cleanup: delete the tar file
os.remove(tar_path)
return True
if os.path.exists(os.path.join(destination, "model.bin")):
logger.info(
f"model.bin found in {destination}. Model extracted successfully!"
)
return destination
else:
logger.error(f"{destination} does not contain a model.bin file. Exiting...")
return ""
except tarfile.ReadError:
logger.error("Could not extract the model")
return ""


# makes sure the model is obtained from S3/HTTP/Huggingface, if w_model doesn't exist locally
def get_model_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"Checking w_model: {whisper_model} and download if necessary")
if validate_s3_uri(whisper_model):
return check_s3_location(base_dir, whisper_model)

elif validate_http_uri(whisper_model):
return check_http_location(base_dir, whisper_model)

# The faster-whisper API can auto-detect if the version exists locally. No need to add extra checks
logger.info(f"{whisper_model} is not an S3/HTTP URI. Using HuggingFace instead")
return whisper_model


def check_s3_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"{whisper_model} is an S3 URI. Attempting to download")
bucket, object_name = parse_s3_uri(whisper_model)
asset_id, extension = get_asset_info(object_name)
destination = os.path.join(base_dir, asset_id)
if os.path.exists(destination):
logger.info("Model already exists")
return destination
s3 = S3Store(s3_endpoint_url)
success = s3.download_file(bucket, object_name, base_dir)
if not success:
logger.error(f"Could not download {whisper_model} into {base_dir}")
return ""
return extract_model(destination, extension)


def check_http_location(base_dir: str, whisper_model: str) -> str:
logger.info(f"{whisper_model} is an HTTP URI. Attempting to download")
asset_id, extension = get_asset_info(urlparse(whisper_model).path)
destination = os.path.join(base_dir, asset_id)
if os.path.exists(destination):
logger.info("Model already exists")
return destination
with open(f"{destination}.{extension}", "wb") as file:
response = requests.get(whisper_model)
if response.status_code >= 400:
logger.error(f"Could not download {whisper_model} into {base_dir}")
return ""
file.write(response.content)
file.close()
return extract_model(destination, extension)
21 changes: 19 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ flake8 = "^7.1.1"
black = "^24.8.0"
pytest = "^8.3.3"
pytest-cov = "^5.0.0"
pytest-mock = "^3.14.0"

[tool.poetry.group.service]
optional = true
Expand Down
2 changes: 1 addition & 1 deletion s3_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def validate_s3_uri(s3_uri: str) -> bool:
logger.error(f"Invalid protocol in {s3_uri}")
return False
if o.path == "":
logger.error(f"No object_name specified {s3_uri}")
logger.error(f"No object_name specified in {s3_uri}")
return False
return True

Expand Down
Binary file not shown.
Empty file.
Binary file not shown.
47 changes: 47 additions & 0 deletions tests/test_model_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import shutil
import os

# Mocking environment used in model_download (by loading tests/.env)
os.environ["DATA_BASE_DIR"] = "data"
os.environ["MODEL_BASE_DIR"] = "tests/input/extract_model_test"
os.environ["S3_ENDPOINT_URL"] = "http://url.com"

from model_download import extract_model, get_model_location # noqa


@pytest.mark.parametrize(
"destination, extension, expected_output",
[
("valid_model", "tar.gz", "valid_model"),
("valid_model", "mp3", ""),
("invalid_model", "tar.gz", ""),
],
)
def test_extract_model(destination, extension, expected_output, tmp_path):
tar_path = os.path.join("tests/input/extract_model_test", destination)
shutil.copy(f"{tar_path}.{extension}", str(tmp_path))
if expected_output != "":
expected_output = os.path.join(tmp_path, expected_output)
assert (
extract_model(os.path.join(tmp_path, destination), extension) == expected_output
)


@pytest.mark.parametrize(
"whisper_model, expected_output",
[
("s3://test-model/assets/modeltest.tar.gz", "s3"),
("http://model-hosting.beng.nl/whisper-test.mp3", "http"),
("large-v2", "large-v2"),
],
)
def test_get_model_location(whisper_model, expected_output, mocker):
mocker.patch("model_download.check_s3_location", return_value="s3")
mocker.patch("model_download.check_http_location", return_value="http")
assert get_model_location("model", whisper_model) == expected_output


# TODO: test check_s3_location (have to mock: S3Store, s3.dl_file, extract_model)

# TODO: test check_http_location (have to mock: whole "with" block?, extract_model)
8 changes: 5 additions & 3 deletions whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@
w_word_timestamps,
)
from base_util import get_asset_info
from model_download import check_model_availability
from model_download import get_model_location


WHISPER_JSON_FILE = "whisper-transcript.json"
logger = logging.getLogger(__name__)


# loads the whisper model
# FIXME does not check if the specific model_type is available locally!
def load_model(model_base_dir: str, model_type: str, device: str) -> WhisperModel:
logger.info(f"Loading Whisper model {model_type} for device: {device}")

# change HuggingFace dir to where model is downloaded
os.environ["HF_HOME"] = model_base_dir

# determine loading locally or have Whisper download from HuggingFace
model_location = model_base_dir if check_model_availability() else model_type
model_location = get_model_location(model_base_dir, w_model)
# FIXME handle cases where model_location is ""
if model_location == "":
raise ValueError("Model could not be loaded! Exiting...")
model = WhisperModel(
model_location, # either local path or e.g. large-v2 (means HuggingFace download)
device=device,
Expand Down

0 comments on commit 410cb41

Please sign in to comment.