From c09bb9c575b361228306cf864b55a2a8d5aac120 Mon Sep 17 00:00:00 2001 From: Alper Huseyin Dogan Date: Tue, 28 Jan 2025 21:41:31 +0300 Subject: [PATCH] add split_by_gap --- whisperx/alignment.py | 97 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 3b2fdae9..1c19ff6d 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -72,6 +72,86 @@ "lv": "jimregan/wav2vec2-large-xlsr-latvian-cv", } +def split_by_gap(segments: List[dict], gap_threshold: float = 0.7) -> List[dict]: + """ + Post-process aligned segments to split texts where word gaps exceed threshold. + Only splits within existing segments, does not modify already separate segments. + """ + result_segments = [] + + for segment in segments: + if not segment.get("words") or len(segment["words"]) < 2: + result_segments.append(segment) + continue + + words = segment["words"] + gaps_found = False + + # First check if any gaps exist in this segment + for i in range(1, len(words)): + prev_word = words[i - 1] + curr_word = words[i] + if ( + prev_word.get("end") + and curr_word.get("start") + and (curr_word["start"] - prev_word["end"]) > gap_threshold + ): + gaps_found = True + break + + # If no gaps found, keep segment as is + if not gaps_found: + result_segments.append(segment) + continue + + # If gaps found, split into multiple segments + current_segment = {"text": "", "words": [], "start": None, "end": None} + + for i, word in enumerate(words): + if not word.get("start") or not word.get("end"): + if current_segment["words"]: + current_segment["words"].append(word) + continue + + # First word + if current_segment["start"] is None: + current_segment["start"] = word["start"] + current_segment["words"] = [word] + continue + + prev_word = words[i - 1] + # Check gap with previous word + if ( + prev_word.get("end") + and (word["start"] - prev_word["end"]) > gap_threshold + ): + # Complete current segment + current_segment["end"] = prev_word["end"] + current_segment["text"] = " ".join( + w["word"] for w in current_segment["words"] + ) + result_segments.append(current_segment) + + # Start new segment + current_segment = { + "text": "", + "words": [word], + "start": word["start"], + "end": None, + } + else: + current_segment["words"].append(word) + + # Add final segment + if current_segment["words"]: + current_segment["end"] = current_segment["words"][-1]["end"] + current_segment["text"] = " ".join( + w["word"] for w in current_segment["words"] + ) + result_segments.append(current_segment) + + return result_segments + def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): if model_name is None: @@ -119,6 +199,7 @@ def align( return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, + gap_threshold: float = 0.7, ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. @@ -376,7 +457,21 @@ def align( for segment in aligned_segments: word_segments += segment["words"] - return {"segments": aligned_segments, "word_segments": word_segments} + aligned_result = {"segments": aligned_segments, "word_segments": word_segments} + + # Post-process to split segments with large gaps + aligned_result["segments"] = split_by_gap(aligned_result["segments"], gap_threshold) + + # Update word_segments list to match new segments + word_segments = [] + for segment in aligned_result["segments"]: + if segment.get("words"): + word_segments.extend(segment["words"]) + aligned_result["word_segments"] = word_segments + + return aligned_result + + # return {"segments": aligned_segments, "word_segments": word_segments} """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html