Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper worker as a service (Jobs don't work yet in airflow) #96

Merged
merged 59 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
1c6b113
Moved source to src folder
Aug 5, 2024
53f951c
Init for src
Aug 5, 2024
2574d4c
Setting up Flask app, WIP
Aug 5, 2024
828c562
Add Flask dependencies to dedicated dependency group
Aug 5, 2024
5fb42db
Renamed app folder to service
Aug 5, 2024
c2d76dd
Moved worker.py outside src
Aug 5, 2024
529fd32
Fixed structure & imports
Aug 5, 2024
dd6db72
Add callable for service
Aug 5, 2024
c70ab26
rename api module
Aug 5, 2024
658862f
Make src a package to be installed
Aug 5, 2024
fe60573
Update README.md
Veldhoen Aug 8, 2024
d5e8727
Make service dependency group optional
Aug 8, 2024
5441450
Renamed app to service
Aug 8, 2024
9d2737f
Renamed app to service
Aug 8, 2024
b0819c3
Started API definition
Aug 8, 2024
2f9fe1b
Merge branch '52-setup-service-decouple-from-dane' of github.com:beel…
Aug 8, 2024
480e109
Add actual content to api definition
Aug 8, 2024
b5a5ad9
Renamed to api, becaus it is strictly not restful
Aug 8, 2024
d349849
merged with main
jblom Sep 23, 2024
7d42d31
merged with main; restructured; server starts; just fix some flask error
jblom Sep 23, 2024
db33f9e
now it also runs;
jblom Sep 23, 2024
27fc9e5
Updated title and description for API
Sep 23, 2024
b40cf63
Actually run service
Sep 23, 2024
d6b21a8
Fixed api model
Sep 24, 2024
0fffc15
Request works, now start debugging the function
Sep 24, 2024
bf95c7f
Remove unnecessary prefix from path
Sep 24, 2024
ab6cdd3
Wired up API, figuring out how to start async processing now
Sep 24, 2024
dd0fab7
Last attempt
Sep 24, 2024
40b80f2
migrated to fastapi, which has a BackGroundTasks mechanism built in
jblom Sep 25, 2024
2ae6af4
readded oddily deleted whisper_api.py
jblom Sep 25, 2024
713dfc5
moved code back in
jblom Sep 25, 2024
a577524
added Status; added ping; removed old health_check endpoint
jblom Sep 25, 2024
9f20bf0
added port and log arguments to whisper_api.py
jblom Sep 25, 2024
bdd5dc7
adapted main.py to either start job or service
jblom Sep 25, 2024
6690ccd
disable temperature because of runtime bug
jblom Sep 25, 2024
b6f1c40
model download should work + loading model on service startup only
greenw0lf Sep 25, 2024
6584448
Fix pipeline issues
greenw0lf Sep 25, 2024
66e10f0
Implement some changes from the audio extraction PR
greenw0lf Sep 25, 2024
d894219
Hopefully pipeline passes so I can test this
greenw0lf Sep 25, 2024
58dc75f
Bringing back the changes I made
greenw0lf Sep 26, 2024
ecb2eb5
If file already exists, report back less than zero download_time
Sep 26, 2024
5def84b
Add provenance
greenw0lf Sep 26, 2024
d541866
Merge branch '52-setup-service-decouple-from-dane' into 81-prov
greenw0lf Sep 26, 2024
5b832ea
Update download.py
greenw0lf Sep 26, 2024
aac0bc8
Use Task object to represent tasks, and use dict to keep all_tasks. WIP!
Sep 26, 2024
ebda32b
Revert some stuff + add response code when busy
greenw0lf Sep 26, 2024
52625c1
Make status not be mandatory
greenw0lf Sep 26, 2024
d64a6a0
no status in __init__
greenw0lf Sep 26, 2024
863d210
remove __init__ for Task class
greenw0lf Sep 26, 2024
5d141db
logging current task
greenw0lf Sep 26, 2024
cefbf9d
status should be checked properly now
greenw0lf Sep 26, 2024
cdce764
Change how version is obtained
greenw0lf Sep 30, 2024
15a0d12
black formatting
greenw0lf Sep 30, 2024
a4e44a2
Merge pull request #101 from beeldengeluid/81-prov
jblom Oct 1, 2024
da939ee
remove unused flask lib references
jblom Oct 1, 2024
0281720
move hugging face env var to module where it is used
jblom Oct 1, 2024
8144798
centralized model download in new function in whisper.py
jblom Oct 1, 2024
ddf0295
updated packages
jblom Oct 2, 2024
341fda0
add prov to S3 output
jblom Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ apt-get -y update && apt-get -y upgrade && apt-get install -y --no-install-recom
## Running the worker using a CUDA-compatible GPU

To run the worker with a CUDA-compatible GPU instead of the CPU, either:
- skip steps 3 & 4 from "Docker CPU run"
- skip steps 3 & 4 from "Docker CPU run"
- skip step 3 from "Local run"

**(OUTDATED BUT STILL MIGHT BE RELEVANT)** To run it using a GPU via Docker, check [the instructions from the dane-example-worker](https://github.com/beeldengeluid/dane-example-worker/wiki/Containerization#running-the-container-locally-using-cuda-compatible-gpu).
Expand All @@ -50,7 +50,7 @@ Make sure to replace `dane-example-worker` in the `docker run` command with `dan

## Expected run

The expected run of this worker (whose pipeline is defined in `asr.py`) should
The expected run of this worker (whose pipeline is defined in `asr.py`) should

1. download the input file if it isn't downloaded already in `/data/input/` via `download.py`

Expand Down Expand Up @@ -80,4 +80,4 @@ The pre-trained Whisper model version can be adjusted in the `.env` file by edit

We recommend version `large-v2` as it performs better than `large-v3` in our [benchmarks](https://opensource-spraakherkenning-nl.github.io/ASR_NL_results/).

You can also specify an S3 URI if you have your own custom model available via S3 (by modifying the `W_MODEL` parameter).
You can also specify an S3 URI if you have your own custom model available via S3 (by modifying the `W_MODEL` parameter).
108 changes: 97 additions & 11 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import logging
import os
import time
import tomli

from base_util import (
get_asset_info,
asr_output_dir,
save_provenance,
PROVENANCE_JSON_FILE,
)
from config import (
s3_endpoint_url,
s3_bucket,
s3_folder_in_bucket,
w_word_timestamps,
w_device,
w_model,
w_beam_size,
w_best_of,
w_vad,
)

from base_util import get_asset_info, asr_output_dir
from config import s3_endpoint_url, s3_bucket, s3_folder_in_bucket
from download import download_uri
from whisper import run_asr, WHISPER_JSON_FILE
from s3_util import S3Store
Expand All @@ -12,48 +30,115 @@
logger = logging.getLogger(__name__)


def run(input_uri: str, output_uri: str) -> bool:
def _get_project_meta():
with open("pyproject.toml", mode="rb") as pyproject:
return tomli.load(pyproject)["tool"]["poetry"]


pkg_meta = _get_project_meta()
version = str(pkg_meta["version"])


def run(input_uri: str, output_uri: str, model=None) -> bool:
logger.info(f"Processing {input_uri} (save to --> {output_uri})")
start_time = time.time()
prov_steps = [] # track provenance
# 1. download input
result = download_uri(input_uri)
logger.info(result)
if not result:
logger.error("Could not obtain input, quitting...")
return False

prov_steps.append(result.provenance)

input_path = result.file_path
asset_id, extension = get_asset_info(input_path)
output_path = asr_output_dir(input_path)

# 2. check if the input file is suitable for processing any further
transcoded_file_path = try_transcode(input_path, asset_id, extension)
if not transcoded_file_path:
transcode_output = try_transcode(input_path, asset_id, extension)
if not transcode_output:
logger.error("The transcode failed to yield a valid file to continue with")
return False
else:
input_path = transcoded_file_path
input_path = transcode_output.transcoded_file_path
prov_steps.append(transcode_output.provenance)

# 3. run ASR
if not asr_already_done(output_path):
logger.info("No Whisper transcript found")
run_asr(input_path, output_path)
whisper_prov = run_asr(input_path, output_path, model)
if whisper_prov:
prov_steps.append(whisper_prov)
else:
logger.info(f"Whisper transcript already present in {output_path}")
provenance = {
"activity_name": "Whisper transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

# 4. generate JSON transcript
if not daan_transcript_already_done(output_path):
logger.info("No DAAN transcript found")
success = generate_daan_transcript(output_path)
if not success:
daan_prov = generate_daan_transcript(output_path)
if daan_prov:
prov_steps.append(daan_prov)
else:
logger.warning("Could not generate DAAN transcript")
else:
logger.info(f"DAAN transcript already present in {output_path}")
provenance = {
"activity_name": "DAAN transcript already exists",
"activity_description": "",
"processing_time_ms": "",
"start_time_unix": "",
"parameters": [],
"software_version": "",
"input_data": "",
"output_data": "",
"steps": [],
}
prov_steps.append(provenance)

end_time = (time.time() - start_time) * 1000
final_prov = {
"activity_name": "Whisper ASR Worker",
"activity_description": "Worker that gets a video/audio file as input and outputs JSON transcripts in various formats",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": {
"word_timestamps": w_word_timestamps,
"device": w_device,
"vad": w_vad,
"model": w_model,
"beam_size": w_beam_size,
"best_of": w_best_of,
},
"software_version": version,
"input_data": input_uri,
"output_data": output_uri if output_uri else output_path,
"steps": prov_steps,
}

prov_success = save_provenance(final_prov, output_path)
if not prov_success:
logger.warning("Could not save the provenance")

# 5. transfer output
if output_uri:
transfer_asr_output(output_path, asset_id)
else:
logger.info("No output_uri specified, so all is done")

return True


Expand Down Expand Up @@ -84,19 +169,20 @@ def transfer_asr_output(output_path: str, asset_id: str) -> bool:
[
os.path.join(output_path, DAAN_JSON_FILE),
os.path.join(output_path, WHISPER_JSON_FILE),
os.path.join(output_path, PROVENANCE_JSON_FILE),
],
)


# check if there is a whisper-transcript.json
def asr_already_done(output_dir):
def asr_already_done(output_dir) -> bool:
whisper_transcript = os.path.join(output_dir, WHISPER_JSON_FILE)
logger.info(f"Checking existence of {whisper_transcript}")
return os.path.exists(os.path.join(output_dir, WHISPER_JSON_FILE))


# check if there is a daan-es-transcript.json
def daan_transcript_already_done(output_dir):
def daan_transcript_already_done(output_dir) -> bool:
daan_transcript = os.path.join(output_dir, DAAN_JSON_FILE)
logger.info(f"Checking existence of {daan_transcript}")
return os.path.exists(os.path.join(output_dir, DAAN_JSON_FILE))
21 changes: 19 additions & 2 deletions base_util.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
import ntpath
import os
import subprocess
import json
from typing import Tuple
from config import data_base_dir


LOG_FORMAT = "%(asctime)s|%(levelname)s|%(process)d|%(module)s|%(funcName)s|%(lineno)d|%(message)s"
PROVENANCE_JSON_FILE = "provenance.json"
logger = logging.getLogger(__name__)


# the file name without extension is used as asset ID
def get_asset_info(input_file: str) -> Tuple[str, str]:
file_name = ntpath.basename(input_file)
file_name = os.path.basename(input_file)
asset_id, extension = os.path.splitext(file_name)
logger.info(f"working with this asset ID {asset_id}")
return asset_id, extension
Expand Down Expand Up @@ -56,3 +57,19 @@ def run_shell_command(cmd: str) -> bool:
except Exception:
logger.exception("Exception")
return False


def save_provenance(provenance: dict, asr_output_dir: str) -> bool:
logger.info(f"Saving provenance to: {asr_output_dir}")
try:
# write provenance.json
with open(
os.path.join(asr_output_dir, PROVENANCE_JSON_FILE), "w+", encoding="utf-8"
) as f:
logger.info(provenance)
json.dump(provenance, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False

return True
22 changes: 18 additions & 4 deletions daan_transcript.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import time
from typing import TypedDict, List, Optional
from whisper import WHISPER_JSON_FILE

Expand All @@ -19,12 +20,13 @@ class ParsedResult(TypedDict):


# asr_output_dir e.g /data/output/whisper-test/
def generate_daan_transcript(asr_output_dir: str) -> bool:
def generate_daan_transcript(asr_output_dir: str) -> Optional[dict]:
logger.info(f"Generating transcript from: {asr_output_dir}")
start_time = time.time()
whisper_transcript = load_whisper_transcript(asr_output_dir)
if not whisper_transcript:
logger.error("No whisper_transcript.json found")
return False
return None

transcript = parse_whisper_transcript(whisper_transcript)

Expand All @@ -37,9 +39,21 @@ def generate_daan_transcript(asr_output_dir: str) -> bool:
json.dump(transcript, f, ensure_ascii=False, indent=4)
except EnvironmentError as e: # OSError or IOError...
logger.exception(os.strerror(e.errno))
return False
return None

return True
end_time = (time.time() - start_time) * 1000
provenance = {
"activity_name": "Whisper transcript -> DAAN transcript",
"activity_description": "Converts the output of Whisper to the DAAN index format",
"processing_time_ms": end_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": whisper_transcript,
"output_data": transcript,
"steps": [],
}
return provenance


def load_whisper_transcript(asr_output_dir: str) -> Optional[dict]:
Expand Down
39 changes: 35 additions & 4 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
@dataclass
class DownloadResult:
file_path: str # target_file_path, # TODO harmonize with dane-download-worker
mime_type: str # download_data.get("mime_type", "unknown"),
mime_type: str
provenance: dict
download_time: float = -1 # time (ms) taken to receive data after request
content_length: int = -1 # download_data.get("content_length", -1),

Expand Down Expand Up @@ -47,11 +48,25 @@ def http_download(url: str) -> Optional[DownloadResult]:
os.makedirs(input_file_dir)
with open(input_file, "wb") as file:
response = requests.get(url)
if response.status_code >= 400:
logger.error(f"Could not download url: {response.status_code}")
return None
file.write(response.content)
file.close()
download_time = (time.time() - start_time) * 1000 # time in ms
provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": url,
"output_data": input_file,
"steps": [],
}
return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)


Expand Down Expand Up @@ -86,7 +101,23 @@ def s3_download(s3_uri: str) -> Optional[DownloadResult]:
if not success:
logger.error("Failed to download input data from S3")
return None
download_time = (time.time() - start_time) * 1000 # time in ms

download_time = int((time.time() - start_time) * 1000) # time in ms
else:
download_time = -1 # Report back?

provenance = {
"activity_name": "Input download",
"activity_description": "Downloads the input file from INPUT_URI",
"processing_time_ms": download_time,
"start_time_unix": start_time,
"parameters": [],
"software_version": "",
"input_data": s3_uri,
"output_data": input_file,
"steps": [],
}

return DownloadResult(
input_file, mime_type, download_time # TODO add content_length
input_file, mime_type, provenance, download_time # TODO add content_length
)
Loading