Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scripts for speech-to-text using whisper and stt+forced alignment with whisperX #13

Merged
merged 11 commits into from
Mar 28, 2024
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
requires-python = ">=3.10"
requires-python = ">=3.10, <3.12"
dependencies = [
"speechbrain>=1.0.0",
"torchaudio>=2.0.0",
"opensmile>=2.3.0",
"matplotlib>=3.8.3",
"click",
"whisperx @ git+https://github.com/m-bain/whisperx.git@f2da2f858e99e4211fe4f64b5f2938b007827e17",
"pydra~=0.23",
"TTS",
"accelerate",
Expand All @@ -35,6 +36,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pytest",
"pre-commit"
]

[project.scripts]
Expand Down
74 changes: 74 additions & 0 deletions src/b2aiprep/speech2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import typing as ty

import torch
import whisperx

from .process import Audio


# Transcribes speech to text using the whisperX model
def transcribe_audio_whisperx(
audio: Audio,
hf_token: ty.Optional[str] = None,
model: str = "base",
device: ty.Optional[str] = None,
batch_size: int = 16,
compute_type: ty.Optional[str] = None,
force_alignment: bool = True,
return_char_alignments: bool = False,
diarize: bool = False,
min_speakers: ty.Optional[int] = None,
max_speakers: ty.Optional[int] = None,
):
"""
Transcribes audio to text using OpenAI's whisper model.

Args:
audio (audio): Audio object.
model (str): Model to use for transcription. Defaults to "base".
See https://github.com/openai/whisper/ for a list of all available models.
device (str): Device to use for computation. Defaults to "cuda".
batch_size (int): Batch size for transcription. Defaults to 16.
compute_type (str): Type of computation to use. Defaults to "float16".
Change to "int8" if low on GPU mem (may reduce accuracy)
force_alignment (bool): Whether or not to perform forced alignment of the
speech-to-text output
diarize (bool): Whether or not to assign speaker labels to the text
hf_token (str): A Huggingface auth token, required to perform speaker diarization

Returns:
Result of the transcription.
"""

# 1. Transcribe with original whisper (batched)
device = device or "cuda" if torch.cuda.is_available() else "cpu"
model = whisperx.load_model(model, device, compute_type=compute_type)

if audio.sample_rate != 16000:
audio = audio.to_16khz()
audio = audio.signal.squeeze().numpy()
result = model.transcribe(audio, batch_size=batch_size)

if force_alignment:
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=return_char_alignments,
)

if diarize:
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
return result
16 changes: 16 additions & 0 deletions src/tests/test_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from b2aiprep.process import Audio, SpeechToText
from b2aiprep.speech2text import transcribe_audio_whisperx


def test_transcribe():
Expand All @@ -22,6 +23,21 @@ def test_transcribe():
assert text.strip() == audio_content


def test_transcribe_whisperx():
"""
Validates SpeechToText's ability to convert audio to text accurately.
Checks if the transcription matches the expected output, considering known model discrepancies.
"""
audio_path = str((Path(__file__).parent.parent.parent / "data/vc_source.wav").absolute())
audio_content = "If it isn't, it isn't."

# Note: Should be "If it didn't, it didn't.", but that's what the model understands
audio = Audio.from_file(audio_path)

result = transcribe_audio_whisperx(audio, model="tiny", device="cpu", compute_type="float32")
assert result["segments"][0]["text"].strip() == audio_content


def test_cuda_not_available():
"""
Test behavior when CUDA is not available.
Expand Down
Loading