Skip to content

Commit

Permalink
Error handling for input download + transcode
Browse files Browse the repository at this point in the history
Also added a function that removes all input/output files generated by the worker for the input URI's asset ID
  • Loading branch information
greenw0lf committed Oct 11, 2024
1 parent da4d937 commit 177c17a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 28 deletions.
12 changes: 7 additions & 5 deletions asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store, parse_s3_uri
from base_util import remove_all_input_output
from transcode import try_transcode
from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE

Expand All @@ -45,9 +46,9 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
if result.error != "":
logger.error("Could not obtain input, quitting...")
return "Input download failed"
return result.error

prov_steps.append(result.provenance)

Expand All @@ -57,9 +58,10 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]:

# 2. check if the input file is suitable for processing any further
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return "Transcode failed"
if transcode_output.error != "":
logger.error("The transcode failed to yield a valid file to continue with, quitting...")
remove_all_input_output(input_path, asset_id, extension, output_path)
return transcode_output.error
else:
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)
Expand Down
16 changes: 16 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,19 @@ def validate_http_uri(http_uri: str) -> bool:
logger.error(f"No object_name specified in {http_uri}")
return False
return True


def remove_all_input_output(input_path: str, asset_id: str, output_path: str) -> bool:
try:
if os.path.exists(input_path):
os.remove(input_path)
dirname, _ = os.path.split(input_path)
if os.path.exists(os.path.join(dirname, asset_id + ".mp3")):
os.remove(os.path.join(dirname, asset_id + ".mp3"))
if os.path.exists(output_path):
for file in os.listdir(output_path):
os.remove(file)
return True

except OSError:
return False
49 changes: 32 additions & 17 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os
import requests
import time
from typing import Optional
from urllib.parse import urlparse
from s3_util import S3Store, parse_s3_uri, validate_s3_uri
from config import data_base_dir, s3_endpoint_url
from base_util import get_asset_info, extension_to_mime_type
from base_util import get_asset_info, extension_to_mime_type, validate_http_uri

logger = logging.getLogger(__name__)

Expand All @@ -21,25 +20,35 @@ class DownloadResult:
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),
error: str = ""


def download_uri(uri: str) -> Optional[DownloadResult]:
def download_uri(uri: str) -> DownloadResult:
logger.info(f"Trying to download {uri}")
if validate_s3_uri(uri):
logger.info("URI seems to be an s3 uri")
logger.info("URI seems to be an S3 URI")
return s3_download(uri)
return http_download(uri)
if validate_http_uri(uri):
logger.info("URI seems to be an HTTP URI")
return http_download(uri)
return DownloadResult(
uri, "", dict(), -1, "Input failure: URI is neither S3, nor HTTP"
)


def http_download(url: str) -> Optional[DownloadResult]:
def http_download(url: str) -> DownloadResult:
logger.info(f"Checking if {url} was already downloaded")
steps = [] # to report if input is already downloaded

fn = os.path.basename(urlparse(url).path)
input_file = os.path.join(input_file_dir, fn)
_, extension = get_asset_info(input_file)
mime_type = extension_to_mime_type(extension)

# download if the file is not present (preventing unnecessary downloads)
start_time = time.time()
download_time = -1

if not os.path.exists(input_file):
logger.info(f"File {input_file} not downloaded yet")
# Create /data/input/ folder if not exists
Expand All @@ -50,10 +59,15 @@ def http_download(url: str) -> Optional[DownloadResult]:
response = requests.get(url)
if response.status_code >= 400:
logger.error(f"Could not download url: {response.status_code}")
return None
download_time = (time.time() - start_time) * 1000
return DownloadResult(
input_file, mime_type, dict(), download_time, f"Input failure: Could not download url. Response code: {response.status_code}"
)
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
download_time = (time.time() - start_time) * 1000 # time in ms
else:
steps.append("Download skipped: input already exists")
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
Expand All @@ -63,19 +77,16 @@ def http_download(url: str) -> Optional[DownloadResult]:
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
"steps": steps,
}
return DownloadResult(
input_file, mime_type, provenance, download_time # TODO add content_length
)


def s3_download(s3_uri: str) -> Optional[DownloadResult]:
def s3_download(s3_uri: str) -> DownloadResult:
logger.info(f"Checking if {s3_uri} was already downloaded")

if not validate_s3_uri(s3_uri):
logger.error(f"Invalid S3 URI: {s3_uri}")
return None
steps = [] # to report if input is already downloaded

# parse S3 URI
bucket, object_name = parse_s3_uri(s3_uri)
Expand All @@ -89,6 +100,7 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
mime_type = extension_to_mime_type(extension)

start_time = time.time()
download_time = -1

if not os.path.exists(input_file):
s3 = S3Store(s3_endpoint_url)
Expand All @@ -100,11 +112,14 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:

if not success:
logger.error("Failed to download input data from S3")
return None
download_time = (time.time() - start_time) * 1000
return DownloadResult(
input_file, mime_type, dict(), download_time, "Input failure: Could not download S3 URI"
)

download_time = int((time.time() - start_time) * 1000) # time in ms
else:
download_time = -1 # Report back?
steps.append("Download skipped: input already exists")

provenance = {
"activity_name": "Input download",
Expand All @@ -115,7 +130,7 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
"steps": steps,
}

return DownloadResult(
Expand Down
12 changes: 6 additions & 6 deletions transcode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
import logging
import os
from typing import Optional
import time

import base_util
Expand All @@ -14,9 +13,10 @@
class TranscodeOutput:
transcoded_file_path: str
provenance: dict
error: str = ""


def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
def try_transcode(input_path, asset_id, extension) -> TranscodeOutput:
logger.info(
f"Determining if transcode is required for input_path: {input_path} asset_id: ({asset_id}) extension: ({extension})"
)
Expand All @@ -34,7 +34,7 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
"steps": [],
}

# if it's alrady valid audio no transcode necessary
# if it's already valid audio no transcode necessary
if _is_audio_file(extension):
logger.info("No transcode required, input is audio")
end_time = (time.time() - start_time) * 1000
Expand All @@ -46,7 +46,7 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
# if the input format is not supported, fail
if not _is_transcodable(extension):
logger.error(f"input with extension {extension} is not transcodable")
return None
return TranscodeOutput(input_path, dict(), f"Transcode failure: Input with extension {extension} is not transcodable")

# check if the input file was already transcoded
transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3")
Expand All @@ -66,8 +66,8 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
transcoded_file_path,
)
if not success:
logger.error("Transcode failed")
return None
logger.error("Running ffmpeg to transcode failed")
return TranscodeOutput(input_path, dict(), "Running ffmpeg to transcode failed")

logger.info(
f"Transcode of {extension} successful, returning: {transcoded_file_path}"
Expand Down

0 comments on commit 177c17a

Please sign in to comment.