diff --git a/whisperx/asr.py b/whisperx/asr.py index 6de94900..7041a512 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -195,10 +195,15 @@ def transcribe( print_progress=False, combined_progress=False, verbose=False, + initial_prompt: Optional[str] = None, + ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) + if initial_prompt is not None: + self.options = replace(self.options, initial_prompt=initial_prompt) + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE)