From 5def84b22c95d136134cb45265f1246f5be1fcc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Thu, 26 Sep 2024 14:30:02 +0200 Subject: [PATCH 1/4] Add provenance phew --- asr.py | 95 +++++++++++++++++++++++++++++++++++++++++----- base_util.py | 18 +++++++++ daan_transcript.py | 22 +++++++++-- download.py | 29 ++++++++++++-- transcode.py | 44 ++++++++++++++++++--- whisper.py | 23 ++++++++++- 6 files changed, 207 insertions(+), 24 deletions(-) diff --git a/asr.py b/asr.py index 521afe5..a2b76bd 100644 --- a/asr.py +++ b/asr.py @@ -1,8 +1,22 @@ import logging import os +import time +import pkg_resources + +from base_util import get_asset_info, asr_output_dir, save_provenance +from config import ( + s3_endpoint_url, + s3_bucket, + s3_folder_in_bucket, + model_base_dir, + w_word_timestamps, + w_device, + w_model, + w_beam_size, + w_best_of, + w_vad, +) -from base_util import get_asset_info, asr_output_dir -from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir from download import download_uri from whisper import run_asr, WHISPER_JSON_FILE from s3_util import S3Store @@ -11,10 +25,15 @@ logger = logging.getLogger(__name__) os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded +my_version = pkg_resources.get_distribution( + "whisper-asr-worker" +).version # get worker version def run(input_uri: str, output_uri: str, model=None) -> bool: logger.info(f"Processing {input_uri} (save to --> {output_uri})") + start_time = time.time() + prov_steps = [] # track provenance # 1. download input result = download_uri(input_uri) logger.info(result) @@ -22,39 +41,95 @@ def run(input_uri: str, output_uri: str, model=None) -> bool: logger.error("Could not obtain input, quitting...") return False + prov_steps.append(result.provenance) + input_path = result.file_path asset_id, extension = get_asset_info(input_path) output_path = asr_output_dir(input_path) # 2. check if the input file is suitable for processing any further - transcoded_file_path = try_transcode(input_path, asset_id, extension) - if not transcoded_file_path: + transcode_output = try_transcode(input_path, asset_id, extension) + if not transcode_output: logger.error("The transcode failed to yield a valid file to continue with") return False else: - input_path = transcoded_file_path + input_path = transcode_output.transcoded_file_path + prov_steps.append(transcode_output.provenance) # 3. run ASR if not asr_already_done(output_path): logger.info("No Whisper transcript found") - run_asr(input_path, output_path, model) + whisper_prov = run_asr(input_path, output_path, model) + if whisper_prov: + prov_steps.append(whisper_prov) else: logger.info(f"Whisper transcript already present in {output_path}") + provenance = { + "activity_name": "Whisper transcript already exists", + "activity_description": "", + "processing_time_ms": "", + "start_time_unix": "", + "parameters": [], + "software_version": "", + "input_data": "", + "output_data": "", + "steps": [], + } + prov_steps.append(provenance) # 4. generate JSON transcript if not daan_transcript_already_done(output_path): logger.info("No DAAN transcript found") - success = generate_daan_transcript(output_path) - if not success: + daan_prov = generate_daan_transcript(output_path) + if daan_prov: + prov_steps.append(daan_prov) + else: logger.warning("Could not generate DAAN transcript") else: logger.info(f"DAAN transcript already present in {output_path}") + provenance = { + "activity_name": "DAAN transcript already exists", + "activity_description": "", + "processing_time_ms": "", + "start_time_unix": "", + "parameters": [], + "software_version": "", + "input_data": "", + "output_data": "", + "steps": [], + } + prov_steps.append(provenance) + + end_time = (time.time() - start_time) * 1000 + final_prov = { + "activity_name": "Whisper ASR Worker", + "activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats", + "processing_time_ms": end_time, + "start_time_unix": start_time, + "parameters": { + "word_timestamps": w_word_timestamps, + "device": w_device, + "vad": w_vad, + "model": w_model, + "beam_size": w_beam_size, + "best_of": w_best_of, + }, + "software_version": my_version, + "input_data": input_uri, + "output_data": output_uri if output_uri else output_path, + "steps": prov_steps, + } + + prov_success = save_provenance(final_prov, output_path) + if not prov_success: + logger.warning("Could not save the provenance") # 5. transfer output if output_uri: transfer_asr_output(output_path, asset_id) else: logger.info("No output_uri specified, so all is done") + return True @@ -90,14 +165,14 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool: # check if there is a whisper-transcript.json -def asr_already_done(output_dir): +def asr_already_done(output_dir) -> bool: whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE) logger.info(f"Checking existence of {whisper_transcript}") return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE)) # check if there is a daan-es-transcript.json -def daan_transcript_already_done(output_dir): +def daan_transcript_already_done(output_dir) -> bool: daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE) logger.info(f"Checking existence of {daan_transcript}") return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE)) diff --git a/base_util.py b/base_util.py index d0384b8..d0ea6fe 100644 --- a/base_util.py +++ b/base_util.py @@ -1,11 +1,13 @@ import logging import os import subprocess +import json from typing import Tuple from config import data_base_dir LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s" +PROVENANCE_JSON_FILE = "provenance.json" logger = logging.getLogger(__name__) @@ -55,3 +57,19 @@ def run_shell_command(cmd: str) -> bool: except Exception: logger.exception("Exception") return False + + +def save_provenance(provenance: dict, asr_output_dir: str) -> bool: + logger.info(f"Saving provenance to: {asr_output_dir}") + try: + # write provenance.json + with open( + os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8" + ) as f: + logger.info(provenance) + json.dump(provenance, f, ensure_ascii=False, indent=4) + except EnvironmentError as e: # OSError or IOError... + logger.exception(os.strerror(e.errno)) + return False + + return True diff --git a/daan_transcript.py b/daan_transcript.py index 6c75526..ee77732 100644 --- a/daan_transcript.py +++ b/daan_transcript.py @@ -1,6 +1,7 @@ import json import logging import os +import time from typing import TypedDict, List, Optional from whisper import WHISPER_JSON_FILE @@ -19,12 +20,13 @@ class ParsedResult(TypedDict): # asr_output_dir e.g /data/output/whisper-test/ -def generate_daan_transcript(asr_output_dir: str) -> bool: +def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]: logger.info(f"Generating transcript from: {asr_output_dir}") + start_time = time.time() whisper_transcript = load_whisper_transcript(asr_output_dir) if not whisper_transcript: logger.error("No whisper_transcript.json found") - return False + return None transcript = parse_whisper_transcript(whisper_transcript) @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool: json.dump(transcript, f, ensure_ascii=False, indent=4) except EnvironmentError as e: # OSError or IOError... logger.exception(os.strerror(e.errno)) - return False + return None - return True + end_time = (time.time() - start_time) * 1000 + provenance = { + "activity_name": "Whisper transcript -> DAAN transcript", + "activity_description": "Converts the output of Whisper to the DAAN index format", + "processing_time_ms": end_time, + "start_time_unix": start_time, + "parameters": [], + "software_version": "", + "input_data": whisper_transcript, + "output_data": transcript, + "steps": [], + } + return provenance def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]: diff --git a/download.py b/download.py index 473ae07..7c2c35b 100644 --- a/download.py +++ b/download.py @@ -17,7 +17,8 @@ @dataclass class DownloadResult: file_path: str # target_file_path, # TODO harmonize with dane-download-worker - mime_type: str # download_data.get("mime_type", "unknown"), + mime_type: str + provenance: dict download_time: float = -1 # time (ms) taken to receive data after request content_length: int = -1 # download_data.get("content_length", -1), @@ -53,8 +54,19 @@ def http_download(url: str) -> Optional[DownloadResult]: file.write(response.content) file.close() download_time = (time.time() - start_time) * 1000 # time in ms + provenance = { + "activity_name": "Input download", + "activity_description": "Downloads the input file from INPUT_URI", + "processing_time_ms": download_time, + "start_time_unix": start_time, + "parameters": [], + "software_version": "", + "input_data": url, + "output_data": input_file, + "steps": [], + } return DownloadResult( - input_file, mime_type, download_time # TODO add content_length + input_file, mime_type, provenance, download_time # TODO add content_length ) @@ -90,6 +102,17 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]: logger.error("Failed to download input data from S3") return None download_time = (time.time() - start_time) * 1000 # time in ms + provenance = { + "activity_name": "Input download", + "activity_description": "Downloads the input file from INPUT_URI", + "processing_time_ms": download_time, + "start_time_unix": start_time, + "parameters": [], + "software_version": "", + "input_data": s3_uri, + "output_data": input_file, + "steps": [], + } return DownloadResult( - input_file, mime_type, download_time # TODO add content_length + input_file, mime_type, provenance, download_time # TODO add content_length ) diff --git a/transcode.py b/transcode.py index 63e019c..b68a252 100644 --- a/transcode.py +++ b/transcode.py @@ -1,6 +1,8 @@ +from dataclasses import dataclass import logging import os from typing import Optional +import time import base_util from config import data_base_dir @@ -8,15 +10,38 @@ logger = logging.getLogger(__name__) -def try_transcode(input_path, asset_id, extension) -> Optional[str]: +@dataclass +class TranscodeOutput: + transcoded_file_path: str + provenance: dict + + +def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: logger.info( f"Determining if transcode is required for input_path: {input_path} asset_id: ({asset_id}) extension: ({extension})" ) + start_time = time.time() + + provenance = { + "activity_name": "Transcoding", + "activity_description": "Checks if input needs transcoding, then transcodes if so", + "processing_time_ms": 0, + "start_time_unix": start_time, + "parameters": [], + "software_version": "", + "input_data": input_path, + "output_data": "", + "steps": [], + } # if it's alrady valid audio no transcode necessary if _is_audio_file(extension): logger.info("No transcode required, input is audio") - return input_path + end_time = (time.time() - start_time) * 1000 + provenance["processing_time_ms"] = end_time + provenance["output_data"] = input_path + provenance["steps"].append("No transcode required, input is audio") + return TranscodeOutput(input_path, provenance) # if the input format is not supported, fail if not _is_transcodable(extension): @@ -27,7 +52,13 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]: transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3") if os.path.exists(transcoded_file_path): logger.info("Transcoded file is already available, no new transcode needed") - return transcoded_file_path + end_time = (time.time() - start_time) * 1000 + provenance["processing_time_ms"] = end_time + provenance["output_data"] = transcoded_file_path + provenance["steps"].append( + "Transcoded file is already available, no new transcode needed" + ) + return TranscodeOutput(transcoded_file_path, provenance) # go ahead and transcode the input file success = transcode_to_mp3( @@ -41,8 +72,11 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]: logger.info( f"Transcode of {extension} successful, returning: {transcoded_file_path}" ) - - return transcoded_file_path + end_time = (time.time() - start_time) * 1000 + provenance["processing_time_ms"] = end_time + provenance["output_data"] = transcoded_file_path + provenance["steps"].append("Transcode successful") + return TranscodeOutput(transcoded_file_path, provenance) def transcode_to_mp3(path: str, asr_path: str) -> bool: diff --git a/whisper.py b/whisper.py index 68e5e90..e2d7d88 100644 --- a/whisper.py +++ b/whisper.py @@ -2,6 +2,8 @@ import json import logging import os +import time +from typing import Optional import faster_whisper from config import ( @@ -22,8 +24,9 @@ logger = logging.getLogger(__name__) -def run_asr(input_path, output_dir, model=None) -> bool: +def run_asr(input_path, output_dir, model=None) -> Optional[dict]: logger.info(f"Starting ASR on {input_path}") + start_time = time.time() if not model: logger.info(f"Device used: {w_device}") # checking if model needs to be downloaded from HF or not @@ -77,8 +80,24 @@ def run_asr(input_path, output_dir, model=None) -> bool: asset_id, _ = get_asset_info(input_path) # Also added "carrierId" because the DAAN format requires it transcript = {"carrierId": asset_id, "segments": segments_to_add} + end_time = time.time() - start_time - return write_whisper_json(transcript, output_dir) + provenance = { + "activity_name": "Running Whisper", + "activity_description": "Runs Whisper to transcribe the input audio file", + "processing_time_ms": end_time, + "start_time_unix": start_time, + "parameters": [], + "software_version": faster_whisper.__version__, + "input_data": input_path, + "output_data": transcript, + "steps": [], + } + + if write_whisper_json(transcript, output_dir): + return provenance + else: + return None def write_whisper_json(transcript: dict, output_dir: str) -> bool: From 5b832ea0fe20bf044e05ebc6bcbd4dff1b384ebc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99=20B=C4=83lan?= <33976463+greenw0lf@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:50:59 +0200 Subject: [PATCH 2/4] Update download.py remove whitespace on empty line --- download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/download.py b/download.py index 33c682c..5a55b90 100644 --- a/download.py +++ b/download.py @@ -105,7 +105,7 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]: download_time = int((time.time() - start_time) * 1000) # time in ms else: download_time = -1 # Report back? - + provenance = { "activity_name": "Input download", "activity_description": "Downloads the input file from INPUT_URI", From cdce764d7afd31fa01a53dcae329bc16414a8728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Mon, 30 Sep 2024 11:03:13 +0200 Subject: [PATCH 3/4] Change how version is obtained Previous attempt did not work --- asr.py | 16 +++++++++++----- poetry.lock | 33 ++++++++++++++++++++++----------- pyproject.toml | 1 + transcode.py | 2 +- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/asr.py b/asr.py index a2b76bd..4b31d22 100644 --- a/asr.py +++ b/asr.py @@ -1,7 +1,7 @@ import logging import os import time -import pkg_resources +import tomli from base_util import get_asset_info, asr_output_dir, save_provenance from config import ( @@ -25,9 +25,15 @@ logger = logging.getLogger(__name__) os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded -my_version = pkg_resources.get_distribution( - "whisper-asr-worker" -).version # get worker version + + +def _get_project_meta(): + with open('pyproject.toml', mode='rb') as pyproject: + return tomli.load(pyproject)['tool']['poetry'] + + +pkg_meta = _get_project_meta() +version = str(pkg_meta['version']) def run(input_uri: str, output_uri: str, model=None) -> bool: @@ -114,7 +120,7 @@ def run(input_uri: str, output_uri: str, model=None) -> bool: "beam_size": w_beam_size, "best_of": w_best_of, }, - "software_version": my_version, + "software_version": version, "input_data": input_uri, "output_data": output_uri if output_uri else output_path, "steps": prov_steps, diff --git a/poetry.lock b/poetry.lock index 73762e0..acc39e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -146,17 +146,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.26" +version = "1.35.29" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.26-py3-none-any.whl", hash = "sha256:c31db992655db233d98762612690cfe60723c9e1503b5709aad92c1c564877bb"}, - {file = "boto3-1.35.26.tar.gz", hash = "sha256:b04087afd3570ba540fd293823c77270ec675672af23da9396bd5988a3f8128b"}, + {file = "boto3-1.35.29-py3-none-any.whl", hash = "sha256:2244044cdfa8ac345d7400536dc15a4824835e7ec5c55bc267e118af66bb27db"}, + {file = "boto3-1.35.29.tar.gz", hash = "sha256:7bbb1ee649e09e956952285782cfdebd7e81fc78384f48dfab3d66c6eaf3f63f"}, ] [package.dependencies] -botocore = ">=1.35.26,<1.36.0" +botocore = ">=1.35.29,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -165,13 +165,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.26" +version = "1.35.29" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.26-py3-none-any.whl", hash = "sha256:0b9dee5e4a3314e251e103585837506b17fcc7485c3c8adb61a9a913f46da1e7"}, - {file = "botocore-1.35.26.tar.gz", hash = "sha256:19efc3a22c9df77960712b4e203f912486f8bcd3794bff0fd7b2a0f5f1d5712d"}, + {file = "botocore-1.35.29-py3-none-any.whl", hash = "sha256:f8e3ae0d84214eff3fb69cb4dc51cea6c43d3bde82027a94d00c52b941d6c3d5"}, + {file = "botocore-1.35.29.tar.gz", hash = "sha256:4ed28ab03675bb008a290c452c5ddd7aaa5d4e3fa1912aadbdf93057ee84362b"}, ] [package.dependencies] @@ -931,13 +931,13 @@ files = [ [[package]] name = "moto" -version = "5.0.15" +version = "5.0.16" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "moto-5.0.15-py2.py3-none-any.whl", hash = "sha256:fa1e92ffb55dbfb9fa92a2115a88c32481b75aa3fbd24075d1f29af2f9becffa"}, - {file = "moto-5.0.15.tar.gz", hash = "sha256:57aa8c2af417cc64a0ddfe63e5bcd1ada90f5079b73cdd1f74c4e9fb30a1a7e6"}, + {file = "moto-5.0.16-py2.py3-none-any.whl", hash = "sha256:4ce1f34830307f7b3d553d77a7ef26066ab3b70006203d4226b048c9d11a3be4"}, + {file = "moto-5.0.16.tar.gz", hash = "sha256:f4afb176a964cd7a70da9bc5e053d43109614ce3cab26044bcbb53610435dff4"}, ] [package.dependencies] @@ -1760,6 +1760,17 @@ dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + [[package]] name = "tqdm" version = "4.66.5" @@ -1885,4 +1896,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "42ea278bed0d6ef83b60e61b4d648984a6c31113b682c3b79f6a2d5517660c7e" +content-hash = "499f96375940c9a93fadd4e6f91b91ed01b5a389bdbe7110df15c8358df9c15f" diff --git a/pyproject.toml b/pyproject.toml index bb9f202..a74c33d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ validators = "^0.33.0" boto3 = "^1.35.10" fastapi = "^0.115.0" uvicorn = "^0.30.6" +tomli = "^2.0.1" [tool.poetry.group.dev.dependencies] moto = "^5.0.13" diff --git a/transcode.py b/transcode.py index b68a252..29f3982 100644 --- a/transcode.py +++ b/transcode.py @@ -25,7 +25,7 @@ def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]: provenance = { "activity_name": "Transcoding", "activity_description": "Checks if input needs transcoding, then transcodes if so", - "processing_time_ms": 0, + "processing_time_ms": -1, "start_time_unix": start_time, "parameters": [], "software_version": "", From 15a0d12e8f1584e7ebb998bab3bd33b2a3481238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Drago=C8=99?= Date: Mon, 30 Sep 2024 11:20:31 +0200 Subject: [PATCH 4/4] black formatting --- asr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asr.py b/asr.py index 4b31d22..46b942f 100644 --- a/asr.py +++ b/asr.py @@ -28,12 +28,12 @@ def _get_project_meta(): - with open('pyproject.toml', mode='rb') as pyproject: - return tomli.load(pyproject)['tool']['poetry'] + with open("pyproject.toml", mode="rb") as pyproject: + return tomli.load(pyproject)["tool"]["poetry"] pkg_meta = _get_project_meta() -version = str(pkg_meta['version']) +version = str(pkg_meta["version"]) def run(input_uri: str, output_uri: str, model=None) -> bool: