-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5301fee
commit 2e7c77e
Showing
5 changed files
with
287 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
src/hope_dedup_engine/apps/core/management/commands/workerupgrade.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import logging | ||
import sys | ||
from argparse import ArgumentParser | ||
from typing import Any, Final | ||
|
||
from django.conf import settings | ||
from django.core.exceptions import ValidationError | ||
from django.core.management import BaseCommand | ||
from django.core.management.base import CommandError, SystemCheckError | ||
|
||
import requests | ||
from storages.backends.azure_storage import AzureStorage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
MESSAGES: Final[dict[str, str]] = { | ||
"already": "File '%s' already exists in FILE_STORAGE_DNN storage.", | ||
"process": "Downloading file from '%s' to '%s' in FILE_STORAGE_DNN storage...", | ||
"empty": "File at '%s' is empty (size is 0 bytes).", | ||
"halted": "\n\n***\nSYSTEM HALTED\nUnable to start without DNN files...", | ||
} | ||
|
||
|
||
class Command(BaseCommand): | ||
help = "Synchronizes DNN files from the git to azure storage" | ||
dnn_files = None | ||
|
||
def add_arguments(self, parser: ArgumentParser) -> None: | ||
""" | ||
Adds custom command-line arguments to the management command. | ||
Args: | ||
parser (ArgumentParser): The argument parser instance to which the arguments should be added. | ||
Adds the following arguments: | ||
--force: A boolean flag that, when provided, forces the re-download of files even if they already exist | ||
in Azure storage. Defaults to False. | ||
--deployfile-url (str): The URL from which the deploy (prototxt) file is downloaded. | ||
Defaults to the value set in the project settings. | ||
--caffemodelfile-url (str): The URL from which the pre-trained model weights (caffemodel) are downloaded. | ||
Defaults to the value set in the project settings. | ||
--download-timeout (int): The maximum time allowed for downloading files, in seconds. | ||
Defaults to 3 minutes (180 seconds). | ||
--chunk-size (int): The size of each chunk to download in bytes. Defaults to 256 KB. | ||
""" | ||
parser.add_argument( | ||
"--force", | ||
action="store_true", | ||
default=False, | ||
help="Force the re-download of files even if they already exist", | ||
) | ||
parser.add_argument( | ||
"--deployfile-url", | ||
type=str, | ||
default=settings.DNN_FILES.get("prototxt", {}) | ||
.get("sources", {}) | ||
.get("github"), | ||
help="The URL of the model architecture (deploy) file", | ||
) | ||
parser.add_argument( | ||
"--caffemodelfile-url", | ||
type=str, | ||
default=settings.DNN_FILES.get("caffemodel", {}) | ||
.get("sources", {}) | ||
.get("github"), | ||
help="The URL of the pre-trained model weights (caffemodel) file", | ||
) | ||
parser.add_argument( | ||
"--download-timeout", | ||
type=int, | ||
default=3 * 60, # 3 minutes | ||
help="The timeout for downloading files", | ||
) | ||
parser.add_argument( | ||
"--chunk-size", | ||
type=int, | ||
default=256 * 1024, # 256 KB | ||
help="The size of each chunk to download in bytes", | ||
) | ||
|
||
def get_options(self, options: dict[str, Any]) -> None: | ||
self.verbosity = options["verbosity"] | ||
self.force = options["force"] | ||
self.dnn_files = ( | ||
{ | ||
"url": options["deployfile_url"], | ||
"filename": settings.DNN_FILES.get("prototxt", {}) | ||
.get("sources", {}) | ||
.get("azure"), | ||
}, | ||
{ | ||
"url": options["caffemodelfile_url"], | ||
"filename": settings.DNN_FILES.get("caffemodel", {}) | ||
.get("sources", {}) | ||
.get("azure"), | ||
}, | ||
) | ||
self.download_timeout = options["download_timeout"] | ||
self.chunk_size = options["chunk_size"] | ||
|
||
def handle(self, *args: Any, **options: Any) -> None: | ||
""" | ||
Executes the command to download and store DNN files from a given source to Azure Blob Storage. | ||
Args: | ||
*args (Any): Positional arguments passed to the command. | ||
**options (dict[str, Any]): Keyword arguments passed to the command, including: | ||
- force (bool): If True, forces the re-download of files even if they already exist in storage. | ||
- deployfile_url (str): The URL of the DNN model architecture file to download. | ||
- caffemodelfile_url (str): The URL of the pre-trained model weights to download. | ||
- download_timeout (int): Timeout for downloading each file, in seconds. | ||
- chunk_size (int): The size of chunks for streaming downloads, in bytes. | ||
Raises: | ||
FileNotFoundError: If the downloaded file is empty (size is 0 bytes). | ||
ValidationError: If any arguments are invalid or improperly configured. | ||
CommandError: If an issue occurs with the Django command execution. | ||
SystemCheckError: If a system check error is encountered during execution. | ||
Exception: For any other errors that occur during the download or storage process. | ||
""" | ||
self.get_options(options) | ||
if self.verbosity >= 1: | ||
echo = self.stdout.write | ||
else: | ||
echo = lambda *a, **kw: None # noqa: E731 | ||
|
||
try: | ||
dnn_storage = AzureStorage(**settings.STORAGES.get("dnn").get("OPTIONS")) | ||
_, files = dnn_storage.listdir("") | ||
for file in self.dnn_files: | ||
if self.force or not file.get("filename") in files: | ||
echo(MESSAGES["process"] % (file.get("url"), file.get("filename"))) | ||
with requests.get( | ||
file.get("url"), stream=True, timeout=self.download_timeout | ||
) as r: | ||
r.raise_for_status() | ||
if int(r.headers.get("Content-Length", 1)) == 0: | ||
raise FileNotFoundError(MESSAGES["empty"] % file.get("url")) | ||
with dnn_storage.open(file.get("filename"), "wb") as f: | ||
for chunk in r.iter_content(chunk_size=self.chunk_size): | ||
f.write(chunk) | ||
else: | ||
echo(MESSAGES["already"] % file.get("filename")) | ||
except ValidationError as e: | ||
self.halt(Exception("\n- ".join(["Wrong argument(s):", *e.messages]))) | ||
except (CommandError, FileNotFoundError, SystemCheckError) as e: | ||
self.halt(e) | ||
except Exception as e: | ||
self.halt(e) | ||
|
||
def halt(self, e: Exception) -> None: | ||
""" | ||
Handle an exception by logging the error and exiting the program. | ||
Args: | ||
e (Exception): The exception that occurred. | ||
""" | ||
logger.exception(e) | ||
self.stdout.write(self.style.ERROR(str(e))) | ||
self.stdout.write(self.style.ERROR(MESSAGES["halted"])) | ||
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from io import StringIO | ||
from typing import Final | ||
from unittest import mock | ||
|
||
from django.core.exceptions import ValidationError | ||
from django.core.management import call_command | ||
from django.core.management.base import CommandError, SystemCheckError | ||
|
||
import pytest | ||
from pytest_mock import MockerFixture | ||
|
||
DNN_FILES: Final[tuple[dict[str, str]]] = ( | ||
{"url": "http://example.com/file1", "filename": "file1"}, | ||
{"url": "http://example.com/file2", "filename": "file2"}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_requests_get(): | ||
with mock.patch("requests.get") as mock_get: | ||
mock_response = mock_get.return_value.__enter__.return_value | ||
mock_response.iter_content.return_value = [b"Hello, world!"] * 3 | ||
mock_response.raise_for_status = lambda: None | ||
yield mock_get | ||
|
||
|
||
@pytest.fixture | ||
def mock_azurite_manager(mocker: MockerFixture): | ||
yield mocker.patch( | ||
"hope_dedup_engine.apps.core.management.commands.workerupgrade.AzureStorage", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_dnn_files(mocker: MockerFixture): | ||
yield mocker.patch( | ||
"hope_dedup_engine.apps.core.management.commands.workerupgrade.Command.dnn_files", | ||
new_callable=mocker.PropertyMock, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"force, expected_count, existing_files", | ||
[ | ||
(False, 2, []), | ||
(False, 1, [DNN_FILES[0]["filename"]]), | ||
(False, 0, [f["filename"] for f in DNN_FILES][:2]), | ||
(True, 2, []), | ||
(True, 2, [DNN_FILES[0]["filename"]]), | ||
(True, 2, [f["filename"] for f in DNN_FILES][:2]), | ||
], | ||
) | ||
def test_workerupgrade_handle_success( | ||
mock_requests_get, | ||
mock_azurite_manager, | ||
mock_dnn_files, | ||
force, | ||
expected_count, | ||
existing_files, | ||
): | ||
mock_dnn_files.return_value = DNN_FILES | ||
mock_azurite_manager().listdir.return_value = ([], existing_files) | ||
out = StringIO() | ||
|
||
call_command("workerupgrade", stdout=out, force=force) | ||
|
||
assert "SYSTEM HALTED" not in out.getvalue() | ||
assert mock_requests_get.call_count == expected_count | ||
assert mock_azurite_manager().open.call_count == expected_count | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"side_effect, expected_exception", | ||
[ | ||
(FileNotFoundError("File not found"), SystemExit), | ||
(ValidationError("Invalid argument"), SystemExit), | ||
(CommandError("Command execution failed"), SystemExit), | ||
(SystemCheckError("System check failed"), SystemExit), | ||
(Exception("Unknown error"), SystemExit), | ||
], | ||
) | ||
def test_workerupgrade_handle_exception( | ||
mock_requests_get, mock_azurite_manager, side_effect, expected_exception | ||
): | ||
mock_azurite_manager.side_effect = side_effect | ||
out = StringIO() | ||
with pytest.raises(expected_exception): | ||
call_command("workerupgrade", stdout=out) | ||
|
||
assert "SYSTEM HALTED" in out.getvalue() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters