diff --git a/README.md b/README.md
index 44c0956..01f272a 100644
--- a/README.md
+++ b/README.md
@@ -109,7 +109,7 @@
خيارات تقنية Wit
- - مفتاح wit.ai: يمكنك استخدام تقنيات wit.ai لتفريغ المواد إلى نصوص من خلال تمرير المفتاح الخاص بك للاختيار
--wit_client_access_token
. إذا تم تمرير هذا الاختيار، سيتم استخدام wit.ai لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper
+ - مفاتيح wit.ai: يمكنك استخدام تقنيات wit.ai لتفريغ المواد إلى نصوص من خلال تمرير المفتاح أو المفاتيح الخاصة بك للاختيار
--wit_client_access_tokens
. إذا تم تمرير هذا الاختيار، سيتم استخدام wit.ai لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper
- تحديد أقصى مدة للتقطيع: يمكنك تحديد أقصى مدة للتقطيع والتي ستؤثر على طول الجمل في ملفات SRT و VTT من خلال تمرير الاختيار
--max_cutting_duration
. القيمة الافتراضية هي 15
@@ -141,8 +141,9 @@
usage: tafrigh [-h] [--skip_if_output_exist | --no-skip_if_output_exist] [--playlist_items PLAYLIST_ITEMS] [--verbose | --no-verbose] [-m MODEL_NAME_OR_PATH] [-t {transcribe,translate}]
[-l {af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh}]
[--use_faster_whisper | --no-use_faster_whisper] [--use_whisper_jax | --no-use_whisper_jax] [--beam_size BEAM_SIZE] [--ct2_compute_type {default,int8,int8_float16,int16,float16}]
- [-w WIT_CLIENT_ACCESS_TOKEN] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT] [--save_files_before_compact | --no-save_files_before_compact]
- [--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE] [-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR]
+ [-w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...]] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT]
+ [--save_files_before_compact | --no-save_files_before_compact] [--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE]
+ [-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR]
urls_or_paths [urls_or_paths ...]
options:
@@ -174,8 +175,8 @@ Whisper:
Quantization type applied while converting the model to CTranslate2 format.
Wit:
- -w WIT_CLIENT_ACCESS_TOKEN, --wit_client_access_token WIT_CLIENT_ACCESS_TOKEN
- wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.
+ -w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...], --wit_client_access_tokens WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...]
+ List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.
--max_cutting_duration [1-17]
The maximum allowed cutting duration. It should be between 1 and 17.
@@ -265,7 +266,7 @@ tafrigh "https://youtu.be/Di0vcmnxULs" \
```
tafrigh "https://youtu.be/dDzxYcEJbgo" \
- --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
+ --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
@@ -276,7 +277,7 @@ tafrigh "https://youtu.be/dDzxYcEJbgo" \
```
tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \
- --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
+ --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
@@ -287,7 +288,7 @@ tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \
```
tafrigh "https://youtu.be/4h5P7jXvW98" "https://youtu.be/jpfndVSROpw" \
- --wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
+ --wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
diff --git a/colab_notebook.ipynb b/colab_notebook.ipynb
index 3873117..1b3d84a 100644
--- a/colab_notebook.ipynb
+++ b/colab_notebook.ipynb
@@ -141,7 +141,7 @@
" beam_size=5,\n",
" ct2_compute_type='default',\n",
"\n",
- " wit_client_access_token=wit_api_key,\n",
+ " wit_client_access_tokens=[wit_api_key],\n",
" max_cutting_duration=max_cutting_duration,\n",
" min_words_per_segment=min_words_per_segment,\n",
"\n",
diff --git a/tafrigh/cli.py b/tafrigh/cli.py
index ef4c602..e8cc989 100644
--- a/tafrigh/cli.py
+++ b/tafrigh/cli.py
@@ -51,7 +51,7 @@ def main():
beam_size=args.beam_size,
ct2_compute_type=args.ct2_compute_type,
- wit_client_access_token=args.wit_client_access_token,
+ wit_client_access_tokens=args.wit_client_access_tokens,
max_cutting_duration=args.max_cutting_duration,
min_words_per_segment=args.min_words_per_segment,
diff --git a/tafrigh/config.py b/tafrigh/config.py
index 3598de4..41677a3 100644
--- a/tafrigh/config.py
+++ b/tafrigh/config.py
@@ -19,7 +19,7 @@ def __init__(
use_whisper_jax: bool,
beam_size: int,
ct2_compute_type: str,
- wit_client_access_token: str,
+ wit_client_access_tokens: List[str],
max_cutting_duration: int,
min_words_per_segment: int,
save_files_before_compact: bool,
@@ -40,7 +40,7 @@ def __init__(
ct2_compute_type,
)
- self.wit = self.Wit(wit_client_access_token, max_cutting_duration)
+ self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration)
self.output = self.Output(
min_words_per_segment,
@@ -52,7 +52,7 @@ def __init__(
)
def use_wit(self) -> bool:
- return self.wit.wit_client_access_token != ''
+ return self.wit.wit_client_access_tokens != []
class Input:
def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool):
@@ -85,8 +85,8 @@ def __init__(
self.ct2_compute_type = ct2_compute_type
class Wit:
- def __init__(self, wit_client_access_token: str, max_cutting_duration: int):
- self.wit_client_access_token = wit_client_access_token
+ def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int):
+ self.wit_client_access_tokens = wit_client_access_tokens
self.max_cutting_duration = max_cutting_duration
class Output:
diff --git a/tafrigh/recognizers/wit_recognizer.py b/tafrigh/recognizers/wit_recognizer.py
index 0200006..07689ad 100644
--- a/tafrigh/recognizers/wit_recognizer.py
+++ b/tafrigh/recognizers/wit_recognizer.py
@@ -7,7 +7,7 @@
import tempfile
import time
-from itertools import repeat
+from multiprocessing.managers import BaseManager
from requests.adapters import HTTPAdapter
from typing import Dict, Generator, List, Tuple, Union
from urllib3.util.retry import Retry
@@ -16,19 +16,32 @@
from tafrigh.audio_splitter import AudioSplitter
from tafrigh.config import Config
-from tafrigh.utils.decorators import minimum_execution_time
+from tafrigh.wit_calling_throttle import WitCallingThrottle
+
+
+class WitCallingThrottleManager(BaseManager):
+ pass
+
+
+WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle)
+
+
+def init_pool(throttle: WitCallingThrottle) -> None:
+ global wit_calling_throttle
+
+ wit_calling_throttle = throttle
class WitRecognizer:
def __init__(self, verbose: bool):
self.verbose = verbose
+ self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count() - 1)
def recognize(
self,
file_path: str,
wit_config: Config.Wit,
) -> Generator[Dict[str, float], None, List[Dict[str, Union[str, float]]]]:
-
temp_directory = tempfile.mkdtemp()
segments = AudioSplitter().split(
@@ -50,39 +63,59 @@ def recognize(
session = requests.Session()
session.mount('https://', adapter)
- with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool:
- async_results = [
- pool.apply_async(self._process_segment, (segment, file_path, wit_config, session))
- for segment in segments
- ]
-
- transcriptions = []
-
- with tqdm(total=len(segments), disable=self.verbose is not False) as pbar:
- while async_results:
- if async_results[0].ready():
- transcriptions.append(async_results.pop(0).get())
- pbar.update(1)
-
- yield {
- 'progress': round(len(transcriptions) / len(segments) * 100, 2),
- 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None,
- }
+ pool_processes_count = min(
+ self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens),
+ multiprocessing.cpu_count() - 1
+ )
- time.sleep(0.5)
+ with WitCallingThrottleManager() as manager:
+ wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens))
+
+ with multiprocessing.Pool(
+ processes=pool_processes_count,
+ initializer=init_pool,
+ initargs=(wit_calling_throttle,),
+ ) as pool:
+ async_results = [
+ pool.apply_async(
+ self._process_segment,
+ (
+ segment,
+ file_path,
+ wit_config,
+ session,
+ index % len(wit_config.wit_client_access_tokens),
+ ),
+ ) for index, segment in enumerate(segments)
+ ]
+
+ transcriptions = []
+
+ with tqdm(total=len(segments), disable=self.verbose is not False) as pbar:
+ while async_results:
+ if async_results[0].ready():
+ transcriptions.append(async_results.pop(0).get())
+ pbar.update(1)
+
+ yield {
+ 'progress': round(len(transcriptions) / len(segments) * 100, 2),
+ 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None,
+ }
shutil.rmtree(temp_directory)
return transcriptions
- @minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1)
def _process_segment(
self,
segment: Tuple[str, float, float],
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
+ wit_client_access_token_index: int
) -> Dict[str, Union[str, float]]:
+ wit_calling_throttle.throttle(wit_client_access_token_index)
+
segment_file_path, start, end = segment
with open(segment_file_path, 'rb') as wav_file:
@@ -92,25 +125,26 @@ def _process_segment(
text = ''
while retries > 0:
- response = session.post(
- 'https://api.wit.ai/speech',
- headers={
- 'Accept': 'application/vnd.wit.20200513+json',
- 'Content-Type': 'audio/wav',
- 'Authorization': f'Bearer {wit_config.wit_client_access_token}',
- },
- data=audio_content,
- )
-
- if response.status_code == 200:
- try:
+ try:
+ response = session.post(
+ 'https://api.wit.ai/speech',
+ headers={
+ 'Accept': 'application/vnd.wit.20200513+json',
+ 'Content-Type': 'audio/wav',
+ 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}',
+ },
+ data=audio_content,
+ )
+
+ if response.status_code == 200:
text = json.loads(response.text)['text']
break
- except KeyError:
+ else:
retries -= 1
- else:
+ time.sleep(self.processes_per_wit_client_access_token + 1)
+ except:
retries -= 1
- time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1)
+ time.sleep(self.processes_per_wit_client_access_token + 1)
if retries == 0:
logging.warn(
diff --git a/tafrigh/utils/cli_utils.py b/tafrigh/utils/cli_utils.py
index 6d4c7c5..8a61c0f 100644
--- a/tafrigh/utils/cli_utils.py
+++ b/tafrigh/utils/cli_utils.py
@@ -112,9 +112,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
wit_group.add_argument(
'-w',
- '--wit_client_access_token',
- default='',
- help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
+ '--wit_client_access_tokens',
+ nargs='+',
+ help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
)
wit_group.add_argument(
diff --git a/tafrigh/utils/decorators.py b/tafrigh/utils/decorators.py
deleted file mode 100644
index 717f8e0..0000000
--- a/tafrigh/utils/decorators.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import time
-
-from functools import wraps
-from typing import Callable, TypeVar
-
-
-T = TypeVar("T", bound=Callable)
-
-
-def minimum_execution_time(minimum_time: float) -> Callable[[T], T]:
- def decorator(func: T) -> T:
- @wraps(func)
- def wrapper(*args, **kwargs):
- start_time = time.time()
- result = func(*args, **kwargs)
- end_time = time.time()
-
- elapsed_time = end_time - start_time
- if elapsed_time < minimum_time:
- time.sleep(minimum_time - elapsed_time)
-
- return result
-
- return wrapper
-
- return decorator
diff --git a/tafrigh/wit_calling_throttle.py b/tafrigh/wit_calling_throttle.py
new file mode 100644
index 0000000..c8a93a2
--- /dev/null
+++ b/tafrigh/wit_calling_throttle.py
@@ -0,0 +1,25 @@
+import time
+
+from threading import Lock
+
+
+class WitCallingThrottle:
+ def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1):
+ self.wit_client_access_tokens_count = wit_client_access_tokens_count
+ self.call_times_limit = call_times_limit
+ self.expired_time = expired_time
+ self.call_timestamps = [list()] * self.wit_client_access_tokens_count
+ self.locks = [Lock()] * wit_client_access_tokens_count
+
+ def throttle(self, wit_client_access_token_index: int) -> None:
+ with self.locks[wit_client_access_token_index]:
+ while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit:
+ now = time.time()
+ self.call_timestamps[wit_client_access_token_index] = list(filter(
+ lambda x: now - x < self.expired_time,
+ self.call_timestamps[wit_client_access_token_index]
+ ))
+ if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit:
+ time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now
+ time.sleep(time_to_sleep)
+ self.call_timestamps[wit_client_access_token_index].append(time.time())