-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into rework-dockerfile
- Loading branch information
Showing
11 changed files
with
171 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,83 @@ | ||
import logging | ||
import os | ||
import tarfile | ||
from urllib.parse import urlparse | ||
import requests | ||
from s3_util import S3Store, parse_s3_uri, validate_s3_uri | ||
from config import model_base_dir, w_model, s3_endpoint_url | ||
|
||
|
||
# makes sure the model is available locally, if not download it from S3, if that fails download from Huggingface | ||
# FIXME should also check if the correct w_model type is available locally! | ||
def check_model_availability() -> bool: | ||
logger = logging.getLogger(__name__) | ||
if os.path.exists(model_base_dir + "/model.bin"): | ||
logger.info("Model found locally") | ||
return True | ||
else: | ||
logger.info("Model not found locally, attempting to download from S3") | ||
if not validate_s3_uri(w_model): | ||
logger.info("No S3 URI detected") | ||
logger.info(f"Downloading version {w_model} from Huggingface instead") | ||
return False | ||
s3 = S3Store(s3_endpoint_url) | ||
bucket, object_name = parse_s3_uri(w_model) | ||
success = s3.download_file(bucket, object_name, model_base_dir) | ||
if not success: | ||
logger.error(f"Could not download {w_model} into {model_base_dir}") | ||
return False | ||
logger.info(f"Downloaded {w_model} into {model_base_dir}") | ||
logger.info("Extracting the model") | ||
tar_path = model_base_dir + "/" + object_name | ||
from base_util import get_asset_info, validate_http_uri | ||
from config import s3_endpoint_url | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# e.g. {base_dir}/modelx.tar.gz will be extracted in {base_dir}/modelx | ||
def extract_model(destination: str, extension: str) -> str: | ||
tar_path = f"{destination}.{extension}" | ||
logger.info(f"Extracting {tar_path} into {destination}") | ||
if not os.path.exists(destination): # Create dir for model to be extracted in | ||
os.makedirs(destination) | ||
logger.info(f"Extracting the model into {destination}") | ||
try: | ||
with tarfile.open(tar_path) as tar: | ||
tar.extractall(path=model_base_dir) | ||
tar.extractall(path=destination) | ||
# cleanup: delete the tar file | ||
os.remove(tar_path) | ||
return True | ||
if os.path.exists(os.path.join(destination, "model.bin")): | ||
logger.info( | ||
f"model.bin found in {destination}. Model extracted successfully!" | ||
) | ||
return destination | ||
else: | ||
logger.error(f"{destination} does not contain a model.bin file. Exiting...") | ||
return "" | ||
except tarfile.ReadError: | ||
logger.error("Could not extract the model") | ||
return "" | ||
|
||
|
||
# makes sure the model is obtained from S3/HTTP/Huggingface, if w_model doesn't exist locally | ||
def get_model_location(base_dir: str, whisper_model: str) -> str: | ||
logger.info(f"Checking w_model: {whisper_model} and download if necessary") | ||
if validate_s3_uri(whisper_model): | ||
return check_s3_location(base_dir, whisper_model) | ||
|
||
elif validate_http_uri(whisper_model): | ||
return check_http_location(base_dir, whisper_model) | ||
|
||
# The faster-whisper API can auto-detect if the version exists locally. No need to add extra checks | ||
logger.info(f"{whisper_model} is not an S3/HTTP URI. Using HuggingFace instead") | ||
return whisper_model | ||
|
||
|
||
def check_s3_location(base_dir: str, whisper_model: str) -> str: | ||
logger.info(f"{whisper_model} is an S3 URI. Attempting to download") | ||
bucket, object_name = parse_s3_uri(whisper_model) | ||
asset_id, extension = get_asset_info(object_name) | ||
destination = os.path.join(base_dir, asset_id) | ||
if os.path.exists(destination): | ||
logger.info("Model already exists") | ||
return destination | ||
s3 = S3Store(s3_endpoint_url) | ||
success = s3.download_file(bucket, object_name, base_dir) | ||
if not success: | ||
logger.error(f"Could not download {whisper_model} into {base_dir}") | ||
return "" | ||
return extract_model(destination, extension) | ||
|
||
|
||
def check_http_location(base_dir: str, whisper_model: str) -> str: | ||
logger.info(f"{whisper_model} is an HTTP URI. Attempting to download") | ||
asset_id, extension = get_asset_info(urlparse(whisper_model).path) | ||
destination = os.path.join(base_dir, asset_id) | ||
if os.path.exists(destination): | ||
logger.info("Model already exists") | ||
return destination | ||
with open(f"{destination}.{extension}", "wb") as file: | ||
response = requests.get(whisper_model) | ||
if response.status_code >= 400: | ||
logger.error(f"Could not download {whisper_model} into {base_dir}") | ||
return "" | ||
file.write(response.content) | ||
file.close() | ||
return extract_model(destination, extension) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pytest | ||
import shutil | ||
import os | ||
|
||
# Mocking environment used in model_download (by loading tests/.env) | ||
os.environ["DATA_BASE_DIR"] = "data" | ||
os.environ["MODEL_BASE_DIR"] = "tests/input/extract_model_test" | ||
os.environ["S3_ENDPOINT_URL"] = "http://url.com" | ||
|
||
from model_download import extract_model, get_model_location # noqa | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"destination, extension, expected_output", | ||
[ | ||
("valid_model", "tar.gz", "valid_model"), | ||
("valid_model", "mp3", ""), | ||
("invalid_model", "tar.gz", ""), | ||
], | ||
) | ||
def test_extract_model(destination, extension, expected_output, tmp_path): | ||
tar_path = os.path.join("tests/input/extract_model_test", destination) | ||
shutil.copy(f"{tar_path}.{extension}", str(tmp_path)) | ||
if expected_output != "": | ||
expected_output = os.path.join(tmp_path, expected_output) | ||
assert ( | ||
extract_model(os.path.join(tmp_path, destination), extension) == expected_output | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"whisper_model, expected_output", | ||
[ | ||
("s3://test-model/assets/modeltest.tar.gz", "s3"), | ||
("http://model-hosting.beng.nl/whisper-test.mp3", "http"), | ||
("large-v2", "large-v2"), | ||
], | ||
) | ||
def test_get_model_location(whisper_model, expected_output, mocker): | ||
mocker.patch("model_download.check_s3_location", return_value="s3") | ||
mocker.patch("model_download.check_http_location", return_value="http") | ||
assert get_model_location("model", whisper_model) == expected_output | ||
|
||
|
||
# TODO: test check_s3_location (have to mock: S3Store, s3.dl_file, extract_model) | ||
|
||
# TODO: test check_http_location (have to mock: whole "with" block?, extract_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters