Skip to content

Commit

Permalink
Use multiple wit.ai API keys to transcript the same video
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Aug 2, 2023
1 parent 5a4ac3b commit 23b0dbd
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 59 deletions.
2 changes: 1 addition & 1 deletion colab_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" beam_size=5,\n",
" ct2_compute_type='default',\n",
"\n",
" wit_client_access_token=wit_api_key,\n",
" wit_client_access_tokens=[wit_api_key],\n",
" max_cutting_duration=max_cutting_duration,\n",
" min_words_per_segment=min_words_per_segment,\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tafrigh/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
beam_size=args.beam_size,
ct2_compute_type=args.ct2_compute_type,

wit_client_access_token=args.wit_client_access_token,
wit_client_access_tokens=args.wit_client_access_tokens,
max_cutting_duration=args.max_cutting_duration,
min_words_per_segment=args.min_words_per_segment,

Expand Down
10 changes: 5 additions & 5 deletions tafrigh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
use_whisper_jax: bool,
beam_size: int,
ct2_compute_type: str,
wit_client_access_token: str,
wit_client_access_tokens: List[str],
max_cutting_duration: int,
min_words_per_segment: int,
save_files_before_compact: bool,
Expand All @@ -40,7 +40,7 @@ def __init__(
ct2_compute_type,
)

self.wit = self.Wit(wit_client_access_token, max_cutting_duration)
self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration)

self.output = self.Output(
min_words_per_segment,
Expand All @@ -52,7 +52,7 @@ def __init__(
)

def use_wit(self) -> bool:
return self.wit.wit_client_access_token != ''
return self.wit.wit_client_access_tokens != []

class Input:
def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool):
Expand Down Expand Up @@ -85,8 +85,8 @@ def __init__(
self.ct2_compute_type = ct2_compute_type

class Wit:
def __init__(self, wit_client_access_token: str, max_cutting_duration: int):
self.wit_client_access_token = wit_client_access_token
def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int):
self.wit_client_access_tokens = wit_client_access_tokens
self.max_cutting_duration = max_cutting_duration

class Output:
Expand Down
77 changes: 55 additions & 22 deletions tafrigh/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import logging
import multiprocessing
import os
import threading
import requests
import shutil
import tempfile
import time

from itertools import repeat
from multiprocessing.dummy import Pool as ThreadPool
from requests.adapters import HTTPAdapter
from typing import Dict, Generator, List, Tuple, Union
from urllib3.util.retry import Retry
Expand All @@ -16,7 +17,6 @@

from tafrigh.audio_splitter import AudioSplitter
from tafrigh.config import Config
from tafrigh.utils.decorators import minimum_execution_time


class WitRecognizer:
Expand All @@ -28,6 +28,8 @@ def recognize(
file_path: str,
wit_config: Config.Wit,
) -> Generator[Dict[str, float], None, List[Dict[str, Union[str, float]]]]:
self.semaphore_pools = [threading.Semaphore(60) for _ in range(len(wit_config.wit_client_access_tokens))]
self.timer = threading.Timer(60, self._reset_semaphore_pools)

temp_directory = tempfile.mkdtemp()

Expand All @@ -50,10 +52,20 @@ def recognize(
session = requests.Session()
session.mount('https://', adapter)

with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool:
self.timer.start()

with ThreadPool(processes=min(4, multiprocessing.cpu_count() - 1) * len(wit_config.wit_client_access_tokens)) as pool:
async_results = [
pool.apply_async(self._process_segment, (segment, file_path, wit_config, session))
for segment in segments
pool.apply_async(
self._process_segment,
(
segment,
file_path,
wit_config,
session,
index % len(wit_config.wit_client_access_tokens),
),
) for index, segment in enumerate(segments)
]

transcriptions = []
Expand All @@ -69,19 +81,17 @@ def recognize(
'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None,
}

time.sleep(0.5)

shutil.rmtree(temp_directory)

return transcriptions

@minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1)
def _process_segment(
self,
segment: Tuple[str, float, float],
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
wit_client_access_token_index: int
) -> Dict[str, Union[str, float]]:
segment_file_path, start, end = segment

Expand All @@ -90,25 +100,28 @@ def _process_segment(

retries = 5

self.semaphore_pools[wit_client_access_token_index].acquire()

text = ''
while retries > 0:
response = session.post(
'https://api.wit.ai/speech',
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Authorization': f'Bearer {wit_config.wit_client_access_token}',
},
data=audio_content,
)

if response.status_code == 200:
try:
try:
response = session.post(
'https://api.wit.ai/speech',
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}',
},
data=audio_content,
)

if response.status_code == 200:
text = json.loads(response.text)['text']
break
except KeyError:
else:
retries -= 1
else:
time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1)
except:
retries -= 1
time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1)

Expand All @@ -123,3 +136,23 @@ def _process_segment(
'end': end,
'text': text.strip(),
}

def _reset_semaphore_pools(self) -> None:
do_reset = False

for semaphore_pool in self.semaphore_pools:
if semaphore_pool._value < 60:
do_reset = True

if do_reset:
for semaphore_pool in self.semaphore_pools:
while semaphore_pool._value < 60:
semaphore_pool.release()

if self.timer:
self.timer.cancel()

self.timer = threading.Timer(60, self._reset_semaphore_pools)
self.timer.start()
else:
self.timer.cancel()
8 changes: 4 additions & 4 deletions tafrigh/utils/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def parse_args(argv: List[str]) -> argparse.Namespace:

wit_group = parser.add_argument_group('Wit')

wit_group.add_argument(
input_group.add_argument(
'-w',
'--wit_client_access_token',
default='',
help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
'--wit_client_access_tokens',
nargs='+',
help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
)

wit_group.add_argument(
Expand Down
26 changes: 0 additions & 26 deletions tafrigh/utils/decorators.py

This file was deleted.

0 comments on commit 23b0dbd

Please sign in to comment.