diff --git a/whisperx/alignment.py b/whisperx/alignment.py index cd7f8ec8..aa9c7247 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -102,6 +102,7 @@ def align( return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, + preprocess: bool = True, ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. @@ -120,6 +121,11 @@ def align( model_lang = align_model_metadata["language"] model_type = align_model_metadata["type"] + # Load align model Huggingface processor for audio feature extraction (Normalization) + if preprocess and model_type == 'huggingface': + processor = Wav2Vec2Processor.from_pretrained( + DEFAULT_ALIGN_MODELS_HF[model_lang]) + # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) for sdx, segment in enumerate(transcript): @@ -222,7 +228,11 @@ def align( if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": - emissions = model(waveform_segment.to(device)).logits + if preprocess: + inputs = processor(waveform_segment.squeeze(), sampling_rate=processor.sampling_rate, return_tensors="pt").to(device) + emissions = model(**inputs).logits + else: + emissions = model(waveform_segment.to(device)).logits else: raise NotImplementedError(f"Align model of type {model_type} not supported.") emissions = torch.log_softmax(emissions, dim=-1)