From bb319d776a64d2a36c2015032549c4d3803cff7a Mon Sep 17 00:00:00 2001 From: Ali Hamdi Ali Fadel Date: Thu, 27 Jun 2024 00:10:15 +0000 Subject: [PATCH] Process audio in memory instead of writing files in AudioSplitter --- pyproject.toml | 2 +- src/audio_splitter.py | 26 +++++++------- src/cli.py | 22 ++++++------ src/recognizers/wit_recognizer.py | 37 ++++++-------------- src/utils/whisper/whisper_utils.py | 5 +-- src/writer.py | 55 ++++++------------------------ 6 files changed, 44 insertions(+), 103 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8807c6a..c2cdd58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tafrigh" -version = "1.2.1" +version = "1.2.2" description = "تفريغ النصوص وإنشاء ملفات SRT و VTT باستخدام نماذج Whisper وتقنية wit.ai." authors = ["EasyBooks "] license = "MIT" diff --git a/src/audio_splitter.py b/src/audio_splitter.py index 0abb895..7ec0285 100644 --- a/src/audio_splitter.py +++ b/src/audio_splitter.py @@ -1,4 +1,4 @@ -import os +import io from auditok import AudioRegion from auditok.core import split @@ -10,7 +10,6 @@ class AudioSplitter: def split( self, file_path: str, - output_dir: str, min_dur: float = 0.5, max_dur: float = 15, max_silence: float = 0.5, @@ -36,7 +35,7 @@ def split( ) for segment in segments ] - return self._save_segments(output_dir, segments) + return self._semgnets_to_data(segments) def _expand_segment_with_noise( self, @@ -57,21 +56,20 @@ def _expand_segment_with_noise( return pre_noise + audio_segment + post_noise - def _save_segments( + def _semgnets_to_data( self, - output_dir: str, segments: list[AudioSegment | tuple[AudioSegment, float, float]], - ) -> list[tuple[str, float, float]]: - segment_paths = [] + ) -> list[tuple[bytes, float, float]]: + segment_data = [] - for i, segment in enumerate(segments): - output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3') + for segment in segments: + output_buffer = io.BytesIO() if isinstance(segment, tuple): - segment[0].export(output_file, format='mp3') - segment_paths.append((output_file, segment[1], segment[2])) + segment[0].export(output_buffer, format='mp3') + segment_data.append((output_buffer.getvalue(), segment[1], segment[2])) else: - segment.save(output_file) - segment_paths.append((output_file, segment.meta.start, segment.meta.end)) + segment.export(output_buffer, format='mp3') + segment_data.append((output_buffer.getvalue(), segment.meta.start, segment.meta.end)) - return segment_paths + return segment_data diff --git a/src/cli.py b/src/cli.py index e31fd4e..dd177b2 100644 --- a/src/cli.py +++ b/src/cli.py @@ -202,15 +202,13 @@ def process_url( continue new_progress_info = progress_info.copy() - new_progress_info.update( - { - 'inner_total': len(elements), - 'inner_current': idx + 1, - 'inner_status': 'processing', - 'progress': 0.0, - 'remaining_time': None, - } - ) + new_progress_info.update({ + 'inner_total': len(elements), + 'inner_current': idx + 1, + 'inner_status': 'processing', + 'progress': 0.0, + 'remaining_time': None, + }) yield new_progress_info, [] writer = Writer() @@ -226,9 +224,9 @@ def process_url( recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit) else: recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( - file_path, - model, - config.whisper, + file_path, + model, + config.whisper, ) while True: diff --git a/src/recognizers/wit_recognizer.py b/src/recognizers/wit_recognizer.py index 57d1335..21f7f8d 100644 --- a/src/recognizers/wit_recognizer.py +++ b/src/recognizers/wit_recognizer.py @@ -1,9 +1,6 @@ import json import logging import multiprocessing -import os -import shutil -import tempfile import time from typing import Generator, cast @@ -36,11 +33,8 @@ def recognize( file_path: str, wit_config: Config.Wit, ) -> Generator[dict[str, float], None, list[SegmentType]]: - temp_directory = tempfile.mkdtemp() - segments = AudioSplitter().split( file_path, - temp_directory, max_dur=wit_config.max_cutting_duration, expand_segments_with_noise=True, ) @@ -73,13 +67,13 @@ def recognize( async_results = [ pool.apply_async( self._process_segment, - ( - segment, - file_path, - wit_config, - session, - index % len(wit_config.wit_client_access_tokens or []), - ), + ( + segment, + file_path, + wit_config, + session, + index % len(wit_config.wit_client_access_tokens or []), + ), ) for index, segment in enumerate(segments) ] @@ -100,8 +94,6 @@ def recognize( else None, } - shutil.rmtree(temp_directory) - return transcriptions def _process_segment( @@ -114,10 +106,7 @@ def _process_segment( ) -> SegmentType: wit_calling_throttle.throttle(wit_client_access_token_index) # type: ignore - segment_file_path, start, end = segment - - with open(segment_file_path, 'rb') as mp3_file: - audio_content = mp3_file.read() + data, start, end = segment retries = 5 @@ -131,7 +120,7 @@ def _process_segment( 'Content-Type': 'audio/mpeg3', 'Authorization': f'Bearer {cast(list[str], wit_config.wit_client_access_tokens)[wit_client_access_token_index]}', }, - data=audio_content, + data=data, ) if response.status_code == 200: @@ -149,10 +138,4 @@ def _process_segment( f"The segment from `{file_path}` file that starts at {start} and ends at {end} didn't transcribed successfully." ) - os.remove(segment_file_path) - - return SegmentType( - text=text.strip(), - start=start, - end=end, - ) + return SegmentType(text=text.strip(), start=start, end=end) diff --git a/src/utils/whisper/whisper_utils.py b/src/utils/whisper/whisper_utils.py index 9ba1685..4cb8f7d 100644 --- a/src/utils/whisper/whisper_utils.py +++ b/src/utils/whisper/whisper_utils.py @@ -7,9 +7,6 @@ def load_model(whisper_config: Config.Whisper) -> WhisperModel: # type: ignore if whisper_config.use_faster_whisper: - return faster_whisper.WhisperModel( - whisper_config.model_name_or_path, - compute_type=whisper_config.ct2_compute_type, - ) + return faster_whisper.WhisperModel(whisper_config.model_name_or_path, compute_type=whisper_config.ct2_compute_type) else: return stable_whisper.load_model(whisper_config.model_name_or_path) diff --git a/src/writer.py b/src/writer.py index 83c802e..7da2727 100644 --- a/src/writer.py +++ b/src/writer.py @@ -11,12 +11,7 @@ class Writer: - def write_all( - self, - file_name: str, - segments: list[SegmentType], - output_config: Config.Output, - ) -> None: + def write_all(self, file_name: str, segments: list[SegmentType], output_config: Config.Output) -> None: if output_config.save_files_before_compact: for output_format in output_config.output_formats: self.write( @@ -35,12 +30,7 @@ def write_all( compacted_segments, ) - def write( - self, - format: TranscriptType, - file_path: str, - segments: list[SegmentType], - ) -> None: + def write(self, format: TranscriptType, file_path: str, segments: list[SegmentType]) -> None: if format == TranscriptType.TXT: self.write_txt(file_path, segments) elif format == TranscriptType.SRT: @@ -54,43 +44,22 @@ def write( elif format == TranscriptType.JSON: self.write_json(file_path, segments) - def write_txt( - self, - file_path: str, - segments: list[SegmentType], - ) -> None: + def write_txt(self, file_path: str, segments: list[SegmentType]) -> None: self._write_to_file(file_path, self.generate_txt(segments)) - def write_srt( - self, - file_path: str, - segments: list[SegmentType], - ) -> None: + def write_srt(self, file_path: str, segments: list[SegmentType]) -> None: self._write_to_file(file_path, self.generate_srt(segments)) - def write_vtt( - self, - file_path: str, - segments: list[SegmentType], - ) -> None: + def write_vtt(self, file_path: str, segments: list[SegmentType]) -> None: self._write_to_file(file_path, self.generate_vtt(segments)) - def write_csv( - self, - file_path: str, - segments: list[SegmentType], - delimiter=',', - ) -> None: + def write_csv(self, file_path: str, segments: list[SegmentType], delimiter=',') -> None: with open(file_path, 'w', encoding='utf-8') as fp: writer = csv.DictWriter(fp, fieldnames=['text', 'start', 'end'], delimiter=delimiter) writer.writeheader() writer.writerows(segments) - def write_json( - self, - file_path: str, - segments: list[SegmentType], - ) -> None: + def write_json(self, file_path: str, segments: list[SegmentType]) -> None: with open(file_path, 'w', encoding='utf-8') as fp: json.dump(segments, fp, ensure_ascii=False, indent=2) @@ -113,11 +82,7 @@ def generate_vtt(self, segments: list[SegmentType]) -> str: for segment in segments ) - def compact_segments( - self, - segments: list[SegmentType], - min_words_per_segment: int, - ) -> list[SegmentType]: + def compact_segments(self, segments: list[SegmentType], min_words_per_segment: int) -> list[SegmentType]: if min_words_per_segment == 0: return segments @@ -133,9 +98,9 @@ def compact_segments( compacted_segments.append(tmp_segment) tmp_segment = None elif len(segment['text'].split()) < min_words_per_segment: - tmp_segment = SegmentType(text=segment['text'], start=segment['start'], end=segment['end']) + tmp_segment = segment.copy() elif len(segment['text'].split()) >= min_words_per_segment: - compacted_segments.append(SegmentType(text=segment['text'], start=segment['start'], end=segment['end'])) + compacted_segments.append(segment.copy()) if tmp_segment: compacted_segments.append(tmp_segment)