diff --git a/asr.py b/asr.py index 587b19e..7679ed0 100644 --- a/asr.py +++ b/asr.py @@ -23,6 +23,7 @@ from download import download_uri from whisper import run_asr, WHISPER_JSON_FILE from s3_util import S3Store, parse_s3_uri +from base_util import remove_all_input_output from transcode import try_transcode from daan_transcript import generate_daan_transcript, DAAN_JSON_FILE @@ -45,9 +46,9 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]: # 1. download input result = download_uri(input_uri) logger.info(result) - if not result: + if result.error != "": logger.error("Could not obtain input, quitting...") - return "Input download failed" + return result.error prov_steps.append(result.provenance) @@ -57,9 +58,10 @@ def run(input_uri: str, output_uri: str, model=None) -> Optional[str]: # 2. check if the input file is suitable for processing any further transcode_output = try_transcode(input_path, asset_id, extension) - if not transcode_output: - logger.error("The transcode failed to yield a valid file to continue with") - return "Transcode failed" + if transcode_output.error != "": + logger.error("The transcode failed to yield a valid file to continue with, quitting...") + remove_all_input_output(input_path, asset_id, extension, output_path) + return transcode_output.error else: input_path = transcode_output.transcoded_file_path prov_steps.append(transcode_output.provenance) diff --git a/base_util.py b/base_util.py index 5d2df5d..313c96d 100644 --- a/base_util.py +++ b/base_util.py @@ -85,3 +85,19 @@ def validate_http_uri(http_uri: str) -> bool: logger.error(f"No object_name specified in {http_uri}") return False return True + + +def remove_all_input_output(input_path: str, asset_id: str, output_path: str) -> bool: + try: + if os.path.exists(input_path): + os.remove(input_path) + dirname, _ = os.path.split(input_path) + if os.path.exists(os.path.join(dirname, asset_id + ".mp3")): + os.remove(os.path.join(dirname, asset_id + ".mp3")) + if os.path.exists(output_path): + for file in os.listdir(output_path): + os.remove(file) + return True + + except OSError: + return False diff --git a/download.py b/download.py index 3304f4c..b3944dd 100644 --- a/download.py +++ b/download.py @@ -3,11 +3,10 @@ import os import requests import time -from typing import Optional from urllib.parse import urlparse from s3_util import S3Store, parse_s3_uri, validate_s3_uri from config import data_base_dir, s3_endpoint_url -from base_util import get_asset_info, extension_to_mime_type +from base_util import get_asset_info, extension_to_mime_type, validate_http_uri logger = logging.getLogger(__name__) @@ -21,18 +20,26 @@ class DownloadResult: provenance: dict download_time: float = -1 # time (ms) taken to receive data after request content_length: int = -1 # download_data.get("content_length", -1), + error: str = "" -def download_uri(uri: str) -> Optional[DownloadResult]: +def download_uri(uri: str) -> DownloadResult: logger.info(f"Trying to download {uri}") if validate_s3_uri(uri): - logger.info("URI seems to be an s3 uri") + logger.info("URI seems to be an S3 URI") return s3_download(uri) - return http_download(uri) + if validate_http_uri(uri): + logger.info("URI seems to be an HTTP URI") + return http_download(uri) + return DownloadResult( + uri, "", dict(), -1, "Input failure: URI is neither S3, nor HTTP" + ) -def http_download(url: str) -> Optional[DownloadResult]: +def http_download(url: str) -> DownloadResult: logger.info(f"Checking if {url} was already downloaded") + steps = [] # to report if input is already downloaded + fn = os.path.basename(urlparse(url).path) input_file = os.path.join(input_file_dir, fn) _, extension = get_asset_info(input_file) @@ -40,6 +47,8 @@ def http_download(url: str) -> Optional[DownloadResult]: # download if the file is not present (preventing unnecessary downloads) start_time = time.time() + download_time = -1 + if not os.path.exists(input_file): logger.info(f"File {input_file} not downloaded yet") # Create /data/input/ folder if not exists @@ -50,10 +59,15 @@ def http_download(url: str) -> Optional[DownloadResult]: response = requests.get(url) if response.status_code >= 400: logger.error(f"Could not download url: {response.status_code}") - return None + download_time = (time.time() - start_time) * 1000 + return DownloadResult( + input_file, mime_type, dict(), download_time, f"Input failure: Could not download url. Response code: {response.status_code}" + ) file.write(response.content) file.close() - download_time = (time.time() - start_time) * 1000 # time in ms + download_time = (time.time() - start_time) * 1000 # time in ms + else: + steps.append("Download skipped: input already exists") provenance = { "activity_name": "Input download", "activity_description": "Downloads the input file from INPUT_URI", @@ -63,19 +77,16 @@ def http_download(url: str) -> Optional[DownloadResult]: "software_version": "", "input_data": url, "output_data": input_file, - "steps": [], + "steps": steps, } return DownloadResult( input_file, mime_type, provenance, download_time # TODO add content_length ) -def s3_download(s3_uri: str) -> Optional[DownloadResult]: +def s3_download(s3_uri: str) -> DownloadResult: logger.info(f"Checking if {s3_uri} was already downloaded") - - if not validate_s3_uri(s3_uri): - logger.error(f"Invalid S3 URI: {s3_uri}") - return None + steps = [] # to report if input is already downloaded # parse S3 URI bucket, object_name = parse_s3_uri(s3_uri) @@ -89,6 +100,7 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]: mime_type = extension_to_mime_type(extension) start_time = time.time() + download_time = -1 if not os.path.exists(input_file): s3 = S3Store(s3_endpoint_url) @@ -100,11 +112,14 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]: if not success: logger.error("Failed to download input data from S3") - return None + download_time = (time.time() - start_time) * 1000 + return DownloadResult( + input_file, mime_type, dict(), download_time, "Input failure: Could not download S3 URI" + ) download_time = int((time.time() - start_time) * 1000) # time in ms else: - download_time = -1 # Report back? + steps.append("Download skipped: input already exists") provenance = { "activity_name": "Input download", @@ -115,7 +130,7 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]: "software_version": "", "input_data": s3_uri, "output_data": input_file, - "steps": [], + "steps": steps, } return DownloadResult( diff --git a/transcode.py b/transcode.py index 29f3982..4b36c19 100644 --- a/transcode.py +++ b/transcode.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import logging import os -from typing import Optional import time import base_util @@ -14,9 +13,10 @@ class TranscodeOutput: transcoded_file_path: str provenance: dict + error: str = "" -def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: +def try_transcode(input_path, asset_id, extension) -> TranscodeOutput: logger.info( f"Determining if transcode is required for input_path: {input_path} asset_id: ({asset_id}) extension: ({extension})" ) @@ -34,7 +34,7 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: "steps": [], } - # if it's alrady valid audio no transcode necessary + # if it's already valid audio no transcode necessary if _is_audio_file(extension): logger.info("No transcode required, input is audio") end_time = (time.time() - start_time) * 1000 @@ -46,7 +46,7 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: # if the input format is not supported, fail if not _is_transcodable(extension): logger.error(f"input with extension {extension} is not transcodable") - return None + return TranscodeOutput(input_path, dict(), f"Transcode failure: Input with extension {extension} is not transcodable") # check if the input file was already transcoded transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3") @@ -66,8 +66,8 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: transcoded_file_path, ) if not success: - logger.error("Transcode failed") - return None + logger.error("Running ffmpeg to transcode failed") + return TranscodeOutput(input_path, dict(), "Running ffmpeg to transcode failed") logger.info( f"Transcode of {extension} successful, returning: {transcoded_file_path}"