Skip to content

Commit

Permalink
add split_by_gap
Browse files Browse the repository at this point in the history
  • Loading branch information
AlperHuseyn authored Jan 28, 2025
1 parent 6c7bc99 commit c09bb9c
Showing 1 changed file with 96 additions and 1 deletion.
97 changes: 96 additions & 1 deletion whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,86 @@
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
}

def split_by_gap(segments: List[dict], gap_threshold: float = 0.7) -> List[dict]:
"""
Post-process aligned segments to split texts where word gaps exceed threshold.
Only splits within existing segments, does not modify already separate segments.
"""
result_segments = []

for segment in segments:
if not segment.get("words") or len(segment["words"]) < 2:
result_segments.append(segment)
continue

words = segment["words"]
gaps_found = False

# First check if any gaps exist in this segment
for i in range(1, len(words)):
prev_word = words[i - 1]
curr_word = words[i]
if (
prev_word.get("end")
and curr_word.get("start")
and (curr_word["start"] - prev_word["end"]) > gap_threshold
):
gaps_found = True
break

# If no gaps found, keep segment as is
if not gaps_found:
result_segments.append(segment)
continue

# If gaps found, split into multiple segments
current_segment = {"text": "", "words": [], "start": None, "end": None}

for i, word in enumerate(words):
if not word.get("start") or not word.get("end"):
if current_segment["words"]:
current_segment["words"].append(word)
continue

# First word
if current_segment["start"] is None:
current_segment["start"] = word["start"]
current_segment["words"] = [word]
continue

prev_word = words[i - 1]
# Check gap with previous word
if (
prev_word.get("end")
and (word["start"] - prev_word["end"]) > gap_threshold
):
# Complete current segment
current_segment["end"] = prev_word["end"]
current_segment["text"] = " ".join(
w["word"] for w in current_segment["words"]
)
result_segments.append(current_segment)

# Start new segment
current_segment = {
"text": "",
"words": [word],
"start": word["start"],
"end": None,
}
else:
current_segment["words"].append(word)

# Add final segment
if current_segment["words"]:
current_segment["end"] = current_segment["words"][-1]["end"]
current_segment["text"] = " ".join(
w["word"] for w in current_segment["words"]
)
result_segments.append(current_segment)

return result_segments


def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None:
Expand Down Expand Up @@ -119,6 +199,7 @@ def align(
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
gap_threshold: float = 0.7,
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
Expand Down Expand Up @@ -376,7 +457,21 @@ def align(
for segment in aligned_segments:
word_segments += segment["words"]

return {"segments": aligned_segments, "word_segments": word_segments}
aligned_result = {"segments": aligned_segments, "word_segments": word_segments}

# Post-process to split segments with large gaps
aligned_result["segments"] = split_by_gap(aligned_result["segments"], gap_threshold)

# Update word_segments list to match new segments
word_segments = []
for segment in aligned_result["segments"]:
if segment.get("words"):
word_segments.extend(segment["words"])
aligned_result["word_segments"] = word_segments

return aligned_result

# return {"segments": aligned_segments, "word_segments": word_segments}

"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
Expand Down

0 comments on commit c09bb9c

Please sign in to comment.