From 012013650c8e646fb88be136a13fe50c19b98311 Mon Sep 17 00:00:00 2001 From: Ali Hamdi Ali Fadel Date: Wed, 26 Jun 2024 21:49:31 +0000 Subject: [PATCH] Fix pre-commit warnings --- .github/workflows/formatter.yml | 8 +- pyproject.toml | 2 +- src/__init__.py | 26 +- src/audio_splitter.py | 118 +++---- src/cli.py | 397 ++++++++++++------------ src/config.py | 216 +++++++------ src/downloader.py | 98 +++--- src/recognizers/whisper_recognizer.py | 179 +++++------ src/recognizers/wit_calling_throttle.py | 42 +-- src/recognizers/wit_recognizer.py | 262 ++++++++-------- src/types/segment_type.py | 9 + src/types/transcript_type.py | 20 +- src/types/whisper/type_hints.py | 6 +- src/utils/cli_utils.py | 356 ++++++++++----------- src/utils/file_utils.py | 22 +- src/utils/time_utils.py | 20 +- src/utils/whisper/whisper_utils.py | 16 +- src/utils/wit/file_utils.py | 8 +- src/writer.py | 291 ++++++++--------- 19 files changed, 1054 insertions(+), 1042 deletions(-) create mode 100644 src/types/segment_type.py diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml index 90c9231..1eda0cd 100644 --- a/.github/workflows/formatter.yml +++ b/.github/workflows/formatter.yml @@ -14,12 +14,8 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: 3.9 - - name: black formatter - uses: rickstaa/action-black@v1 - with: - black_args: ". --check --diff --skip-string-normalization --line-length 120" + python-version: 3.11 - name: isort formatter uses: isort/isort-action@v1 with: - configuration: "--profile black --src tafrigh --line-length 120 --lines-between-types 1 --lines-after-imports 2 --case-sensitive --trailing-comma --check-only --diff" + configuration: "--src src --line-length 120 --lines-between-types 1 --lines-after-imports 2 --check-only --diff" diff --git a/pyproject.toml b/pyproject.toml index e8117c1..187e008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ homepage = "https://tafrigh.ieasybooks.com" repository = "https://github.com/ieasybooks/tafrigh" [tool.poetry.dependencies] -python = ">=3.10,<3.12" +python = "3.11" tqdm = ">=4.66.4" yt-dlp = ">=2024.4.9" auditok = {version = ">=0.2.0", extras = ["wit"]} diff --git a/src/__init__.py b/src/__init__.py index aba4c1b..b271497 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,12 +1,12 @@ __all__ = [ - "farrigh", - "Config", - "Downloader", - "TranscriptType", - "Writer", - "WhisperRecognizer", - "AudioSplitter", - "WitRecognizer", + 'farrigh', + 'Config', + 'Downloader', + 'TranscriptType', + 'Writer', + 'WhisperRecognizer', + 'AudioSplitter', + 'WitRecognizer', ] @@ -18,12 +18,12 @@ try: - from .recognizers.whisper_recognizer import WhisperRecognizer + from .recognizers.whisper_recognizer import WhisperRecognizer except ModuleNotFoundError: - pass + pass try: - from .audio_splitter import AudioSplitter - from .recognizers.wit_recognizer import WitRecognizer + from .audio_splitter import AudioSplitter + from .recognizers.wit_recognizer import WitRecognizer except ModuleNotFoundError: - pass + pass diff --git a/src/audio_splitter.py b/src/audio_splitter.py index b08bdbd..0abb895 100644 --- a/src/audio_splitter.py +++ b/src/audio_splitter.py @@ -7,71 +7,71 @@ class AudioSplitter: - def split( - self, - file_path: str, - output_dir: str, - min_dur: float = 0.5, - max_dur: float = 15, - max_silence: float = 0.5, - energy_threshold: float = 50, - expand_segments_with_noise: bool = False, - noise_seconds: int = 1, - noise_amplitude: int = 0, - ) -> list[tuple[str, float, float]]: - segments = split( - file_path, - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - energy_threshold=energy_threshold, - ) + def split( + self, + file_path: str, + output_dir: str, + min_dur: float = 0.5, + max_dur: float = 15, + max_silence: float = 0.5, + energy_threshold: float = 50, + expand_segments_with_noise: bool = False, + noise_seconds: int = 1, + noise_amplitude: int = 0, + ) -> list[tuple[str, float, float]]: + segments = split( + file_path, + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + energy_threshold=energy_threshold, + ) - if expand_segments_with_noise: - segments = [ - ( - self._expand_segment_with_noise(segment, noise_seconds, noise_amplitude), - segment.meta.start, - segment.meta.end, - ) for segment in segments - ] + if expand_segments_with_noise: + segments = [ + ( + self._expand_segment_with_noise(segment, noise_seconds, noise_amplitude), + segment.meta.start, + segment.meta.end, + ) for segment in segments + ] - return self._save_segments(output_dir, segments) + return self._save_segments(output_dir, segments) - def _expand_segment_with_noise( - self, - segment: AudioRegion, - noise_seconds: int, - noise_amplitude: int, - ) -> AudioSegment: + def _expand_segment_with_noise( + self, + segment: AudioRegion, + noise_seconds: int, + noise_amplitude: int, + ) -> AudioSegment: - audio_segment = AudioSegment( - segment._data, - frame_rate=segment.sampling_rate, - sample_width=segment.sample_width, - channels=segment.channels, - ) + audio_segment = AudioSegment( + segment._data, + frame_rate=segment.sampling_rate, + sample_width=segment.sample_width, + channels=segment.channels, + ) - pre_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) - post_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) + pre_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) + post_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) - return pre_noise + audio_segment + post_noise + return pre_noise + audio_segment + post_noise - def _save_segments( - self, - output_dir: str, - segments: list[AudioSegment | tuple[AudioSegment, float, float]], - ) -> list[tuple[str, float, float]]: - segment_paths = [] + def _save_segments( + self, + output_dir: str, + segments: list[AudioSegment | tuple[AudioSegment, float, float]], + ) -> list[tuple[str, float, float]]: + segment_paths = [] - for i, segment in enumerate(segments): - output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3') + for i, segment in enumerate(segments): + output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3') - if isinstance(segment, tuple): - segment[0].export(output_file, format='mp3') - segment_paths.append((output_file, segment[1], segment[2])) - else: - segment.save(output_file) - segment_paths.append((output_file, segment.meta.start, segment.meta.end)) - - return segment_paths + if isinstance(segment, tuple): + segment[0].export(output_file, format='mp3') + segment_paths.append((output_file, segment[1], segment[2])) + else: + segment.save(output_file) + segment_paths.append((output_file, segment.meta.start, segment.meta.end)) + + return segment_paths diff --git a/src/cli.py b/src/cli.py index 069d9e3..e31fd4e 100644 --- a/src/cli.py +++ b/src/cli.py @@ -13,252 +13,261 @@ from .config import Config from .downloader import Downloader +from .types.segment_type import SegmentType from .utils import cli_utils, file_utils, time_utils from .writer import Writer try: - import requests + import requests - from .recognizers.wit_recognizer import WitRecognizer - from .utils.wit import file_utils as wit_file_utils + from .recognizers.wit_recognizer import WitRecognizer + from .utils.wit import file_utils as wit_file_utils except ModuleNotFoundError: - pass + pass + try: - from .recognizers.whisper_recognizer import WhisperRecognizer - from .types.whisper.type_hints import WhisperModel - from .utils.whisper import whisper_utils + from .recognizers.whisper_recognizer import WhisperRecognizer + from .types.whisper.type_hints import WhisperModel + from .utils.whisper import whisper_utils except ModuleNotFoundError: - pass + pass def main(): - args = cli_utils.parse_args(sys.argv[1:]) - - config = Config( - urls_or_paths=args.urls_or_paths, - skip_if_output_exist=args.skip_if_output_exist, - playlist_items=args.playlist_items, - verbose=args.verbose, - # - model_name_or_path=args.model_name_or_path, - task=args.task, - language=args.language, - use_faster_whisper=args.use_faster_whisper, - beam_size=args.beam_size, - ct2_compute_type=args.ct2_compute_type, - # - wit_client_access_tokens=args.wit_client_access_tokens, - max_cutting_duration=args.max_cutting_duration, - min_words_per_segment=args.min_words_per_segment, - # - save_files_before_compact=args.save_files_before_compact, - save_yt_dlp_responses=args.save_yt_dlp_responses, - output_sample=args.output_sample, - output_formats=args.output_formats, - output_dir=args.output_dir, - ) - - if config.use_wit() and config.input.skip_if_output_exist: - retries = 3 - - while retries > 0: - try: - deque(farrigh(config), maxlen=0) - break - except requests.exceptions.RetryError: - retries -= 1 - else: + args = cli_utils.parse_args(sys.argv[1:]) + + config = Config( + urls_or_paths=args.urls_or_paths, + skip_if_output_exist=args.skip_if_output_exist, + playlist_items=args.playlist_items, + verbose=args.verbose, + # + model_name_or_path=args.model_name_or_path, + task=args.task, + language=args.language, + use_faster_whisper=args.use_faster_whisper, + beam_size=args.beam_size, + ct2_compute_type=args.ct2_compute_type, + # + wit_client_access_tokens=args.wit_client_access_tokens, + max_cutting_duration=args.max_cutting_duration, + min_words_per_segment=args.min_words_per_segment, + # + save_files_before_compact=args.save_files_before_compact, + save_yt_dlp_responses=args.save_yt_dlp_responses, + output_sample=args.output_sample, + output_formats=args.output_formats, + output_dir=args.output_dir, + ) + + if config.use_wit() and config.input.skip_if_output_exist: + retries = 3 + + while retries > 0: + try: deque(farrigh(config), maxlen=0) + break + except requests.exceptions.RetryError: + retries -= 1 + else: + deque(farrigh(config), maxlen=0) -def farrigh(config: Config) -> Generator[dict[str, int], None, None]: - prepare_output_dir(config.output.output_dir) +def farrigh(config: Config) -> Generator[dict[str, Any], None, None]: + prepare_output_dir(config.output.output_dir) - model = None - if not config.use_wit(): - model = whisper_utils.load_model(config.whisper) + model = None + if not config.use_wit(): + model = whisper_utils.load_model(config.whisper) - segments = [] + segments: list[SegmentType] = [] - for idx, item in enumerate(tqdm(config.input.urls_or_paths, desc='URLs or local paths')): - progress_info = { - 'outer_total': len(config.input.urls_or_paths), - 'outer_current': idx + 1, - 'outer_status': 'processing', - } + for idx, item in enumerate(tqdm(config.input.urls_or_paths, desc='URLs or local paths')): + progress_info = { + 'outer_total': len(config.input.urls_or_paths), + 'outer_current': idx + 1, + 'outer_status': 'processing', + } - if Path(item).exists(): - file_or_folder = Path(item) - for progress_info, local_elements_segments in process_local(file_or_folder, model, config, progress_info): - segments.extend(local_elements_segments) - yield progress_info - elif re.match('(https?://)', item): - for progress_info, url_elements_segments in process_url(item, model, config, progress_info): - segments.extend(url_elements_segments) - yield progress_info - else: - logging.error(f'Path {item} does not exist and is not a URL either.') + if Path(item).exists(): + file_or_folder = Path(item) + for progress_info, local_elements_segments in process_local(file_or_folder, model, config, progress_info): + segments.extend(local_elements_segments) + yield progress_info + elif re.match('(https?://)', item): + for progress_info, url_elements_segments in process_url(item, model, config, progress_info): + segments.extend(url_elements_segments) + yield progress_info + else: + logging.error(f'Path {item} does not exist and is not a URL either.') - progress_info['outer_status'] = 'completed' - yield progress_info + progress_info['outer_status'] = 'completed' + yield progress_info - continue + continue - progress_info['outer_status'] = 'completed' - yield progress_info + progress_info['outer_status'] = 'completed' + yield progress_info - write_output_sample(segments, config.output) + write_output_sample(segments, config.output) def prepare_output_dir(output_dir: str) -> None: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) def process_local( - path: Path, - model: 'WhisperModel', - config: Config, - progress_info: dict, -) -> Generator[tuple[dict[str, int], list[list[dict[str, str | float]]]], None, None]: - filtered_media_files: list[Path] = file_utils.filter_media_files([path] if path.is_file() else path.iterdir()) - files: list[dict[str, Any]] = [{'file_name': file.name, 'file_path': file} for file in filtered_media_files] - - for idx, file in enumerate(tqdm(files, desc='Local files')): - new_progress_info = progress_info.copy() - new_progress_info.update( - { - 'inner_total': len(files), - 'inner_current': idx + 1, - 'inner_status': 'processing', - 'progress': 0.0, - 'remaining_time': None, - } - ) - yield new_progress_info, [] - - writer = Writer() - if config.input.skip_if_output_exist and writer.is_output_exist(Path(file['file_name']).stem, config.output): - new_progress_info['inner_status'] = 'completed' - yield new_progress_info, [] + path: Path, + model: 'WhisperModel', + config: Config, + progress_info: dict, +) -> Generator[tuple[dict[str, Any], list[SegmentType]], None, None]: + filtered_media_files = file_utils.filter_media_files([path] if path.is_file() else list(path.iterdir())) + files: list[dict[str, Any]] = [{'file_name': file.name, 'file_path': file} for file in filtered_media_files] + + for idx, file in enumerate(tqdm(files, desc='Local files')): + new_progress_info = progress_info.copy() + new_progress_info.update( + { + 'inner_total': len(files), + 'inner_current': idx + 1, + 'inner_status': 'processing', + 'progress': 0.0, + 'remaining_time': None, + } + ) + yield new_progress_info, [] - continue + writer = Writer() + if config.input.skip_if_output_exist and writer.is_output_exist(Path(file['file_name']).stem, config.output): + new_progress_info['inner_status'] = 'completed' + yield new_progress_info, [] - file_path = str(file['file_path'].absolute()) + continue - if config.use_wit(): - mp3_file_path = str(wit_file_utils.convert_to_mp3(file['file_path']).absolute()) - recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(mp3_file_path, config.wit) - else: - recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( - file_path, - model, - config.whisper, - ) + file_path = str(file['file_path'].absolute()) - while True: - try: - new_progress_info.update(next(recognize_generator)) - yield new_progress_info, [] - except StopIteration as exception: - segments = exception.value - break + if config.use_wit(): + mp3_file_path = str(wit_file_utils.convert_to_mp3(file['file_path']).absolute()) + recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(mp3_file_path, config.wit) + else: + recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( + file_path, + model, + config.whisper, + ) + + while True: + try: + new_progress_info.update(next(recognize_generator)) + yield new_progress_info, [] + except StopIteration as exception: + segments: list[SegmentType] = exception.value + break - if config.use_wit() and file['file_path'].suffix != '.mp3': - Path(mp3_file_path).unlink(missing_ok=True) + if config.use_wit() and file['file_path'].suffix != '.mp3': + Path(mp3_file_path).unlink(missing_ok=True) - writer.write_all(Path(file['file_name']).stem, segments, config.output) + writer.write_all(Path(file['file_name']).stem, segments, config.output) - for segment in segments: - segment['url'] = f"file://{file_path}&t={int(segment['start'])}" - segment['file_path'] = file_path + for segment in segments: + segment['url'] = f"file://{file_path}&t={int(segment['start'])}" + segment['file_path'] = file_path - new_progress_info['inner_status'] = 'completed' - new_progress_info['progress'] = 100.0 - yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) + new_progress_info['inner_status'] = 'completed' + new_progress_info['progress'] = 100.0 + yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) def process_url( - url: str, - model: 'WhisperModel', - config: Config, - progress_info: dict, -) -> Generator[tuple[dict[str, int], list[list[dict[str, str | float]]]], None, None]: - url_data = Downloader(playlist_items=config.input.playlist_items, output_dir=config.output.output_dir).download( - url, - save_response=config.output.save_yt_dlp_responses, + url: str, + model: 'WhisperModel', + config: Config, + progress_info: dict, +) -> Generator[tuple[dict[str, Any], list[SegmentType]], None, None]: + url_data = Downloader(playlist_items=config.input.playlist_items, output_dir=config.output.output_dir).download( + url, + save_response=config.output.save_yt_dlp_responses, + ) + + if '_type' in url_data and url_data['_type'] == 'playlist': + elements = url_data['entries'] + else: + elements = [url_data] + + for idx, element in enumerate(tqdm(elements, desc='URL elements')): + if not element: + continue + + new_progress_info = progress_info.copy() + new_progress_info.update( + { + 'inner_total': len(elements), + 'inner_current': idx + 1, + 'inner_status': 'processing', + 'progress': 0.0, + 'remaining_time': None, + } ) + yield new_progress_info, [] - if '_type' in url_data and url_data['_type'] == 'playlist': - url_data = url_data['entries'] - else: - url_data = [url_data] - - for idx, element in enumerate(tqdm(url_data, desc='URL elements')): - if not element: - continue - - new_progress_info = progress_info.copy() - new_progress_info.update( - { - 'inner_total': len(url_data), - 'inner_current': idx + 1, - 'inner_status': 'processing', - 'progress': 0.0, - 'remaining_time': None, - } - ) - yield new_progress_info, [] - - writer = Writer() - if config.input.skip_if_output_exist and writer.is_output_exist(element['id'], config.output): - new_progress_info['inner_status'] = 'completed' - yield new_progress_info, [] + writer = Writer() + if config.input.skip_if_output_exist and writer.is_output_exist(element['id'], config.output): + new_progress_info['inner_status'] = 'completed' + yield new_progress_info, [] - continue + continue - file_path = os.path.join(config.output.output_dir, f"{element['id']}.mp3") + file_path = os.path.join(config.output.output_dir, f"{element['id']}.mp3") - if config.use_wit(): - recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit) - else: - recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( - file_path, - model, - config.whisper, - ) + if config.use_wit(): + recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit) + else: + recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( + file_path, + model, + config.whisper, + ) + + while True: + try: + new_progress_info.update(next(recognize_generator)) + yield new_progress_info, [] + except StopIteration as exception: + segments: list[SegmentType] = exception.value + break - while True: - try: - new_progress_info.update(next(recognize_generator)) - yield new_progress_info, [] - except StopIteration as exception: - segments = exception.value - break + writer.write_all(element['id'], segments, config.output) - writer.write_all(element['id'], segments, config.output) + for segment in segments: + segment['url'] = f"https://youtube.com/watch?v={element['id']}&t={int(segment['start'])}" + segment['file_path'] = file_path - for segment in segments: - segment['url'] = f"https://youtube.com/watch?v={element['id']}&t={int(segment['start'])}" - segment['file_path'] = file_path + new_progress_info['inner_status'] = 'completed' + new_progress_info['progress'] = 100.0 + yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) - new_progress_info['inner_status'] = 'completed' - new_progress_info['progress'] = 100.0 - yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) +def write_output_sample(segments: list[SegmentType], output: Config.Output) -> None: + if output.output_sample == 0: + return -def write_output_sample(segments: list[dict[str, str | float]], output: Config.Output) -> None: - if output.output_sample == 0: - return + random.shuffle(segments) - random.shuffle(segments) + with open(os.path.join(output.output_dir, 'sample.csv'), 'w') as fp: + writer = csv.DictWriter(fp, fieldnames=['start', 'end', 'text', 'url', 'file_path']) + writer.writeheader() - with open(os.path.join(output.output_dir, 'sample.csv'), 'w') as fp: - writer = csv.DictWriter(fp, fieldnames=['start', 'end', 'text', 'url', 'file_path']) - writer.writeheader() + for segment in segments[:output.output_sample]: + formatted_start = time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',') + formatted_end = time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',') - for segment in segments[: output.output_sample]: - segment['start'] = time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',') - segment['end'] = time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',') - writer.writerow(segment) + writer.writerow({ + 'start': formatted_start, + 'end': formatted_end, + 'text': segment['text'], + 'url': segment['url'], + 'file_path': segment['file_path'], + }) diff --git a/src/config.py b/src/config.py index 0b44601..ad19e8d 100644 --- a/src/config.py +++ b/src/config.py @@ -4,115 +4,111 @@ class Config: + def __init__( + self, + urls_or_paths: list[str], + skip_if_output_exist: bool, + playlist_items: str, + verbose: bool, + model_name_or_path: str, + task: str, + language: str, + use_faster_whisper: bool, + beam_size: int, + ct2_compute_type: str, + wit_client_access_tokens: list[str], + max_cutting_duration: int, + min_words_per_segment: int, + save_files_before_compact: bool, + save_yt_dlp_responses: bool, + output_sample: int, + output_formats: list[str], + output_dir: str, + ): + self.input = self.Input(urls_or_paths, skip_if_output_exist, playlist_items, verbose) + + self.whisper = self.Whisper( + model_name_or_path, + task, + language, + use_faster_whisper, + beam_size, + ct2_compute_type, + ) + + self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) + + self.output = self.Output( + min_words_per_segment, + save_files_before_compact, + save_yt_dlp_responses, + output_sample, + output_formats, + output_dir, + ) + + def use_wit(self) -> bool: + return self.wit.wit_client_access_tokens is not None and self.wit.wit_client_access_tokens != [] + + class Input: + def __init__(self, urls_or_paths: list[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): + self.urls_or_paths = urls_or_paths + self.skip_if_output_exist = skip_if_output_exist + self.playlist_items = playlist_items + self.verbose = verbose + + class Whisper: def __init__( - self, - urls_or_paths: list[str], - skip_if_output_exist: bool, - playlist_items: str, - verbose: bool, - model_name_or_path: str, - task: str, - language: str, - use_faster_whisper: bool, - beam_size: int, - ct2_compute_type: str, - wit_client_access_tokens: list[str], - max_cutting_duration: int, - min_words_per_segment: int, - save_files_before_compact: bool, - save_yt_dlp_responses: bool, - output_sample: int, - output_formats: list[str], - output_dir: str, + self, + model_name_or_path: str, + task: str, + language: str, + use_faster_whisper: bool, + beam_size: int, + ct2_compute_type: str, ): - self.input = self.Input(urls_or_paths, skip_if_output_exist, playlist_items, verbose) - - self.whisper = self.Whisper( - model_name_or_path, - task, - language, - use_faster_whisper, - beam_size, - ct2_compute_type, - ) - - self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) - - self.output = self.Output( - min_words_per_segment, - save_files_before_compact, - save_yt_dlp_responses, - output_sample, - output_formats, - output_dir, - ) - - def use_wit(self) -> bool: - return self.wit.wit_client_access_tokens is not None and self.wit.wit_client_access_tokens != [] - - class Input: - def __init__(self, urls_or_paths: list[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): - self.urls_or_paths = urls_or_paths - self.skip_if_output_exist = skip_if_output_exist - self.playlist_items = playlist_items - self.verbose = verbose - - class Whisper: - def __init__( - self, - model_name_or_path: str, - task: str, - language: str, - use_faster_whisper: bool, - beam_size: int, - ct2_compute_type: str, - ): - if model_name_or_path.endswith('.en'): - logging.warn(f'{model_name_or_path} is an English-only model, setting language to English.') - language = 'en' - - self.model_name_or_path = model_name_or_path - self.task = task - self.language = language - self.use_faster_whisper = use_faster_whisper - self.beam_size = beam_size - self.ct2_compute_type = ct2_compute_type - - class Wit: - def __init__(self, wit_client_access_tokens: list[str], max_cutting_duration: int): - if wit_client_access_tokens is None: - self.wit_client_access_tokens = None - else: - self.wit_client_access_tokens = [ - key for key in wit_client_access_tokens if key is not None and key != '' - ] - - self.max_cutting_duration = max_cutting_duration - - class Output: - def __init__( - self, - min_words_per_segment: int, - save_files_before_compact: bool, - save_yt_dlp_responses: bool, - output_sample: int, - output_formats: list[str], - output_dir: str, - ): - if 'all' in output_formats: - output_formats = list(TranscriptType) - else: - output_formats = [TranscriptType(output_format) for output_format in output_formats] - - if TranscriptType.ALL in output_formats: - output_formats.remove(TranscriptType.ALL) - - if TranscriptType.NONE in output_formats: - output_formats.remove(TranscriptType.NONE) - - self.min_words_per_segment = min_words_per_segment - self.save_files_before_compact = save_files_before_compact - self.save_yt_dlp_responses = save_yt_dlp_responses - self.output_sample = output_sample - self.output_formats = output_formats - self.output_dir = output_dir + if model_name_or_path.endswith('.en'): + logging.warn(f'{model_name_or_path} is an English-only model, setting language to English.') + language = 'en' + + self.model_name_or_path = model_name_or_path + self.task = task + self.language = language + self.use_faster_whisper = use_faster_whisper + self.beam_size = beam_size + self.ct2_compute_type = ct2_compute_type + + class Wit: + def __init__(self, wit_client_access_tokens: list[str] | None, max_cutting_duration: int): + if wit_client_access_tokens is None: + self.wit_client_access_tokens = None + else: + self.wit_client_access_tokens = [key for key in wit_client_access_tokens if key is not None and key != ''] + + self.max_cutting_duration = max_cutting_duration + + class Output: + def __init__( + self, + min_words_per_segment: int, + save_files_before_compact: bool, + save_yt_dlp_responses: bool, + output_sample: int, + output_formats: list[str], + output_dir: str, + ): + if 'all' in output_formats: + output_formats = list(TranscriptType) + + if TranscriptType.ALL in output_formats: + output_formats.remove(str(TranscriptType.ALL)) + + if TranscriptType.NONE in output_formats: + output_formats.remove(str(TranscriptType.NONE)) + + self.min_words_per_segment = min_words_per_segment + self.save_files_before_compact = save_files_before_compact + self.save_yt_dlp_responses = save_yt_dlp_responses + self.output_sample = output_sample + self.output_formats = output_formats + self.output_dir = output_dir diff --git a/src/downloader.py b/src/downloader.py index 026dcbc..1f4839c 100644 --- a/src/downloader.py +++ b/src/downloader.py @@ -7,52 +7,52 @@ class Downloader: - def __init__(self, playlist_items: str, output_dir: str): - self.playlist_items = playlist_items - self.output_dir = output_dir - self.youtube_dl_with_archive = yt_dlp.YoutubeDL(self._config(os.path.join(self.output_dir, 'archive.txt'))) - self.youtube_dl_without_archive = yt_dlp.YoutubeDL(self._config(False)) - - def _config(self, download_archive: str | bool) -> dict[str, Any]: - return { - 'quiet': True, - 'verbose': False, - 'format': 'bestaudio', - 'extract_audio': True, - 'outtmpl': os.path.join(self.output_dir, '%(id)s.%(ext)s'), - 'ignoreerrors': True, - 'download_archive': download_archive, - 'playlist_items': self.playlist_items, - 'postprocessors': [ - { - 'key': 'FFmpegExtractAudio', - 'preferredcodec': 'mp3', - }, - ], - } - - def download(self, url: str, save_response: bool = False) -> dict[str, Any]: - self.youtube_dl_with_archive.download(url) - url_data = self.youtube_dl_without_archive.extract_info(url, download=False) - - if save_response: - self._save_response(url_data) - - return url_data - - def _save_response(self, url_data: dict[str, Any]) -> None: - if '_type' in url_data and url_data['_type'] == 'playlist': - for entry in url_data['entries']: - if entry and 'requested_downloads' in entry: - self._remove_postprocessors(entry['requested_downloads']) - elif 'requested_downloads' in url_data: - self._remove_postprocessors(url_data['requested_downloads']) - - file_path = os.path.join(self.output_dir, f"{url_data['id']}.json") - - with open(file_path, 'w', encoding='utf-8') as fp: - json.dump(url_data, fp, indent=2, ensure_ascii=False) - - def _remove_postprocessors(self, requested_downloads: list[dict[str, Any]]) -> None: - for requested_download in requested_downloads: - requested_download.pop('__postprocessors') + def __init__(self, playlist_items: str, output_dir: str): + self.playlist_items = playlist_items + self.output_dir = output_dir + self.youtube_dl_with_archive = yt_dlp.YoutubeDL(self._config(os.path.join(self.output_dir, 'archive.txt'))) + self.youtube_dl_without_archive = yt_dlp.YoutubeDL(self._config(False)) + + def _config(self, download_archive: str | bool) -> dict[str, Any]: + return { + 'quiet': True, + 'verbose': False, + 'format': 'bestaudio', + 'extract_audio': True, + 'outtmpl': os.path.join(self.output_dir, '%(id)s.%(ext)s'), + 'ignoreerrors': True, + 'download_archive': download_archive, + 'playlist_items': self.playlist_items, + 'postprocessors': [ + { + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'mp3', + }, + ], + } + + def download(self, url: str, save_response: bool = False) -> dict[str, Any]: + self.youtube_dl_with_archive.download(url) + url_data = self.youtube_dl_without_archive.extract_info(url, download=False) + + if save_response: + self._save_response(url_data) + + return url_data + + def _save_response(self, url_data: dict[str, Any]) -> None: + if '_type' in url_data and url_data['_type'] == 'playlist': + for entry in url_data['entries']: + if entry and 'requested_downloads' in entry: + self._remove_postprocessors(entry['requested_downloads']) + elif 'requested_downloads' in url_data: + self._remove_postprocessors(url_data['requested_downloads']) + + file_path = os.path.join(self.output_dir, f"{url_data['id']}.json") + + with open(file_path, 'w', encoding='utf-8') as fp: + json.dump(url_data, fp, indent=2, ensure_ascii=False) + + def _remove_postprocessors(self, requested_downloads: list[dict[str, Any]]) -> None: + for requested_download in requested_downloads: + requested_download.pop('__postprocessors') diff --git a/src/recognizers/whisper_recognizer.py b/src/recognizers/whisper_recognizer.py index 7870002..e27329e 100644 --- a/src/recognizers/whisper_recognizer.py +++ b/src/recognizers/whisper_recognizer.py @@ -1,6 +1,6 @@ import warnings -from typing import Generator +from typing import Any, Generator import faster_whisper import whisper @@ -8,97 +8,98 @@ from tqdm import tqdm from src.config import Config +from src.types.segment_type import SegmentType from src.types.whisper.type_hints import WhisperModel class WhisperRecognizer: - def __init__(self, verbose: bool): - self.verbose = verbose - - def recognize( - self, - file_path: str, - model: WhisperModel, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, str | float]]]: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - - if isinstance(model, whisper.Whisper): - whisper_generator = self._recognize_stable_whisper(file_path, model, whisper_config) - elif isinstance(model, faster_whisper.WhisperModel): - whisper_generator = self._recognize_faster_whisper(file_path, model, whisper_config) - - while True: - try: - yield next(whisper_generator) - except StopIteration as e: - return e.value - - def _recognize_stable_whisper( - self, - audio_file_path: str, - model: whisper.Whisper, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, str | float]]]: - yield {'progress': 0.0, 'remaining_time': None} - - segments = model.transcribe( - audio=audio_file_path, - verbose=self.verbose, - task=whisper_config.task, - language=whisper_config.language, - beam_size=whisper_config.beam_size, - ).segments - - return [ - { - 'start': segment.start, - 'end': segment.end, - 'text': segment.text.strip(), - } - for segment in segments - ] - - def _recognize_faster_whisper( - self, - audio_file_path: str, - model: faster_whisper.WhisperModel, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, str | float]]]: - segments, info = model.transcribe( - audio=audio_file_path, - task=whisper_config.task, - language=whisper_config.language, - beam_size=whisper_config.beam_size, + def __init__(self, verbose: bool): + self.verbose = verbose + + def recognize( + self, + file_path: str, + model: WhisperModel, + whisper_config: Config.Whisper, + ) -> Generator[dict[str, float], None, list[SegmentType]]: + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + + if isinstance(model, whisper.Whisper): + whisper_generator = self._recognize_stable_whisper(file_path, model, whisper_config) + elif isinstance(model, faster_whisper.WhisperModel): + whisper_generator = self._recognize_faster_whisper(file_path, model, whisper_config) + + while True: + try: + yield next(whisper_generator) + except StopIteration as e: + return e.value + + def _recognize_stable_whisper( + self, + audio_file_path: str, + model: whisper.Whisper, + whisper_config: Config.Whisper, + ) -> Generator[dict[str, Any], None, list[SegmentType]]: + yield {'progress': 0.0, 'remaining_time': None} + + segments = model.transcribe( + audio=audio_file_path, + verbose=self.verbose, + task=whisper_config.task, + language=whisper_config.language, + beam_size=whisper_config.beam_size, + ).segments + + return [ + SegmentType( + text=segment.text.strip(), + start=segment.start, + end=segment.end, + ) + for segment in segments + ] + + def _recognize_faster_whisper( + self, + audio_file_path: str, + model: faster_whisper.WhisperModel, + whisper_config: Config.Whisper, + ) -> Generator[dict[str, float], None, list[SegmentType]]: + segments, info = model.transcribe( + audio=audio_file_path, + task=whisper_config.task, + language=whisper_config.language, + beam_size=whisper_config.beam_size, + ) + + converted_segments = [] + last_end = 0 + with tqdm( + total=round(info.duration, 2), + unit='sec', + bar_format='{desc}: {percentage:.2f}%|{bar}| {n:.2f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]', + disable=self.verbose is not False, + ) as pbar: + for segment in segments: + converted_segments.append( + SegmentType( + start=segment.start, + end=segment.end, + text=segment.text.strip(), + ) ) - converted_segments = [] - last_end = 0 - with tqdm( - total=round(info.duration, 2), - unit='sec', - bar_format='{desc}: {percentage:.2f}%|{bar}| {n:.2f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]', - disable=self.verbose is not False, - ) as pbar: - for segment in segments: - converted_segments.append( - { - 'start': segment.start, - 'end': segment.end, - 'text': segment.text.strip(), - } - ) - - pbar_update = min(segment.end - last_end, info.duration - pbar.n) - pbar.update(pbar_update) - last_end = segment.end - - yield { - 'progress': round(pbar.n / pbar.total * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] - if pbar.format_dict['rate'] and pbar.total - else None, - } - - return converted_segments + pbar_update = min(segment.end - last_end, info.duration - pbar.n) + pbar.update(pbar_update) + last_end = segment.end + + yield { + 'progress': round(pbar.n / pbar.total * 100, 2), + 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] + if pbar.format_dict['rate'] and pbar.total + else None, + } + + return converted_segments diff --git a/src/recognizers/wit_calling_throttle.py b/src/recognizers/wit_calling_throttle.py index 18dde0c..1c724d8 100644 --- a/src/recognizers/wit_calling_throttle.py +++ b/src/recognizers/wit_calling_throttle.py @@ -5,34 +5,34 @@ class WitCallingThrottle: - def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1): - self.wit_client_access_tokens_count = wit_client_access_tokens_count - self.call_times_limit = call_times_limit - self.expired_time = expired_time - self.call_timestamps = [[] for _ in range(self.wit_client_access_tokens_count)] - self.locks = [Lock() for _ in range(self.wit_client_access_tokens_count)] + def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1): + self.wit_client_access_tokens_count = wit_client_access_tokens_count + self.call_times_limit = call_times_limit + self.expired_time = expired_time + self.call_timestamps: list[list[float]] = [[] for _ in range(self.wit_client_access_tokens_count)] + self.locks = [Lock() for _ in range(self.wit_client_access_tokens_count)] - def throttle(self, wit_client_access_token_index: int) -> None: - with self.locks[wit_client_access_token_index]: - while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: - now = time.time() + def throttle(self, wit_client_access_token_index: int) -> None: + with self.locks[wit_client_access_token_index]: + while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + now = time.time() - self.call_timestamps[wit_client_access_token_index] = list( - filter( - lambda call_timestamp, now=now: now - call_timestamp < self.expired_time, - self.call_timestamps[wit_client_access_token_index], - ) - ) + self.call_timestamps[wit_client_access_token_index] = list( + filter( + lambda call_timestamp: now - call_timestamp < self.expired_time, + self.call_timestamps[wit_client_access_token_index], + ) + ) - if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: - time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now - time.sleep(time_to_sleep) + if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now + time.sleep(time_to_sleep) - self.call_timestamps[wit_client_access_token_index].append(time.time()) + self.call_timestamps[wit_client_access_token_index].append(time.time()) class WitCallingThrottleManager(BaseManager): - pass + pass WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle) diff --git a/src/recognizers/wit_recognizer.py b/src/recognizers/wit_recognizer.py index 0242f40..57d1335 100644 --- a/src/recognizers/wit_recognizer.py +++ b/src/recognizers/wit_recognizer.py @@ -6,7 +6,7 @@ import tempfile import time -from typing import Generator +from typing import Generator, cast import requests @@ -17,142 +17,142 @@ from src.audio_splitter import AudioSplitter from src.config import Config from src.recognizers.wit_calling_throttle import WitCallingThrottle, WitCallingThrottleManager +from src.types.segment_type import SegmentType def init_pool(throttle: WitCallingThrottle) -> None: - global wit_calling_throttle + global wit_calling_throttle - wit_calling_throttle = throttle + wit_calling_throttle = throttle # type: ignore class WitRecognizer: - def __init__(self, verbose: bool): - self.verbose = verbose - self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count()) - - def recognize( - self, - file_path: str, - wit_config: Config.Wit, - ) -> Generator[dict[str, float], None, list[dict[str, str | float]]]: - temp_directory = tempfile.mkdtemp() - - segments = AudioSplitter().split( - file_path, - temp_directory, - max_dur=wit_config.max_cutting_duration, - expand_segments_with_noise=True, + def __init__(self, verbose: bool): + self.verbose = verbose + self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count()) + + def recognize( + self, + file_path: str, + wit_config: Config.Wit, + ) -> Generator[dict[str, float], None, list[SegmentType]]: + temp_directory = tempfile.mkdtemp() + + segments = AudioSplitter().split( + file_path, + temp_directory, + max_dur=wit_config.max_cutting_duration, + expand_segments_with_noise=True, + ) + + retry_strategy = Retry( + total=5, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=['POST'], + backoff_factor=1, + ) + + adapter = HTTPAdapter(max_retries=retry_strategy) + + session = requests.Session() + session.mount('https://', adapter) + + pool_processes_count = min( + self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens or []), + multiprocessing.cpu_count(), + ) + + with WitCallingThrottleManager() as manager: + wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)) # type: ignore + + with multiprocessing.Pool( + processes=pool_processes_count, + initializer=init_pool, + initargs=(wit_calling_throttle,), + ) as pool: + async_results = [ + pool.apply_async( + self._process_segment, + ( + segment, + file_path, + wit_config, + session, + index % len(wit_config.wit_client_access_tokens or []), + ), + ) + for index, segment in enumerate(segments) + ] + + transcriptions = [] + + with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: + for async_result in async_results: + async_result.wait() + pbar.update(1) + + transcriptions.append(async_result.get()) + + yield { + 'progress': round(len(transcriptions) / len(segments) * 100, 2), + 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] + if pbar.format_dict['rate'] and pbar.total + else None, + } + + shutil.rmtree(temp_directory) + + return transcriptions + + def _process_segment( + self, + segment: tuple[str, float, float], + file_path: str, + wit_config: Config.Wit, + session: requests.Session, + wit_client_access_token_index: int, + ) -> SegmentType: + wit_calling_throttle.throttle(wit_client_access_token_index) # type: ignore + + segment_file_path, start, end = segment + + with open(segment_file_path, 'rb') as mp3_file: + audio_content = mp3_file.read() + + retries = 5 + + text = '' + while retries > 0: + try: + response = session.post( + 'https://api.wit.ai/speech', + headers={ + 'Accept': 'application/vnd.wit.20200513+json', + 'Content-Type': 'audio/mpeg3', + 'Authorization': f'Bearer {cast(list[str], wit_config.wit_client_access_tokens)[wit_client_access_token_index]}', + }, + data=audio_content, ) - retry_strategy = Retry( - total=5, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=['POST'], - backoff_factor=1, - ) - - adapter = HTTPAdapter(max_retries=retry_strategy) - - session = requests.Session() - session.mount('https://', adapter) - - pool_processes_count = min( - self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens), - multiprocessing.cpu_count(), - ) - - with WitCallingThrottleManager() as manager: - wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)) - - with multiprocessing.Pool( - processes=pool_processes_count, - initializer=init_pool, - initargs=(wit_calling_throttle,), - ) as pool: - async_results = [ - pool.apply_async( - self._process_segment, - ( - segment, - file_path, - wit_config, - session, - index % len(wit_config.wit_client_access_tokens), - ), - ) - for index, segment in enumerate(segments) - ] - - transcriptions = [] - - with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: - for async_result in async_results: - async_result.wait() - pbar.update(1) - - transcriptions.append(async_result.get()) - - yield { - 'progress': round(len(transcriptions) / len(segments) * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] - if pbar.format_dict['rate'] and pbar.total - else None, - } - - shutil.rmtree(temp_directory) - - return transcriptions - - def _process_segment( - self, - segment: tuple[str, float, float], - file_path: str, - wit_config: Config.Wit, - session: requests.Session, - wit_client_access_token_index: int, - ) -> dict[str, str | float]: - wit_calling_throttle.throttle(wit_client_access_token_index) - - segment_file_path, start, end = segment - - with open(segment_file_path, 'rb') as mp3_file: - audio_content = mp3_file.read() - - retries = 5 - - text = '' - while retries > 0: - try: - response = session.post( - 'https://api.wit.ai/speech', - headers={ - 'Accept': 'application/vnd.wit.20200513+json', - 'Content-Type': 'audio/mpeg3', - 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}', - }, - data=audio_content, - ) - - if response.status_code == 200: - text = json.loads(response.text)['text'] - break - else: - retries -= 1 - time.sleep(self.processes_per_wit_client_access_token + 1) - except Exception: - retries -= 1 - time.sleep(self.processes_per_wit_client_access_token + 1) - - if retries == 0: - logging.warn( - f"The segment from `{file_path}` file that starts at {start} and ends at {end}" - " didn't transcribed successfully." - ) - - os.remove(segment_file_path) - - return { - 'start': start, - 'end': end, - 'text': text.strip(), - } + if response.status_code == 200: + text = json.loads(response.text)['text'] + break + else: + retries -= 1 + time.sleep(self.processes_per_wit_client_access_token + 1) + except Exception: + retries -= 1 + time.sleep(self.processes_per_wit_client_access_token + 1) + + if retries == 0: + logging.warn( + f"The segment from `{file_path}` file that starts at {start} and ends at {end} didn't transcribed successfully." + ) + + os.remove(segment_file_path) + + return SegmentType( + text=text.strip(), + start=start, + end=end, + ) diff --git a/src/types/segment_type.py b/src/types/segment_type.py new file mode 100644 index 0000000..4bcb115 --- /dev/null +++ b/src/types/segment_type.py @@ -0,0 +1,9 @@ +from typing import TypedDict, NotRequired + + +class SegmentType(TypedDict): + text: str + start: float + end: float + url: NotRequired[str] + file_path: NotRequired[str] diff --git a/src/types/transcript_type.py b/src/types/transcript_type.py index 3cfa105..9043ee5 100644 --- a/src/types/transcript_type.py +++ b/src/types/transcript_type.py @@ -2,14 +2,14 @@ class TranscriptType(Enum): - ALL = 'all' - TXT = 'txt' - SRT = 'srt' - VTT = 'vtt' - CSV = 'csv' - TSV = 'tsv' - JSON = 'json' - NONE = 'none' + ALL = 'all' + TXT = 'txt' + SRT = 'srt' + VTT = 'vtt' + CSV = 'csv' + TSV = 'tsv' + JSON = 'json' + NONE = 'none' - def __str__(self): - return self.value + def __str__(self): + return self.value diff --git a/src/types/whisper/type_hints.py b/src/types/whisper/type_hints.py index cd13ed9..0fba876 100644 --- a/src/types/whisper/type_hints.py +++ b/src/types/whisper/type_hints.py @@ -5,7 +5,7 @@ WhisperModel = TypeVar( - 'WhisperModel', - whisper.Whisper, - faster_whisper.WhisperModel, + 'WhisperModel', + whisper.Whisper, + faster_whisper.WhisperModel, ) diff --git a/src/utils/cli_utils.py b/src/utils/cli_utils.py index fa82c8f..afa6972 100644 --- a/src/utils/cli_utils.py +++ b/src/utils/cli_utils.py @@ -6,195 +6,195 @@ PLAYLIST_ITEMS_RE = re.compile( - r'''(?x) - (?P[+-]?\d+)? - (?P[:-] - (?P[+-]?\d+|inf(?:inite)?)? - (?::(?P[+-]?\d+))? - )?''' + r'''(?x) + (?P[+-]?\d+)? + (?P[:-] + (?P[+-]?\d+|inf(?:inite)?)? + (?::(?P[+-]?\d+))? + )?''' ) def parse_args(argv: list[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser() - - parser.add_argument( - '--version', - action='version', - version=importlib.metadata.version('tafrigh'), - ) - - input_group = parser.add_argument_group('Input') - - input_group.add_argument( - 'urls_or_paths', - nargs='+', - help='Video/Playlist URLs or local folder/file(s) to transcribe.', - ) - - input_group.add_argument( - '--skip_if_output_exist', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to skip generating the output if the output file already exists.', - ) - - input_group.add_argument( - '--playlist_items', - type=parse_playlist_items, - help='Comma separated playlist_index of the items to download. You can specify a range using "[START]:[STOP][:STEP]".', - ) - - input_group.add_argument( - '--verbose', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to print out the progress and debug messages.', - ) - - whisper_group = parser.add_argument_group('Whisper') - - whisper_group.add_argument( - '-m', - '--model_name_or_path', - default='small', - help='Name or path of the Whisper model to use.', - ) - - whisper_group.add_argument( - '-t', - '--task', - default='transcribe', - choices=[ - 'transcribe', - 'translate', - ], - help="Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate').", - ) - - whisper_group.add_argument( - '-l', - '--language', - default=None, - choices=['af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de'] - + ['el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu'] - + ['hy', 'id', 'is', 'it', 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv', 'mg'] - + ['mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no', 'oc', 'pa', 'pl', 'ps', 'pt', 'ro'] - + ['ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk'] - + ['tl', 'tr', 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'], - help='Language spoken in the audio, skip to perform language detection.', - ) - - whisper_group.add_argument( - '--use_faster_whisper', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to use Faster Whisper implementation.', - ) - - whisper_group.add_argument( - '--beam_size', - type=int, - default=5, - help='Number of beams in beam search, only applicable when temperature is zero.', - ) - - whisper_group.add_argument( - '--ct2_compute_type', - default='default', - choices=[ - 'default', - 'int8', - 'int8_float16', - 'int16', - 'float16', - ], - help='Quantization type applied while converting the model to CTranslate2 format.', - ) - - wit_group = parser.add_argument_group('Wit') - - wit_group.add_argument( - '-w', - '--wit_client_access_tokens', - nargs='+', - help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', - ) - - wit_group.add_argument( - '--max_cutting_duration', - type=int, - default=15, - choices=range(1, 17), - metavar='[1-17]', - help='The maximum allowed cutting duration. It should be between 1 and 17.', - ) - - output_group = parser.add_argument_group('Output') - - output_group.add_argument( - '--min_words_per_segment', - type=int, - default=1, - help='The minimum number of words should appear in each transcript segment. Any segment have words count less than this threshold will be merged with the next one. Pass 0 to disable this behavior.', - ) - - output_group.add_argument( - '--save_files_before_compact', - action=argparse.BooleanOptionalAction, - default=False, - help='Saves the output files before applying the compact logic that is based on --min_words_per_segment.', - ) - - output_group.add_argument( - '--save_yt_dlp_responses', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to save the yt-dlp library JSON responses or not.', - ) - - output_group.add_argument( - '--output_sample', - type=int, - default=0, - help='Samples random compacted segments from the output and generates a CSV file contains the sampled data. Pass 0 to disable this behavior.', - ) - - output_group.add_argument( - '-f', - '--output_formats', - nargs='+', - default='all', - choices=[transcript_type.value for transcript_type in TranscriptType], - help='Format of the output file; if not specified, all available formats will be produced.', - ) - - output_group.add_argument('-o', '--output_dir', default='.', help='Directory to save the outputs.') - - return parser.parse_args(argv) + parser = argparse.ArgumentParser() + + parser.add_argument( + '--version', + action='version', + version=importlib.metadata.version('tafrigh'), + ) + + input_group = parser.add_argument_group('Input') + + input_group.add_argument( + 'urls_or_paths', + nargs='+', + help='Video/Playlist URLs or local folder/file(s) to transcribe.', + ) + + input_group.add_argument( + '--skip_if_output_exist', + action=argparse.BooleanOptionalAction, + default=False, + help='Whether to skip generating the output if the output file already exists.', + ) + + input_group.add_argument( + '--playlist_items', + type=parse_playlist_items, + help='Comma separated playlist_index of the items to download. You can specify a range using "[START]:[STOP][:STEP]".', + ) + + input_group.add_argument( + '--verbose', + action=argparse.BooleanOptionalAction, + default=False, + help='Whether to print out the progress and debug messages.', + ) + + whisper_group = parser.add_argument_group('Whisper') + + whisper_group.add_argument( + '-m', + '--model_name_or_path', + default='small', + help='Name or path of the Whisper model to use.', + ) + + whisper_group.add_argument( + '-t', + '--task', + default='transcribe', + choices=[ + 'transcribe', + 'translate', + ], + help="Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate').", + ) + + whisper_group.add_argument( + '-l', + '--language', + default=None, + choices=['af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de'] + + ['el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu'] + + ['hy', 'id', 'is', 'it', 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv', 'mg'] + + ['mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no', 'oc', 'pa', 'pl', 'ps', 'pt', 'ro'] + + ['ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk'] + + ['tl', 'tr', 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'], + help='Language spoken in the audio, skip to perform language detection.', + ) + + whisper_group.add_argument( + '--use_faster_whisper', + action=argparse.BooleanOptionalAction, + default=False, + help='Whether to use Faster Whisper implementation.', + ) + + whisper_group.add_argument( + '--beam_size', + type=int, + default=5, + help='Number of beams in beam search, only applicable when temperature is zero.', + ) + + whisper_group.add_argument( + '--ct2_compute_type', + default='default', + choices=[ + 'default', + 'int8', + 'int8_float16', + 'int16', + 'float16', + ], + help='Quantization type applied while converting the model to CTranslate2 format.', + ) + + wit_group = parser.add_argument_group('Wit') + + wit_group.add_argument( + '-w', + '--wit_client_access_tokens', + nargs='+', + help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', + ) + + wit_group.add_argument( + '--max_cutting_duration', + type=int, + default=15, + choices=range(1, 17), + metavar='[1-17]', + help='The maximum allowed cutting duration. It should be between 1 and 17.', + ) + + output_group = parser.add_argument_group('Output') + + output_group.add_argument( + '--min_words_per_segment', + type=int, + default=1, + help='The minimum number of words should appear in each transcript segment. Any segment have words count less than this threshold will be merged with the next one. Pass 0 to disable this behavior.', + ) + + output_group.add_argument( + '--save_files_before_compact', + action=argparse.BooleanOptionalAction, + default=False, + help='Saves the output files before applying the compact logic that is based on --min_words_per_segment.', + ) + + output_group.add_argument( + '--save_yt_dlp_responses', + action=argparse.BooleanOptionalAction, + default=False, + help='Whether to save the yt-dlp library JSON responses or not.', + ) + + output_group.add_argument( + '--output_sample', + type=int, + default=0, + help='Samples random compacted segments from the output and generates a CSV file contains the sampled data. Pass 0 to disable this behavior.', + ) + + output_group.add_argument( + '-f', + '--output_formats', + nargs='+', + default='all', + choices=[transcript_type.value for transcript_type in TranscriptType], + help='Format of the output file; if not specified, all available formats will be produced.', + ) + + output_group.add_argument('-o', '--output_dir', default='.', help='Directory to save the outputs.') + + return parser.parse_args(argv) def parse_playlist_items(arg_value: str) -> str: - for segment in arg_value.split(','): - if not segment: - raise ValueError('There is two or more consecutive commas.') + for segment in arg_value.split(','): + if not segment: + raise ValueError('There is two or more consecutive commas.') - mobj = PLAYLIST_ITEMS_RE.fullmatch(segment) - if not mobj: - raise ValueError(f'{segment!r} is not a valid specification.') + mobj = PLAYLIST_ITEMS_RE.fullmatch(segment) + if not mobj: + raise ValueError(f'{segment!r} is not a valid specification.') - _, _, step, _ = mobj.group('start', 'end', 'step', 'range') - if int_or_none(step) == 0: - raise ValueError(f'Step in {segment!r} cannot be zero.') + _, _, step, _ = mobj.group('start', 'end', 'step', 'range') + if int_or_none(step) == 0: + raise ValueError(f'Step in {segment!r} cannot be zero.') - return arg_value + return arg_value def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1): - if get_attr and v is not None: - v = getattr(v, get_attr, None) + if get_attr and v is not None: + v = getattr(v, get_attr, None) - try: - return int(v) * invscale // scale - except (ValueError, TypeError, OverflowError): - return default + try: + return int(v) * invscale // scale + except (ValueError, TypeError, OverflowError): + return default diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py index 00f0078..ab22838 100644 --- a/src/utils/file_utils.py +++ b/src/utils/file_utils.py @@ -7,14 +7,14 @@ def filter_media_files(paths: list[Path]) -> list[Path]: - # Filter out non audio or video files - filtered_media_files: list[str] = [] - for path in paths: - mime = mimetypes.guess_type(path)[0] - if mime is None: - continue - mime_type = mime.split('/')[0] - if mime_type not in ('audio', 'video'): - continue - filtered_media_files.append(path) - return filtered_media_files + # Filter out non audio or video files + filtered_media_files = [] + for path in paths: + mime = mimetypes.guess_type(path)[0] + if mime is None: + continue + mime_type = mime.split('/')[0] + if mime_type not in ('audio', 'video'): + continue + filtered_media_files.append(path) + return filtered_media_files diff --git a/src/utils/time_utils.py b/src/utils/time_utils.py index 4f113a4..0341866 100644 --- a/src/utils/time_utils.py +++ b/src/utils/time_utils.py @@ -1,15 +1,15 @@ def format_timestamp(seconds: float, include_hours: bool = False, decimal_marker: str = '.') -> str: - assert seconds >= 0, 'Non-negative timestamp expected' + assert seconds >= 0, 'Non-negative timestamp expected' - total_milliseconds = int(round(seconds * 1_000)) + total_milliseconds = int(round(seconds * 1_000)) - hours, total_milliseconds = divmod(total_milliseconds, 3_600_000) - minutes, total_milliseconds = divmod(total_milliseconds, 60_000) - seconds, milliseconds = divmod(total_milliseconds, 1_000) + hours, total_milliseconds = divmod(total_milliseconds, 3_600_000) + minutes, total_milliseconds = divmod(total_milliseconds, 60_000) + seconds, milliseconds = divmod(total_milliseconds, 1_000) - if include_hours or hours > 0: - time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" - else: - time_str = f"{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + if include_hours or hours > 0: + time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + else: + time_str = f"{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" - return time_str + return time_str diff --git a/src/utils/whisper/whisper_utils.py b/src/utils/whisper/whisper_utils.py index 83b033c..9ba1685 100644 --- a/src/utils/whisper/whisper_utils.py +++ b/src/utils/whisper/whisper_utils.py @@ -5,11 +5,11 @@ from src.types.whisper.type_hints import WhisperModel -def load_model(whisper_config: Config.Whisper) -> WhisperModel: - if whisper_config.use_faster_whisper: - return faster_whisper.WhisperModel( - whisper_config.model_name_or_path, - compute_type=whisper_config.ct2_compute_type, - ) - else: - return stable_whisper.load_model(whisper_config.model_name_or_path) +def load_model(whisper_config: Config.Whisper) -> WhisperModel: # type: ignore + if whisper_config.use_faster_whisper: + return faster_whisper.WhisperModel( + whisper_config.model_name_or_path, + compute_type=whisper_config.ct2_compute_type, + ) + else: + return stable_whisper.load_model(whisper_config.model_name_or_path) diff --git a/src/utils/wit/file_utils.py b/src/utils/wit/file_utils.py index d087d3d..88f9ae6 100644 --- a/src/utils/wit/file_utils.py +++ b/src/utils/wit/file_utils.py @@ -4,7 +4,7 @@ def convert_to_mp3(file: Path) -> Path: - audio_file = AudioSegment.from_file(str(file)) - converted_file_path = file.with_suffix('.mp3') - audio_file.export(str(converted_file_path), format='mp3') - return converted_file_path + audio_file = AudioSegment.from_file(str(file)) + converted_file_path = file.with_suffix('.mp3') + audio_file.export(str(converted_file_path), format='mp3') + return converted_file_path diff --git a/src/writer.py b/src/writer.py index 3944dd6..83c802e 100644 --- a/src/writer.py +++ b/src/writer.py @@ -5,157 +5,158 @@ from pathlib import Path from .config import Config +from .types.segment_type import SegmentType from .types.transcript_type import TranscriptType from .utils import time_utils class Writer: - def write_all( - self, - file_name: str, - segments: list[dict[str, str | float]], - output_config: Config.Output, - ) -> None: - if output_config.save_files_before_compact: - for output_format in output_config.output_formats: - self.write( - output_format, - os.path.join(output_config.output_dir, f'{file_name}-original.{output_format}'), - segments, - ) - - if not output_config.save_files_before_compact or output_config.min_words_per_segment != 0: - compacted_segments = self.compact_segments(segments, output_config.min_words_per_segment) - - for output_format in output_config.output_formats: - self.write( - output_format, - os.path.join(output_config.output_dir, f'{file_name}.{output_format}'), - compacted_segments, - ) - - def write( - self, - format: TranscriptType, - file_path: str, - segments: list[dict[str, str | float]], - ) -> None: - if format == TranscriptType.TXT: - self.write_txt(file_path, segments) - elif format == TranscriptType.SRT: - self.write_srt(file_path, segments) - elif format == TranscriptType.VTT: - self.write_vtt(file_path, segments) - elif format == TranscriptType.CSV: - self.write_csv(file_path, segments) - elif format == TranscriptType.TSV: - self.write_csv(file_path, segments, '\t') - elif format == TranscriptType.JSON: - self.write_json(file_path, segments) - - def write_txt( - self, - file_path: str, - segments: list[dict[str, str | float]], - ) -> None: - self._write_to_file(file_path, self.generate_txt(segments)) - - def write_srt( - self, - file_path: str, - segments: list[dict[str, str | float]], - ) -> None: - self._write_to_file(file_path, self.generate_srt(segments)) - - def write_vtt( - self, - file_path: str, - segments: list[dict[str, str | float]], - ) -> None: - self._write_to_file(file_path, self.generate_vtt(segments)) - - def write_csv( - self, - file_path: str, - segments: list[dict[str, str | float]], - delimiter=',', - ) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - writer = csv.DictWriter(fp, fieldnames=['text', 'start', 'end'], delimiter=delimiter) - writer.writeheader() - writer.writerows(segments) - - def write_json( - self, - file_path: str, - segments: list[dict[str, str | float]], - ) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - json.dump(segments, fp, ensure_ascii=False, indent=2) - - def generate_txt(self, segments: list[dict[str, str | float]]) -> str: - return '\n'.join(list(map(lambda segment: segment['text'].strip(), segments))) + '\n' - - def generate_srt(self, segments: list[dict[str, str | float]]) -> str: - return ''.join( - f"{i}\n" - f"{time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',')} --> " - f"{time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip()}\n\n" - for i, segment in enumerate(segments, start=1) + def write_all( + self, + file_name: str, + segments: list[SegmentType], + output_config: Config.Output, + ) -> None: + if output_config.save_files_before_compact: + for output_format in output_config.output_formats: + self.write( + TranscriptType(output_format), + os.path.join(output_config.output_dir, f'{file_name}-original.{output_format}'), + segments, ) - def generate_vtt(self, segments: list[dict[str, str | float]]) -> str: - return 'WEBVTT\n\n' + ''.join( - f"{time_utils.format_timestamp(segment['start'])} --> {time_utils.format_timestamp(segment['end'])}\n" - f"{segment['text'].strip()}\n\n" - for segment in segments + if not output_config.save_files_before_compact or output_config.min_words_per_segment != 0: + compacted_segments = self.compact_segments(segments, output_config.min_words_per_segment) + + for output_format in output_config.output_formats: + self.write( + TranscriptType(output_format), + os.path.join(output_config.output_dir, f'{file_name}.{output_format}'), + compacted_segments, ) - def compact_segments( - self, - segments: list[dict[str, str | float]], - min_words_per_segment: int, - ) -> list[dict[str, str | float]]: - if min_words_per_segment == 0: - return segments - - compacted_segments = [] - tmp_segment = None - - for segment in segments: - if tmp_segment: - tmp_segment['text'] += f" {segment['text'].strip()}" - tmp_segment['end'] = segment['end'] - - if len(tmp_segment['text'].split()) >= min_words_per_segment: - compacted_segments.append(tmp_segment) - tmp_segment = None - elif len(segment['text'].split()) < min_words_per_segment: - tmp_segment = dict(segment) - elif len(segment['text'].split()) >= min_words_per_segment: - compacted_segments.append(dict(segment)) - - if tmp_segment: - compacted_segments.append(tmp_segment) - - return compacted_segments - - def is_output_exist(self, file_name: str, output_config: Config.Output): - if output_config.save_files_before_compact and not all( - Path(output_config.output_dir, f'{file_name}-original.{output_format}').is_file() - for output_format in output_config.output_formats - ): - return False - - if (not output_config.save_files_before_compact or output_config.min_words_per_segment != 0) and not all( - Path(output_config.output_dir, f'{file_name}.{output_format}').is_file() - for output_format in output_config.output_formats - ): - return False - - return True - - def _write_to_file(self, file_path: str, content: str) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - fp.write(content) + def write( + self, + format: TranscriptType, + file_path: str, + segments: list[SegmentType], + ) -> None: + if format == TranscriptType.TXT: + self.write_txt(file_path, segments) + elif format == TranscriptType.SRT: + self.write_srt(file_path, segments) + elif format == TranscriptType.VTT: + self.write_vtt(file_path, segments) + elif format == TranscriptType.CSV: + self.write_csv(file_path, segments) + elif format == TranscriptType.TSV: + self.write_csv(file_path, segments, '\t') + elif format == TranscriptType.JSON: + self.write_json(file_path, segments) + + def write_txt( + self, + file_path: str, + segments: list[SegmentType], + ) -> None: + self._write_to_file(file_path, self.generate_txt(segments)) + + def write_srt( + self, + file_path: str, + segments: list[SegmentType], + ) -> None: + self._write_to_file(file_path, self.generate_srt(segments)) + + def write_vtt( + self, + file_path: str, + segments: list[SegmentType], + ) -> None: + self._write_to_file(file_path, self.generate_vtt(segments)) + + def write_csv( + self, + file_path: str, + segments: list[SegmentType], + delimiter=',', + ) -> None: + with open(file_path, 'w', encoding='utf-8') as fp: + writer = csv.DictWriter(fp, fieldnames=['text', 'start', 'end'], delimiter=delimiter) + writer.writeheader() + writer.writerows(segments) + + def write_json( + self, + file_path: str, + segments: list[SegmentType], + ) -> None: + with open(file_path, 'w', encoding='utf-8') as fp: + json.dump(segments, fp, ensure_ascii=False, indent=2) + + def generate_txt(self, segments: list[SegmentType]) -> str: + return '\n'.join(list(map(lambda segment: segment['text'].strip(), segments))) + '\n' + + def generate_srt(self, segments: list[SegmentType]) -> str: + return ''.join( + f'{i}\n' + f"{time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',')} --> " + f"{time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip()}\n\n" + for i, segment in enumerate(segments, start=1) + ) + + def generate_vtt(self, segments: list[SegmentType]) -> str: + return 'WEBVTT\n\n' + ''.join( + f"{time_utils.format_timestamp(segment['start'])} --> {time_utils.format_timestamp(segment['end'])}\n" + f"{segment['text'].strip()}\n\n" + for segment in segments + ) + + def compact_segments( + self, + segments: list[SegmentType], + min_words_per_segment: int, + ) -> list[SegmentType]: + if min_words_per_segment == 0: + return segments + + compacted_segments = [] + tmp_segment = None + + for segment in segments: + if tmp_segment: + tmp_segment['text'] += f" {segment['text'].strip()}" + tmp_segment['end'] = segment['end'] + + if len(tmp_segment['text'].split()) >= min_words_per_segment: + compacted_segments.append(tmp_segment) + tmp_segment = None + elif len(segment['text'].split()) < min_words_per_segment: + tmp_segment = SegmentType(text=segment['text'], start=segment['start'], end=segment['end']) + elif len(segment['text'].split()) >= min_words_per_segment: + compacted_segments.append(SegmentType(text=segment['text'], start=segment['start'], end=segment['end'])) + + if tmp_segment: + compacted_segments.append(tmp_segment) + + return compacted_segments + + def is_output_exist(self, file_name: str, output_config: Config.Output): + if output_config.save_files_before_compact and not all( + Path(output_config.output_dir, f'{file_name}-original.{output_format}').is_file() + for output_format in output_config.output_formats + ): + return False + + if (not output_config.save_files_before_compact or output_config.min_words_per_segment != 0) and not all( + Path(output_config.output_dir, f'{file_name}.{output_format}').is_file() + for output_format in output_config.output_formats + ): + return False + + return True + + def _write_to_file(self, file_path: str, content: str) -> None: + with open(file_path, 'w', encoding='utf-8') as fp: + fp.write(content)