Skip to content

Commit

Permalink
Use multiple wit.ai API keys to transcript the same video
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Aug 5, 2023
1 parent 5a4ac3b commit df379f7
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 82 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
<li>
خيارات تقنية Wit
<ul dir="rtl">
<li>مفتاح <a href="wit.ai">wit.ai</a>: يمكنك استخدام تقنيات <a href="wit.ai">wit.ai</a> لتفريغ المواد إلى نصوص من خلال تمرير المفتاح الخاص بك للاختيار <code dir="ltr">--wit_client_access_token</code>. إذا تم تمرير هذا الاختيار، سيتم استخدام <a href="wit.ai">wit.ai</a> لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper</li>
<li>مفاتيح <a href="wit.ai">wit.ai</a>: يمكنك استخدام تقنيات <a href="wit.ai">wit.ai</a> لتفريغ المواد إلى نصوص من خلال تمرير المفتاح أو المفاتيح الخاصة بك للاختيار <code dir="ltr">--wit_client_access_tokens</code>. إذا تم تمرير هذا الاختيار، سيتم استخدام <a href="wit.ai">wit.ai</a> لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper</li>
<li>تحديد أقصى مدة للتقطيع: يمكنك تحديد أقصى مدة للتقطيع والتي ستؤثر على طول الجمل في ملفات SRT و VTT من خلال تمرير الاختيار <code dir="ltr">--max_cutting_duration</code>. القيمة الافتراضية هي <code>15</code></li>
</ul>
</li>
Expand Down Expand Up @@ -141,8 +141,9 @@
usage: tafrigh [-h] [--skip_if_output_exist | --no-skip_if_output_exist] [--playlist_items PLAYLIST_ITEMS] [--verbose | --no-verbose] [-m MODEL_NAME_OR_PATH] [-t {transcribe,translate}]
[-l {af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh}]
[--use_faster_whisper | --no-use_faster_whisper] [--use_whisper_jax | --no-use_whisper_jax] [--beam_size BEAM_SIZE] [--ct2_compute_type {default,int8,int8_float16,int16,float16}]
[-w WIT_CLIENT_ACCESS_TOKEN] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT] [--save_files_before_compact | --no-save_files_before_compact]
[--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE] [-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR]
[-w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...]] [--max_cutting_duration [1-17]] [--min_words_per_segment MIN_WORDS_PER_SEGMENT]
[--save_files_before_compact | --no-save_files_before_compact] [--save_yt_dlp_responses | --no-save_yt_dlp_responses] [--output_sample OUTPUT_SAMPLE]
[-f {all,txt,srt,vtt,none} [{all,txt,srt,vtt,none} ...]] [-o OUTPUT_DIR]
urls_or_paths [urls_or_paths ...]
options:
Expand Down Expand Up @@ -174,8 +175,8 @@ Whisper:
Quantization type applied while converting the model to CTranslate2 format.
Wit:
-w WIT_CLIENT_ACCESS_TOKEN, --wit_client_access_token WIT_CLIENT_ACCESS_TOKEN
wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.
-w WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...], --wit_client_access_tokens WIT_CLIENT_ACCESS_TOKENS [WIT_CLIENT_ACCESS_TOKENS ...]
List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.
--max_cutting_duration [1-17]
The maximum allowed cutting duration. It should be between 1 and 17.
Expand Down Expand Up @@ -265,7 +266,7 @@ tafrigh "https://youtu.be/Di0vcmnxULs" \

```
tafrigh "https://youtu.be/dDzxYcEJbgo" \
--wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
Expand All @@ -276,7 +277,7 @@ tafrigh "https://youtu.be/dDzxYcEJbgo" \

```
tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \
--wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
Expand All @@ -287,7 +288,7 @@ tafrigh "https://youtube.com/playlist?list=PLyS-PHSxRDxsLnVsPrIwnsHMO5KgLz7T5" \

```
tafrigh "https://youtu.be/4h5P7jXvW98" "https://youtu.be/jpfndVSROpw" \
--wit_client_access_token XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--wit_client_access_tokens XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX \
--output_dir . \
--output_formats txt srt \
--min_words_per_segment 10 \
Expand Down
2 changes: 1 addition & 1 deletion colab_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" beam_size=5,\n",
" ct2_compute_type='default',\n",
"\n",
" wit_client_access_token=wit_api_key,\n",
" wit_client_access_tokens=[wit_api_key],\n",
" max_cutting_duration=max_cutting_duration,\n",
" min_words_per_segment=min_words_per_segment,\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tafrigh/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
beam_size=args.beam_size,
ct2_compute_type=args.ct2_compute_type,

wit_client_access_token=args.wit_client_access_token,
wit_client_access_tokens=args.wit_client_access_tokens,
max_cutting_duration=args.max_cutting_duration,
min_words_per_segment=args.min_words_per_segment,

Expand Down
10 changes: 5 additions & 5 deletions tafrigh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
use_whisper_jax: bool,
beam_size: int,
ct2_compute_type: str,
wit_client_access_token: str,
wit_client_access_tokens: List[str],
max_cutting_duration: int,
min_words_per_segment: int,
save_files_before_compact: bool,
Expand All @@ -40,7 +40,7 @@ def __init__(
ct2_compute_type,
)

self.wit = self.Wit(wit_client_access_token, max_cutting_duration)
self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration)

self.output = self.Output(
min_words_per_segment,
Expand All @@ -52,7 +52,7 @@ def __init__(
)

def use_wit(self) -> bool:
return self.wit.wit_client_access_token != ''
return self.wit.wit_client_access_tokens != []

class Input:
def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool):
Expand Down Expand Up @@ -85,8 +85,8 @@ def __init__(
self.ct2_compute_type = ct2_compute_type

class Wit:
def __init__(self, wit_client_access_token: str, max_cutting_duration: int):
self.wit_client_access_token = wit_client_access_token
def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int):
self.wit_client_access_tokens = wit_client_access_tokens
self.max_cutting_duration = max_cutting_duration

class Output:
Expand Down
110 changes: 72 additions & 38 deletions tafrigh/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tempfile
import time

from itertools import repeat
from multiprocessing.managers import BaseManager
from requests.adapters import HTTPAdapter
from typing import Dict, Generator, List, Tuple, Union
from urllib3.util.retry import Retry
Expand All @@ -16,19 +16,32 @@

from tafrigh.audio_splitter import AudioSplitter
from tafrigh.config import Config
from tafrigh.utils.decorators import minimum_execution_time
from tafrigh.wit_calling_throttle import WitCallingThrottle


class WitCallingThrottleManager(BaseManager):
pass


WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle)


def init_pool(throttle: WitCallingThrottle) -> None:
global wit_calling_throttle

wit_calling_throttle = throttle


class WitRecognizer:
def __init__(self, verbose: bool):
self.verbose = verbose
self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count())

def recognize(
self,
file_path: str,
wit_config: Config.Wit,
) -> Generator[Dict[str, float], None, List[Dict[str, Union[str, float]]]]:

temp_directory = tempfile.mkdtemp()

segments = AudioSplitter().split(
Expand All @@ -50,39 +63,59 @@ def recognize(
session = requests.Session()
session.mount('https://', adapter)

with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool:
async_results = [
pool.apply_async(self._process_segment, (segment, file_path, wit_config, session))
for segment in segments
]

transcriptions = []

with tqdm(total=len(segments), disable=self.verbose is not False) as pbar:
while async_results:
if async_results[0].ready():
transcriptions.append(async_results.pop(0).get())
pbar.update(1)

yield {
'progress': round(len(transcriptions) / len(segments) * 100, 2),
'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None,
}
pool_processes_count = min(
self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens),
multiprocessing.cpu_count(),
)

time.sleep(0.5)
with WitCallingThrottleManager() as manager:
wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens))

with multiprocessing.Pool(
processes=pool_processes_count,
initializer=init_pool,
initargs=(wit_calling_throttle,),
) as pool:
async_results = [
pool.apply_async(
self._process_segment,
(
segment,
file_path,
wit_config,
session,
index % len(wit_config.wit_client_access_tokens),
),
) for index, segment in enumerate(segments)
]

transcriptions = []

with tqdm(total=len(segments), disable=self.verbose is not False) as pbar:
while async_results:
if async_results[0].ready():
transcriptions.append(async_results.pop(0).get())
pbar.update(1)

yield {
'progress': round(len(transcriptions) / len(segments) * 100, 2),
'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] if pbar.format_dict['rate'] and pbar.total else None,
}

shutil.rmtree(temp_directory)

return transcriptions

@minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1)
def _process_segment(
self,
segment: Tuple[str, float, float],
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
wit_client_access_token_index: int
) -> Dict[str, Union[str, float]]:
wit_calling_throttle.throttle(wit_client_access_token_index)

segment_file_path, start, end = segment

with open(segment_file_path, 'rb') as wav_file:
Expand All @@ -92,25 +125,26 @@ def _process_segment(

text = ''
while retries > 0:
response = session.post(
'https://api.wit.ai/speech',
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Authorization': f'Bearer {wit_config.wit_client_access_token}',
},
data=audio_content,
)

if response.status_code == 200:
try:
try:
response = session.post(
'https://api.wit.ai/speech',
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}',
},
data=audio_content,
)

if response.status_code == 200:
text = json.loads(response.text)['text']
break
except KeyError:
else:
retries -= 1
else:
time.sleep(self.processes_per_wit_client_access_token + 1)
except:
retries -= 1
time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1)
time.sleep(self.processes_per_wit_client_access_token + 1)

if retries == 0:
logging.warn(
Expand Down
6 changes: 3 additions & 3 deletions tafrigh/utils/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace:

wit_group.add_argument(
'-w',
'--wit_client_access_token',
default='',
help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
'--wit_client_access_tokens',
nargs='+',
help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
)

wit_group.add_argument(
Expand Down
26 changes: 0 additions & 26 deletions tafrigh/utils/decorators.py

This file was deleted.

28 changes: 28 additions & 0 deletions tafrigh/wit_calling_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import time

from threading import Lock


class WitCallingThrottle:
def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1):
self.wit_client_access_tokens_count = wit_client_access_tokens_count
self.call_times_limit = call_times_limit
self.expired_time = expired_time
self.call_timestamps = [list()] * self.wit_client_access_tokens_count
self.locks = [Lock()] * self.wit_client_access_tokens_count

def throttle(self, wit_client_access_token_index: int) -> None:
with self.locks[wit_client_access_token_index]:
while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit:
now = time.time()

self.call_timestamps[wit_client_access_token_index] = list(filter(
lambda x: now - x < self.expired_time,
self.call_timestamps[wit_client_access_token_index],
))

if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit:
time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now
time.sleep(time_to_sleep)

self.call_timestamps[wit_client_access_token_index].append(time.time())

0 comments on commit df379f7

Please sign in to comment.