diff --git a/README.md b/README.md index 44c0956..01f272a 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@
  • خيارات تقنية Wit
  • @@ -141,8 +141,9 @@ usage: tafrigh [-h] [--skip_if_output_exist | --no-skip_if_output_exist] [--playlist_items PLAYLIST_ITEMS] [--verbose | --no-verbose] [-m MODEL_NAME_OR_PATH] [-t {transcribe,translate}] [-l {af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh}] [--use_faster_whisper | --no-use_faster_whisper] [--use_whisper_jax | --no-use_whisper_jax] [--beam_size BEAM_SIZE] [--ct2_compute_type {default,int8,int8_float16,int16,float16}] - [-w WIT_CLIENT_ACCESS_TOKEN] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT] [--save_files_before_compact | --no-save_files_before_compact] - [--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE] [-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR] + [-w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...]] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT] + [--save_files_before_compact | --no-save_files_before_compact] [--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE] + [-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR] urls_or_paths [urls_or_paths ...] options: @@ -174,8 +175,8 @@ Whisper: Quantization type applied while converting the model to CTranslate2 format. Wit: - -w WIT_CLIENT_ACCESS_TOKEN, --wit_client_access_token WIT_CLIENT_ACCESS_TOKEN - wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used. + -w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...], --wit_client_access_tokens WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...] + List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used. --max_cutting_duration [1-17] The maximum allowed cutting duration. It should be between 1 and 17. @@ -265,7 +266,7 @@ tafrigh "https://youtu.be/Di0vcmnxULs" \ ``` tafrigh "https://youtu.be/dDzxYcEJbgo" \ - --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ + --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ --output_dir . \ --output_formats txt srt \ --min_words_per_segment 10 \ @@ -276,7 +277,7 @@ tafrigh "https://youtu.be/dDzxYcEJbgo" \ ``` tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \ - --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ + --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ --output_dir . \ --output_formats txt srt \ --min_words_per_segment 10 \ @@ -287,7 +288,7 @@ tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \ ``` tafrigh "https://youtu.be/4h5P7jXvW98" "https://youtu.be/jpfndVSROpw" \ - --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ + --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \ --output_dir . \ --output_formats txt srt \ --min_words_per_segment 10 \ diff --git a/colab_notebook.ipynb b/colab_notebook.ipynb index 3873117..1b3d84a 100644 --- a/colab_notebook.ipynb +++ b/colab_notebook.ipynb @@ -141,7 +141,7 @@ " beam_size=5,\n", " ct2_compute_type='default',\n", "\n", - " wit_client_access_token=wit_api_key,\n", + " wit_client_access_tokens=[wit_api_key],\n", " max_cutting_duration=max_cutting_duration,\n", " min_words_per_segment=min_words_per_segment,\n", "\n", diff --git a/tafrigh/cli.py b/tafrigh/cli.py index ef4c602..e8cc989 100644 --- a/tafrigh/cli.py +++ b/tafrigh/cli.py @@ -51,7 +51,7 @@ def main(): beam_size=args.beam_size, ct2_compute_type=args.ct2_compute_type, - wit_client_access_token=args.wit_client_access_token, + wit_client_access_tokens=args.wit_client_access_tokens, max_cutting_duration=args.max_cutting_duration, min_words_per_segment=args.min_words_per_segment, diff --git a/tafrigh/config.py b/tafrigh/config.py index 3598de4..41677a3 100644 --- a/tafrigh/config.py +++ b/tafrigh/config.py @@ -19,7 +19,7 @@ def __init__( use_whisper_jax: bool, beam_size: int, ct2_compute_type: str, - wit_client_access_token: str, + wit_client_access_tokens: List[str], max_cutting_duration: int, min_words_per_segment: int, save_files_before_compact: bool, @@ -40,7 +40,7 @@ def __init__( ct2_compute_type, ) - self.wit = self.Wit(wit_client_access_token, max_cutting_duration) + self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) self.output = self.Output( min_words_per_segment, @@ -52,7 +52,7 @@ def __init__( ) def use_wit(self) -> bool: - return self.wit.wit_client_access_token != '' + return self.wit.wit_client_access_tokens != [] class Input: def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): @@ -85,8 +85,8 @@ def __init__( self.ct2_compute_type = ct2_compute_type class Wit: - def __init__(self, wit_client_access_token: str, max_cutting_duration: int): - self.wit_client_access_token = wit_client_access_token + def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int): + self.wit_client_access_tokens = wit_client_access_tokens self.max_cutting_duration = max_cutting_duration class Output: diff --git a/tafrigh/recognizers/wit_recognizer.py b/tafrigh/recognizers/wit_recognizer.py index 0200006..07689ad 100644 --- a/tafrigh/recognizers/wit_recognizer.py +++ b/tafrigh/recognizers/wit_recognizer.py @@ -7,7 +7,7 @@ import tempfile import time -from itertools import repeat +from multiprocessing.managers import BaseManager from requests.adapters import HTTPAdapter from typing import Dict, Generator, List, Tuple, Union from urllib3.util.retry import Retry @@ -16,19 +16,32 @@ from tafrigh.audio_splitter import AudioSplitter from tafrigh.config import Config -from tafrigh.utils.decorators import minimum_execution_time +from tafrigh.wit_calling_throttle import WitCallingThrottle + + +class WitCallingThrottleManager(BaseManager): + pass + + +WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle) + + +def init_pool(throttle: WitCallingThrottle) -> None: + global wit_calling_throttle + + wit_calling_throttle = throttle class WitRecognizer: def __init__(self, verbose: bool): self.verbose = verbose + self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count() - 1) def recognize( self, file_path: str, wit_config: Config.Wit, ) -> Generator[Dict[str, float], None, List[Dict[str, Union[str, float]]]]: - temp_directory = tempfile.mkdtemp() segments = AudioSplitter().split( @@ -50,39 +63,59 @@ def recognize( session = requests.Session() session.mount('https://', adapter) - with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool: - async_results = [ - pool.apply_async(self._process_segment, (segment, file_path, wit_config, session)) - for segment in segments - ] - - transcriptions = [] - - with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: - while async_results: - if async_results[0].ready(): - transcriptions.append(async_results.pop(0).get()) - pbar.update(1) - - yield { - 'progress': round(len(transcriptions) / len(segments) * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None, - } + pool_processes_count = min( + self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens), + multiprocessing.cpu_count() - 1 + ) - time.sleep(0.5) + with WitCallingThrottleManager() as manager: + wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)) + + with multiprocessing.Pool( + processes=pool_processes_count, + initializer=init_pool, + initargs=(wit_calling_throttle,), + ) as pool: + async_results = [ + pool.apply_async( + self._process_segment, + ( + segment, + file_path, + wit_config, + session, + index % len(wit_config.wit_client_access_tokens), + ), + ) for index, segment in enumerate(segments) + ] + + transcriptions = [] + + with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: + while async_results: + if async_results[0].ready(): + transcriptions.append(async_results.pop(0).get()) + pbar.update(1) + + yield { + 'progress': round(len(transcriptions) / len(segments) * 100, 2), + 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None, + } shutil.rmtree(temp_directory) return transcriptions - @minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1) def _process_segment( self, segment: Tuple[str, float, float], file_path: str, wit_config: Config.Wit, session: requests.Session, + wit_client_access_token_index: int ) -> Dict[str, Union[str, float]]: + wit_calling_throttle.throttle(wit_client_access_token_index) + segment_file_path, start, end = segment with open(segment_file_path, 'rb') as wav_file: @@ -92,25 +125,26 @@ def _process_segment( text = '' while retries > 0: - response = session.post( - 'https://api.wit.ai/speech', - headers={ - 'Accept': 'application/vnd.wit.20200513+json', - 'Content-Type': 'audio/wav', - 'Authorization': f'Bearer {wit_config.wit_client_access_token}', - }, - data=audio_content, - ) - - if response.status_code == 200: - try: + try: + response = session.post( + 'https://api.wit.ai/speech', + headers={ + 'Accept': 'application/vnd.wit.20200513+json', + 'Content-Type': 'audio/wav', + 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}', + }, + data=audio_content, + ) + + if response.status_code == 200: text = json.loads(response.text)['text'] break - except KeyError: + else: retries -= 1 - else: + time.sleep(self.processes_per_wit_client_access_token + 1) + except: retries -= 1 - time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1) + time.sleep(self.processes_per_wit_client_access_token + 1) if retries == 0: logging.warn( diff --git a/tafrigh/utils/cli_utils.py b/tafrigh/utils/cli_utils.py index 6d4c7c5..8a61c0f 100644 --- a/tafrigh/utils/cli_utils.py +++ b/tafrigh/utils/cli_utils.py @@ -112,9 +112,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace: wit_group.add_argument( '-w', - '--wit_client_access_token', - default='', - help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', + '--wit_client_access_tokens', + nargs='+', + help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', ) wit_group.add_argument( diff --git a/tafrigh/utils/decorators.py b/tafrigh/utils/decorators.py deleted file mode 100644 index 717f8e0..0000000 --- a/tafrigh/utils/decorators.py +++ /dev/null @@ -1,26 +0,0 @@ -import time - -from functools import wraps -from typing import Callable, TypeVar - - -T = TypeVar("T", bound=Callable) - - -def minimum_execution_time(minimum_time: float) -> Callable[[T], T]: - def decorator(func: T) -> T: - @wraps(func) - def wrapper(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - - elapsed_time = end_time - start_time - if elapsed_time < minimum_time: - time.sleep(minimum_time - elapsed_time) - - return result - - return wrapper - - return decorator diff --git a/tafrigh/wit_calling_throttle.py b/tafrigh/wit_calling_throttle.py new file mode 100644 index 0000000..c8a93a2 --- /dev/null +++ b/tafrigh/wit_calling_throttle.py @@ -0,0 +1,25 @@ +import time + +from threading import Lock + + +class WitCallingThrottle: + def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1): + self.wit_client_access_tokens_count = wit_client_access_tokens_count + self.call_times_limit = call_times_limit + self.expired_time = expired_time + self.call_timestamps = [list()] * self.wit_client_access_tokens_count + self.locks = [Lock()] * wit_client_access_tokens_count + + def throttle(self, wit_client_access_token_index: int) -> None: + with self.locks[wit_client_access_token_index]: + while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + now = time.time() + self.call_timestamps[wit_client_access_token_index] = list(filter( + lambda x: now - x < self.expired_time, + self.call_timestamps[wit_client_access_token_index] + )) + if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now + time.sleep(time_to_sleep) + self.call_timestamps[wit_client_access_token_index].append(time.time())