Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provenance #101

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import logging
import os
import time
import pkg_resources

from base_util import get_asset_info, asr_output_dir, save_provenance
from config import (
s3_endpoint_url,
s3_bucket,
s3_folder_in_bucket,
model_base_dir,
w_word_timestamps,
w_device,
w_model,
w_beam_size,
w_best_of,
w_vad,
)

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket, model_base_dir
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
Expand All @@ -11,50 +25,111 @@

logger = logging.getLogger(__name__)
os.environ["HF_HOME"] = model_base_dir # change dir where model is downloaded
my_version = pkg_resources.get_distribution(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really interesting. How does this work to get the worker version? We have struggled previously with trying to put in a Github version but we had trouble getting that into the Dockerfile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't put in the Github version. It simply obtains the version from the pyproject.toml file. And also, this approach wasn't actually working, so I'm replacing it with a different one that actually reads pyproject.toml and outputs what is written in the version field.

"whisper-asr-worker"
).version # get worker version


def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
start_time = time.time()
prov_steps = [] # track provenance
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
logger.error("Could not obtain input, quitting...")
return False

prov_steps.append(result.provenance)

input_path = result.file_path
asset_id, extension = get_asset_info(input_path)
output_path = asr_output_dir(input_path)

# 2. check if the input file is suitable for processing any further
transcoded_file_path = try_transcode(input_path, asset_id, extension)
if not transcoded_file_path:
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return False
else:
input_path = transcoded_file_path
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)

# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path, model)
whisper_prov = run_asr(input_path, output_path, model)
if whisper_prov:
prov_steps.append(whisper_prov)
else:
logger.info(f"Whisper transcript already present in {output_path}")
provenance = {
"activity_name": "Whisper transcript already exists",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point for later discussion with the team as to whether we model this as a different activity, or as a different output of the Whisper speech processing activity. There's something to be said for both

"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

# 4. generate JSON transcript
if not daan_transcript_already_done(output_path):
logger.info("No DAAN transcript found")
success = generate_daan_transcript(output_path)
if not success:
daan_prov = generate_daan_transcript(output_path)
if daan_prov:
prov_steps.append(daan_prov)
else:
logger.warning("Could not generate DAAN transcript")
else:
logger.info(f"DAAN transcript already present in {output_path}")
provenance = {
"activity_name": "DAAN transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

end_time = (time.time() - start_time) * 1000
final_prov = {
"activity_name": "Whisper ASR Worker",
"activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": {
"word_timestamps": w_word_timestamps,
"device": w_device,
"vad": w_vad,
"model": w_model,
"beam_size": w_beam_size,
"best_of": w_best_of,
},
"software_version": my_version,
"input_data": input_uri,
"output_data": output_uri if output_uri else output_path,
"steps": prov_steps,
}

prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.warning("Could not save the provenance")

# 5. transfer output
if output_uri:
transfer_asr_output(output_path, asset_id)
else:
logger.info("No output_uri specified, so all is done")

return True


Expand Down Expand Up @@ -90,14 +165,14 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool:


# check if there is a whisper-transcript.json
def asr_already_done(output_dir):
def asr_already_done(output_dir) -> bool:
whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE)
logger.info(f"Checking existence of {whisper_transcript}")
return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE))


# check if there is a daan-es-transcript.json
def daan_transcript_already_done(output_dir):
def daan_transcript_already_done(output_dir) -> bool:
daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE)
logger.info(f"Checking existence of {daan_transcript}")
return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE))
18 changes: 18 additions & 0 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import subprocess
import json
from typing import Tuple
from config import data_base_dir


LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s"
PROVENANCE_JSON_FILE = "provenance.json"
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,3 +57,19 @@ def run_shell_command(cmd: str) -> bool:
except Exception:
logger.exception("Exception")
return False


def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
logger.info(f"Saving provenance to: {asr_output_dir}")
try:
# write provenance.json
with open(
os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8"
) as f:
logger.info(provenance)
json.dump(provenance, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False

return True
22 changes: 18 additions & 4 deletions daan_transcript.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import TypedDict, List, Optional
from whisper import WHISPER_JSON_FILE

Expand All @@ -19,12 +20,13 @@ class ParsedResult(TypedDict):


# asr_output_dir e.g /data/output/whisper-test/
def generate_daan_transcript(asr_output_dir: str) -> bool:
def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]:
logger.info(f"Generating transcript from: {asr_output_dir}")
start_time = time.time()
whisper_transcript = load_whisper_transcript(asr_output_dir)
if not whisper_transcript:
logger.error("No whisper_transcript.json found")
return False
return None

transcript = parse_whisper_transcript(whisper_transcript)

Expand All @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return None

return True
end_time = (time.time() - start_time) * 1000
provenance = {
"activity_name": "Whisper transcript -> DAAN transcript",
"activity_description": "Converts the output of Whisper to the DAAN index format",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": whisper_transcript,
"output_data": transcript,
"steps": [],
}
return provenance


def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]:
Expand Down
32 changes: 29 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@dataclass
class DownloadResult:
file_path: str # target_file_path, # TODO harmonize with dane-download-worker
mime_type: str # download_data.get("mime_type", "unknown"),
mime_type: str
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),

Expand Down Expand Up @@ -53,8 +54,19 @@ def http_download(url: str) -> Optional[DownloadResult]:
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)


Expand Down Expand Up @@ -89,9 +101,23 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
if not success:
logger.error("Failed to download input data from S3")
return None

download_time = int((time.time() - start_time) * 1000) # time in ms
else:
download_time = -1 # Report back?

provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
}

return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)
44 changes: 39 additions & 5 deletions transcode.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
from dataclasses import dataclass
import logging
import os
from typing import Optional
import time

import base_util
from config import data_base_dir

logger = logging.getLogger(__name__)


def try_transcode(input_path, asset_id, extension) -> Optional[str]:
@dataclass
class TranscodeOutput:
transcoded_file_path: str
provenance: dict


def try_transcode(input_path, asset_id, extension) -> Optional[TranscodeOutput]:
logger.info(
f"Determining if transcode is required for input_path: {input_path} asset_id: ({asset_id}) extension: ({extension})"
)
start_time = time.time()

provenance = {
"activity_name": "Transcoding",
"activity_description": "Checks if input needs transcoding, then transcodes if so",
"processing_time_ms": 0,
greenw0lf marked this conversation as resolved.
Show resolved Hide resolved
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": input_path,
"output_data": "",
"steps": [],
}

# if it's alrady valid audio no transcode necessary
if _is_audio_file(extension):
logger.info("No transcode required, input is audio")
return input_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = input_path
provenance["steps"].append("No transcode required, input is audio")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again we need to think (in the future) a bit more clearly as a team about how we want to communicate about steps that were skipped (but not failed).

return TranscodeOutput(input_path, provenance)

# if the input format is not supported, fail
if not _is_transcodable(extension):
Expand All @@ -27,7 +52,13 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]:
transcoded_file_path = os.path.join(data_base_dir, "input", f"{asset_id}.mp3")
if os.path.exists(transcoded_file_path):
logger.info("Transcoded file is already available, no new transcode needed")
return transcoded_file_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = transcoded_file_path
provenance["steps"].append(
"Transcoded file is already available, no new transcode needed"
)
return TranscodeOutput(transcoded_file_path, provenance)

# go ahead and transcode the input file
success = transcode_to_mp3(
Expand All @@ -41,8 +72,11 @@ def try_transcode(input_path, asset_id, extension) -> Optional[str]:
logger.info(
f"Transcode of {extension} successful, returning: {transcoded_file_path}"
)

return transcoded_file_path
end_time = (time.time() - start_time) * 1000
provenance["processing_time_ms"] = end_time
provenance["output_data"] = transcoded_file_path
provenance["steps"].append("Transcode successful")
return TranscodeOutput(transcoded_file_path, provenance)


def transcode_to_mp3(path: str, asr_path: str) -> bool:
Expand Down
Loading