From 23b0dbd27653a026ce6990f48397ddbe7b353571 Mon Sep 17 00:00:00 2001 From: Ali Hamdi Ali Fadel Date: Sun, 30 Jul 2023 20:37:57 +0300 Subject: [PATCH] Use multiple wit.ai API keys to transcript the same video --- colab_notebook.ipynb | 2 +- tafrigh/cli.py | 2 +- tafrigh/config.py | 10 ++-- tafrigh/recognizers/wit_recognizer.py | 77 +++++++++++++++++++-------- tafrigh/utils/cli_utils.py | 8 +-- tafrigh/utils/decorators.py | 26 --------- 6 files changed, 66 insertions(+), 59 deletions(-) delete mode 100644 tafrigh/utils/decorators.py diff --git a/colab_notebook.ipynb b/colab_notebook.ipynb index 3873117..1b3d84a 100644 --- a/colab_notebook.ipynb +++ b/colab_notebook.ipynb @@ -141,7 +141,7 @@ " beam_size=5,\n", " ct2_compute_type='default',\n", "\n", - " wit_client_access_token=wit_api_key,\n", + " wit_client_access_tokens=[wit_api_key],\n", " max_cutting_duration=max_cutting_duration,\n", " min_words_per_segment=min_words_per_segment,\n", "\n", diff --git a/tafrigh/cli.py b/tafrigh/cli.py index ef4c602..e8cc989 100644 --- a/tafrigh/cli.py +++ b/tafrigh/cli.py @@ -51,7 +51,7 @@ def main(): beam_size=args.beam_size, ct2_compute_type=args.ct2_compute_type, - wit_client_access_token=args.wit_client_access_token, + wit_client_access_tokens=args.wit_client_access_tokens, max_cutting_duration=args.max_cutting_duration, min_words_per_segment=args.min_words_per_segment, diff --git a/tafrigh/config.py b/tafrigh/config.py index 3598de4..41677a3 100644 --- a/tafrigh/config.py +++ b/tafrigh/config.py @@ -19,7 +19,7 @@ def __init__( use_whisper_jax: bool, beam_size: int, ct2_compute_type: str, - wit_client_access_token: str, + wit_client_access_tokens: List[str], max_cutting_duration: int, min_words_per_segment: int, save_files_before_compact: bool, @@ -40,7 +40,7 @@ def __init__( ct2_compute_type, ) - self.wit = self.Wit(wit_client_access_token, max_cutting_duration) + self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) self.output = self.Output( min_words_per_segment, @@ -52,7 +52,7 @@ def __init__( ) def use_wit(self) -> bool: - return self.wit.wit_client_access_token != '' + return self.wit.wit_client_access_tokens != [] class Input: def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): @@ -85,8 +85,8 @@ def __init__( self.ct2_compute_type = ct2_compute_type class Wit: - def __init__(self, wit_client_access_token: str, max_cutting_duration: int): - self.wit_client_access_token = wit_client_access_token + def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int): + self.wit_client_access_tokens = wit_client_access_tokens self.max_cutting_duration = max_cutting_duration class Output: diff --git a/tafrigh/recognizers/wit_recognizer.py b/tafrigh/recognizers/wit_recognizer.py index 0200006..3d1905c 100644 --- a/tafrigh/recognizers/wit_recognizer.py +++ b/tafrigh/recognizers/wit_recognizer.py @@ -2,12 +2,13 @@ import logging import multiprocessing import os +import threading import requests import shutil import tempfile import time -from itertools import repeat +from multiprocessing.dummy import Pool as ThreadPool from requests.adapters import HTTPAdapter from typing import Dict, Generator, List, Tuple, Union from urllib3.util.retry import Retry @@ -16,7 +17,6 @@ from tafrigh.audio_splitter import AudioSplitter from tafrigh.config import Config -from tafrigh.utils.decorators import minimum_execution_time class WitRecognizer: @@ -28,6 +28,8 @@ def recognize( file_path: str, wit_config: Config.Wit, ) -> Generator[Dict[str, float], None, List[Dict[str, Union[str, float]]]]: + self.semaphore_pools = [threading.Semaphore(60) for _ in range(len(wit_config.wit_client_access_tokens))] + self.timer = threading.Timer(60, self._reset_semaphore_pools) temp_directory = tempfile.mkdtemp() @@ -50,10 +52,20 @@ def recognize( session = requests.Session() session.mount('https://', adapter) - with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool: + self.timer.start() + + with ThreadPool(processes=min(4, multiprocessing.cpu_count() - 1) * len(wit_config.wit_client_access_tokens)) as pool: async_results = [ - pool.apply_async(self._process_segment, (segment, file_path, wit_config, session)) - for segment in segments + pool.apply_async( + self._process_segment, + ( + segment, + file_path, + wit_config, + session, + index % len(wit_config.wit_client_access_tokens), + ), + ) for index, segment in enumerate(segments) ] transcriptions = [] @@ -69,19 +81,17 @@ def recognize( 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None, } - time.sleep(0.5) - shutil.rmtree(temp_directory) return transcriptions - @minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1) def _process_segment( self, segment: Tuple[str, float, float], file_path: str, wit_config: Config.Wit, session: requests.Session, + wit_client_access_token_index: int ) -> Dict[str, Union[str, float]]: segment_file_path, start, end = segment @@ -90,25 +100,28 @@ def _process_segment( retries = 5 + self.semaphore_pools[wit_client_access_token_index].acquire() + text = '' while retries > 0: - response = session.post( - 'https://api.wit.ai/speech', - headers={ - 'Accept': 'application/vnd.wit.20200513+json', - 'Content-Type': 'audio/wav', - 'Authorization': f'Bearer {wit_config.wit_client_access_token}', - }, - data=audio_content, - ) - - if response.status_code == 200: - try: + try: + response = session.post( + 'https://api.wit.ai/speech', + headers={ + 'Accept': 'application/vnd.wit.20200513+json', + 'Content-Type': 'audio/wav', + 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}', + }, + data=audio_content, + ) + + if response.status_code == 200: text = json.loads(response.text)['text'] break - except KeyError: + else: retries -= 1 - else: + time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1) + except: retries -= 1 time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1) @@ -123,3 +136,23 @@ def _process_segment( 'end': end, 'text': text.strip(), } + + def _reset_semaphore_pools(self) -> None: + do_reset = False + + for semaphore_pool in self.semaphore_pools: + if semaphore_pool._value < 60: + do_reset = True + + if do_reset: + for semaphore_pool in self.semaphore_pools: + while semaphore_pool._value < 60: + semaphore_pool.release() + + if self.timer: + self.timer.cancel() + + self.timer = threading.Timer(60, self._reset_semaphore_pools) + self.timer.start() + else: + self.timer.cancel() diff --git a/tafrigh/utils/cli_utils.py b/tafrigh/utils/cli_utils.py index 6d4c7c5..7e69461 100644 --- a/tafrigh/utils/cli_utils.py +++ b/tafrigh/utils/cli_utils.py @@ -110,11 +110,11 @@ def parse_args(argv: List[str]) -> argparse.Namespace: wit_group = parser.add_argument_group('Wit') - wit_group.add_argument( + input_group.add_argument( '-w', - '--wit_client_access_token', - default='', - help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', + '--wit_client_access_tokens', + nargs='+', + help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', ) wit_group.add_argument( diff --git a/tafrigh/utils/decorators.py b/tafrigh/utils/decorators.py deleted file mode 100644 index 717f8e0..0000000 --- a/tafrigh/utils/decorators.py +++ /dev/null @@ -1,26 +0,0 @@ -import time - -from functools import wraps -from typing import Callable, TypeVar - - -T = TypeVar("T", bound=Callable) - - -def minimum_execution_time(minimum_time: float) -> Callable[[T], T]: - def decorator(func: T) -> T: - @wraps(func) - def wrapper(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - - elapsed_time = end_time - start_time - if elapsed_time < minimum_time: - time.sleep(minimum_time - elapsed_time) - - return result - - return wrapper - - return decorator