Skip to content

Commit

Permalink
Add provenance
Browse files Browse the repository at this point in the history
phew
  • Loading branch information
greenw0lf committed Sep 26, 2024
1 parent 58dc75f commit 5def84b
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 24 deletions.
95 changes: 85 additions & 10 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import logging
import os
import time
import pkg_resources

from base_util import get_asset_info, asr_output_dir, save_provenance
from config import (
s3_endpoint_url,
s3_bucket,
s3_folder_in_bucket,
model_base_dir,
w_word_timestamps,
w_device,
w_model,
w_beam_size,
w_best_of,
w_vad,
)

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
Expand All @@ -11,50 +25,111 @@

logger = logging.getLogger(__name__)
os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded
my_version = pkg_resources.get_distribution(
"whisper-asr-worker"
).version # get worker version


def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
start_time = time.time()
prov_steps = [] # track provenance
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
logger.error("Could not obtain input, quitting...")
return False

prov_steps.append(result.provenance)

input_path = result.file_path
asset_id, extension = get_asset_info(input_path)
output_path = asr_output_dir(input_path)

# 2. check if the input file is suitable for processing any further
transcoded_file_path = try_transcode(input_path, asset_id, extension)
if not transcoded_file_path:
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return False
else:
input_path = transcoded_file_path
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)

# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path, model)
whisper_prov = run_asr(input_path, output_path, model)
if whisper_prov:
prov_steps.append(whisper_prov)
else:
logger.info(f"Whisper transcript already present in {output_path}")
provenance = {
"activity_name": "Whisper transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

# 4. generate JSON transcript
if not daan_transcript_already_done(output_path):
logger.info("No DAAN transcript found")
success = generate_daan_transcript(output_path)
if not success:
daan_prov = generate_daan_transcript(output_path)
if daan_prov:
prov_steps.append(daan_prov)
else:
logger.warning("Could not generate DAAN transcript")
else:
logger.info(f"DAAN transcript already present in {output_path}")
provenance = {
"activity_name": "DAAN transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

end_time = (time.time() - start_time) * 1000
final_prov = {
"activity_name": "Whisper ASR Worker",
"activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": {
"word_timestamps": w_word_timestamps,
"device": w_device,
"vad": w_vad,
"model": w_model,
"beam_size": w_beam_size,
"best_of": w_best_of,
},
"software_version": my_version,
"input_data": input_uri,
"output_data": output_uri if output_uri else output_path,
"steps": prov_steps,
}

prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.warning("Could not save the provenance")

# 5. transfer output
if output_uri:
transfer_asr_output(output_path, asset_id)
else:
logger.info("No output_uri specified, so all is done")

return True


Expand Down Expand Up @@ -90,14 +165,14 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool:


# check if there is a whisper-transcript.json
def asr_already_done(output_dir):
def asr_already_done(output_dir) -> bool:
whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE)
logger.info(f"Checking existence of {whisper_transcript}")
return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE))


# check if there is a daan-es-transcript.json
def daan_transcript_already_done(output_dir):
def daan_transcript_already_done(output_dir) -> bool:
daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE)
logger.info(f"Checking existence of {daan_transcript}")
return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE))
18 changes: 18 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import subprocess
import json
from typing import Tuple
from config import data_base_dir


LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s"
PROVENANCE_JSON_FILE = "provenance.json"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,3 +57,19 @@ def run_shell_command(cmd: str) -> bool:
except Exception:
logger.exception("Exception")
return False


def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
logger.info(f"Saving provenance to: {asr_output_dir}")
try:
# write provenance.json
with open(
os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8"
) as f:
logger.info(provenance)
json.dump(provenance, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False

return True
22 changes: 18 additions & 4 deletions daan_transcript.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import TypedDict, List, Optional
from whisper import WHISPER_JSON_FILE

Expand All @@ -19,12 +20,13 @@ class ParsedResult(TypedDict):


# asr_output_dir e.g /data/output/whisper-test/
def generate_daan_transcript(asr_output_dir: str) -> bool:
def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]:
logger.info(f"Generating transcript from: {asr_output_dir}")
start_time = time.time()
whisper_transcript = load_whisper_transcript(asr_output_dir)
if not whisper_transcript:
logger.error("No whisper_transcript.json found")
return False
return None

transcript = parse_whisper_transcript(whisper_transcript)

Expand All @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return None

return True
end_time = (time.time() - start_time) * 1000
provenance = {
"activity_name": "Whisper transcript -> DAAN transcript",
"activity_description": "Converts the output of Whisper to the DAAN index format",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": whisper_transcript,
"output_data": transcript,
"steps": [],
}
return provenance


def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]:
Expand Down
29 changes: 26 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@dataclass
class DownloadResult:
file_path: str # target_file_path, # TODO harmonize with dane-download-worker
mime_type: str # download_data.get("mime_type", "unknown"),
mime_type: str
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),

Expand Down Expand Up @@ -53,8 +54,19 @@ def http_download(url: str) -> Optional[DownloadResult]:
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)


Expand Down Expand Up @@ -90,6 +102,17 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
logger.error("Failed to download input data from S3")
return None
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)
44 changes: 39 additions & 5 deletions transcode.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
from dataclasses import dataclass
import logging
import os
from typing import Optional
import time

import base_util
from config import data_base_dir

logger = logging.getLogger(__name__)


def try_transcode(input_path, asset_id, extension) -> Optional[str]:
@dataclass
class TranscodeOutput:
transcoded_file_path: str
provenance: dict


def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
logger.info(
f"Determining if transcode is required for input_path: {input_path} asset_id: ({asset_id}) extension: ({extension})"
)
start_time = time.time()

provenance = {
"activity_name": "Transcoding",
"activity_description": "Checks if input needs transcoding, then transcodes if so",
"processing_time_ms": 0,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": input_path,
"output_data": "",
"steps": [],
}

# if it's alrady valid audio no transcode necessary
if _is_audio_file(extension):
logger.info("No transcode required, input is audio")
return input_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = input_path
provenance["steps"].append("No transcode required, input is audio")
return TranscodeOutput(input_path, provenance)

# if the input format is not supported, fail
if not _is_transcodable(extension):
Expand All @@ -27,7 +52,13 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]:
transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3")
if os.path.exists(transcoded_file_path):
logger.info("Transcoded file is already available, no new transcode needed")
return transcoded_file_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = transcoded_file_path
provenance["steps"].append(
"Transcoded file is already available, no new transcode needed"
)
return TranscodeOutput(transcoded_file_path, provenance)

# go ahead and transcode the input file
success = transcode_to_mp3(
Expand All @@ -41,8 +72,11 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]:
logger.info(
f"Transcode of {extension} successful, returning: {transcoded_file_path}"
)

return transcoded_file_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = transcoded_file_path
provenance["steps"].append("Transcode successful")
return TranscodeOutput(transcoded_file_path, provenance)


def transcode_to_mp3(path: str, asr_path: str) -> bool:
Expand Down
Loading

0 comments on commit 5def84b

Please sign in to comment.