Skip to content

Commit

Permalink
Process audio in memory instead of writing files in AudioSplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Jun 27, 2024
1 parent 82ff165 commit bb319d7
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 103 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tafrigh"
version = "1.2.1"
version = "1.2.2"
description = "تفريغ النصوص وإنشاء ملفات SRT و VTT باستخدام نماذج Whisper وتقنية wit.ai."
authors = ["EasyBooks <[email protected]>"]
license = "MIT"
Expand Down
26 changes: 12 additions & 14 deletions src/audio_splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import io

from auditok import AudioRegion
from auditok.core import split
Expand All @@ -10,7 +10,6 @@ class AudioSplitter:
def split(
self,
file_path: str,
output_dir: str,
min_dur: float = 0.5,
max_dur: float = 15,
max_silence: float = 0.5,
Expand All @@ -36,7 +35,7 @@ def split(
) for segment in segments
]

return self._save_segments(output_dir, segments)
return self._semgnets_to_data(segments)

def _expand_segment_with_noise(
self,
Expand All @@ -57,21 +56,20 @@ def _expand_segment_with_noise(

return pre_noise + audio_segment + post_noise

def _save_segments(
def _semgnets_to_data(
self,
output_dir: str,
segments: list[AudioSegment | tuple[AudioSegment, float, float]],
) -> list[tuple[str, float, float]]:
segment_paths = []
) -> list[tuple[bytes, float, float]]:
segment_data = []

for i, segment in enumerate(segments):
output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3')
for segment in segments:
output_buffer = io.BytesIO()

if isinstance(segment, tuple):
segment[0].export(output_file, format='mp3')
segment_paths.append((output_file, segment[1], segment[2]))
segment[0].export(output_buffer, format='mp3')
segment_data.append((output_buffer.getvalue(), segment[1], segment[2]))
else:
segment.save(output_file)
segment_paths.append((output_file, segment.meta.start, segment.meta.end))
segment.export(output_buffer, format='mp3')
segment_data.append((output_buffer.getvalue(), segment.meta.start, segment.meta.end))

return segment_paths
return segment_data
22 changes: 10 additions & 12 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,13 @@ def process_url(
continue

new_progress_info = progress_info.copy()
new_progress_info.update(
{
'inner_total': len(elements),
'inner_current': idx + 1,
'inner_status': 'processing',
'progress': 0.0,
'remaining_time': None,
}
)
new_progress_info.update({
'inner_total': len(elements),
'inner_current': idx + 1,
'inner_status': 'processing',
'progress': 0.0,
'remaining_time': None,
})
yield new_progress_info, []

writer = Writer()
Expand All @@ -226,9 +224,9 @@ def process_url(
recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit)
else:
recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize(
file_path,
model,
config.whisper,
file_path,
model,
config.whisper,
)

while True:
Expand Down
37 changes: 10 additions & 27 deletions src/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
import logging
import multiprocessing
import os
import shutil
import tempfile
import time

from typing import Generator, cast
Expand Down Expand Up @@ -36,11 +33,8 @@ def recognize(
file_path: str,
wit_config: Config.Wit,
) -> Generator[dict[str, float], None, list[SegmentType]]:
temp_directory = tempfile.mkdtemp()

segments = AudioSplitter().split(
file_path,
temp_directory,
max_dur=wit_config.max_cutting_duration,
expand_segments_with_noise=True,
)
Expand Down Expand Up @@ -73,13 +67,13 @@ def recognize(
async_results = [
pool.apply_async(
self._process_segment,
(
segment,
file_path,
wit_config,
session,
index % len(wit_config.wit_client_access_tokens or []),
),
(
segment,
file_path,
wit_config,
session,
index % len(wit_config.wit_client_access_tokens or []),
),
)
for index, segment in enumerate(segments)
]
Expand All @@ -100,8 +94,6 @@ def recognize(
else None,
}

shutil.rmtree(temp_directory)

return transcriptions

def _process_segment(
Expand All @@ -114,10 +106,7 @@ def _process_segment(
) -> SegmentType:
wit_calling_throttle.throttle(wit_client_access_token_index) # type: ignore

segment_file_path, start, end = segment

with open(segment_file_path, 'rb') as mp3_file:
audio_content = mp3_file.read()
data, start, end = segment

retries = 5

Expand All @@ -131,7 +120,7 @@ def _process_segment(
'Content-Type': 'audio/mpeg3',
'Authorization': f'Bearer {cast(list[str], wit_config.wit_client_access_tokens)[wit_client_access_token_index]}',
},
data=audio_content,
data=data,
)

if response.status_code == 200:
Expand All @@ -149,10 +138,4 @@ def _process_segment(
f"The segment from `{file_path}` file that starts at {start} and ends at {end} didn't transcribed successfully."
)

os.remove(segment_file_path)

return SegmentType(
text=text.strip(),
start=start,
end=end,
)
return SegmentType(text=text.strip(), start=start, end=end)
5 changes: 1 addition & 4 deletions src/utils/whisper/whisper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

def load_model(whisper_config: Config.Whisper) -> WhisperModel: # type: ignore
if whisper_config.use_faster_whisper:
return faster_whisper.WhisperModel(
whisper_config.model_name_or_path,
compute_type=whisper_config.ct2_compute_type,
)
return faster_whisper.WhisperModel(whisper_config.model_name_or_path, compute_type=whisper_config.ct2_compute_type)
else:
return stable_whisper.load_model(whisper_config.model_name_or_path)
55 changes: 10 additions & 45 deletions src/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@


class Writer:
def write_all(
self,
file_name: str,
segments: list[SegmentType],
output_config: Config.Output,
) -> None:
def write_all(self, file_name: str, segments: list[SegmentType], output_config: Config.Output) -> None:
if output_config.save_files_before_compact:
for output_format in output_config.output_formats:
self.write(
Expand All @@ -35,12 +30,7 @@ def write_all(
compacted_segments,
)

def write(
self,
format: TranscriptType,
file_path: str,
segments: list[SegmentType],
) -> None:
def write(self, format: TranscriptType, file_path: str, segments: list[SegmentType]) -> None:
if format == TranscriptType.TXT:
self.write_txt(file_path, segments)
elif format == TranscriptType.SRT:
Expand All @@ -54,43 +44,22 @@ def write(
elif format == TranscriptType.JSON:
self.write_json(file_path, segments)

def write_txt(
self,
file_path: str,
segments: list[SegmentType],
) -> None:
def write_txt(self, file_path: str, segments: list[SegmentType]) -> None:
self._write_to_file(file_path, self.generate_txt(segments))

def write_srt(
self,
file_path: str,
segments: list[SegmentType],
) -> None:
def write_srt(self, file_path: str, segments: list[SegmentType]) -> None:
self._write_to_file(file_path, self.generate_srt(segments))

def write_vtt(
self,
file_path: str,
segments: list[SegmentType],
) -> None:
def write_vtt(self, file_path: str, segments: list[SegmentType]) -> None:
self._write_to_file(file_path, self.generate_vtt(segments))

def write_csv(
self,
file_path: str,
segments: list[SegmentType],
delimiter=',',
) -> None:
def write_csv(self, file_path: str, segments: list[SegmentType], delimiter=',') -> None:
with open(file_path, 'w', encoding='utf-8') as fp:
writer = csv.DictWriter(fp, fieldnames=['text', 'start', 'end'], delimiter=delimiter)
writer.writeheader()
writer.writerows(segments)

def write_json(
self,
file_path: str,
segments: list[SegmentType],
) -> None:
def write_json(self, file_path: str, segments: list[SegmentType]) -> None:
with open(file_path, 'w', encoding='utf-8') as fp:
json.dump(segments, fp, ensure_ascii=False, indent=2)

Expand All @@ -113,11 +82,7 @@ def generate_vtt(self, segments: list[SegmentType]) -> str:
for segment in segments
)

def compact_segments(
self,
segments: list[SegmentType],
min_words_per_segment: int,
) -> list[SegmentType]:
def compact_segments(self, segments: list[SegmentType], min_words_per_segment: int) -> list[SegmentType]:
if min_words_per_segment == 0:
return segments

Expand All @@ -133,9 +98,9 @@ def compact_segments(
compacted_segments.append(tmp_segment)
tmp_segment = None
elif len(segment['text'].split()) < min_words_per_segment:
tmp_segment = SegmentType(text=segment['text'], start=segment['start'], end=segment['end'])
tmp_segment = segment.copy()
elif len(segment['text'].split()) >= min_words_per_segment:
compacted_segments.append(SegmentType(text=segment['text'], start=segment['start'], end=segment['end']))
compacted_segments.append(segment.copy())

if tmp_segment:
compacted_segments.append(tmp_segment)
Expand Down

0 comments on commit bb319d7

Please sign in to comment.