Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provenance #101

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 91 additions & 10 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import logging
import os
import time
import tomli

from base_util import get_asset_info, asr_output_dir, save_provenance
from config import (
s3_endpoint_url,
s3_bucket,
s3_folder_in_bucket,
model_base_dir,
w_word_timestamps,
w_device,
w_model,
w_beam_size,
w_best_of,
w_vad,
)

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
Expand All @@ -13,48 +27,115 @@
os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded


def _get_project_meta():
with open("pyproject.toml", mode="rb") as pyproject:
return tomli.load(pyproject)["tool"]["poetry"]


pkg_meta = _get_project_meta()
version = str(pkg_meta["version"])


def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
start_time = time.time()
prov_steps = [] # track provenance
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
logger.error("Could not obtain input, quitting...")
return False

prov_steps.append(result.provenance)

input_path = result.file_path
asset_id, extension = get_asset_info(input_path)
output_path = asr_output_dir(input_path)

# 2. check if the input file is suitable for processing any further
transcoded_file_path = try_transcode(input_path, asset_id, extension)
if not transcoded_file_path:
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return False
else:
input_path = transcoded_file_path
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)

# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path, model)
whisper_prov = run_asr(input_path, output_path, model)
if whisper_prov:
prov_steps.append(whisper_prov)
else:
logger.info(f"Whisper transcript already present in {output_path}")
provenance = {
"activity_name": "Whisper transcript already exists",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point for later discussion with the team as to whether we model this as a different activity, or as a different output of the Whisper speech processing activity. There's something to be said for both

"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

# 4. generate JSON transcript
if not daan_transcript_already_done(output_path):
logger.info("No DAAN transcript found")
success = generate_daan_transcript(output_path)
if not success:
daan_prov = generate_daan_transcript(output_path)
if daan_prov:
prov_steps.append(daan_prov)
else:
logger.warning("Could not generate DAAN transcript")
else:
logger.info(f"DAAN transcript already present in {output_path}")
provenance = {
"activity_name": "DAAN transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

end_time = (time.time() - start_time) * 1000
final_prov = {
"activity_name": "Whisper ASR Worker",
"activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": {
"word_timestamps": w_word_timestamps,
"device": w_device,
"vad": w_vad,
"model": w_model,
"beam_size": w_beam_size,
"best_of": w_best_of,
},
"software_version": version,
"input_data": input_uri,
"output_data": output_uri if output_uri else output_path,
"steps": prov_steps,
}

prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.warning("Could not save the provenance")

# 5. transfer output
if output_uri:
transfer_asr_output(output_path, asset_id)
else:
logger.info("No output_uri specified, so all is done")

return True


Expand Down Expand Up @@ -90,14 +171,14 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool:


# check if there is a whisper-transcript.json
def asr_already_done(output_dir):
def asr_already_done(output_dir) -> bool:
whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE)
logger.info(f"Checking existence of {whisper_transcript}")
return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE))


# check if there is a daan-es-transcript.json
def daan_transcript_already_done(output_dir):
def daan_transcript_already_done(output_dir) -> bool:
daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE)
logger.info(f"Checking existence of {daan_transcript}")
return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE))
18 changes: 18 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import subprocess
import json
from typing import Tuple
from config import data_base_dir


LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s"
PROVENANCE_JSON_FILE = "provenance.json"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,3 +57,19 @@ def run_shell_command(cmd: str) -> bool:
except Exception:
logger.exception("Exception")
return False


def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
logger.info(f"Saving provenance to: {asr_output_dir}")
try:
# write provenance.json
with open(
os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8"
) as f:
logger.info(provenance)
json.dump(provenance, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False

return True
22 changes: 18 additions & 4 deletions daan_transcript.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import TypedDict, List, Optional
from whisper import WHISPER_JSON_FILE

Expand All @@ -19,12 +20,13 @@ class ParsedResult(TypedDict):


# asr_output_dir e.g /data/output/whisper-test/
def generate_daan_transcript(asr_output_dir: str) -> bool:
def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]:
logger.info(f"Generating transcript from: {asr_output_dir}")
start_time = time.time()
whisper_transcript = load_whisper_transcript(asr_output_dir)
if not whisper_transcript:
logger.error("No whisper_transcript.json found")
return False
return None

transcript = parse_whisper_transcript(whisper_transcript)

Expand All @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return None

return True
end_time = (time.time() - start_time) * 1000
provenance = {
"activity_name": "Whisper transcript -> DAAN transcript",
"activity_description": "Converts the output of Whisper to the DAAN index format",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": whisper_transcript,
"output_data": transcript,
"steps": [],
}
return provenance


def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]:
Expand Down
32 changes: 29 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@dataclass
class DownloadResult:
file_path: str # target_file_path, # TODO harmonize with dane-download-worker
mime_type: str # download_data.get("mime_type", "unknown"),
mime_type: str
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),

Expand Down Expand Up @@ -53,8 +54,19 @@ def http_download(url: str) -> Optional[DownloadResult]:
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)


Expand Down Expand Up @@ -89,9 +101,23 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
if not success:
logger.error("Failed to download input data from S3")
return None

download_time = int((time.time() - start_time) * 1000) # time in ms
else:
download_time = -1 # Report back?

provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
}

return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)
33 changes: 22 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ validators = "^0.33.0"
boto3 = "^1.35.10"
fastapi = "^0.115.0"
uvicorn = "^0.30.6"
tomli = "^2.0.1"

[tool.poetry.group.dev.dependencies]
moto = "^5.0.13"
Expand Down
Loading