Skip to content

Commit

Permalink
feat: whisper - rely on word probability and punctuation for better s…
Browse files Browse the repository at this point in the history
…egmentation
  • Loading branch information
rpurdel authored Jan 28, 2025
1 parent e0dcfa9 commit b06d22c
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 98 deletions.
24 changes: 18 additions & 6 deletions demos/streaming-whisper/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ <h1 class="title">Streaming Whisper Demo</h1>
<div class="field">
<button class="button is-primary" id="transcribebtn">Transcribe</button>
<button class="button" id="stopbtn" disabled>Stop</button>
<button class="button" id="clearbtn">Clear</button>
<div class="select is-disabled">
<select id="langselector" name="langselector">
<option value="en" selected="selected">English</option>
Expand All @@ -48,7 +49,7 @@ <h1 class="title">Streaming Whisper Demo</h1>
</section>
<section class="section">
<div class="container is-max-desktop">
<div class="box has-text-centered is-family-monospace has-text-grey" id="outputcontainer">
<div class="box is-family-monospace has-text-grey" id="outputcontainer">
<p>Waiting for interims...</p>
</div>
</div>
Expand All @@ -68,8 +69,6 @@ <h1 class="title">Streaming Whisper Demo</h1>
// mostly taken from https://dev.to/louisgv/quick-guide-to-audioworklet-30df
const main = async () => {
let finals = []
const CLIENTID = crypto.randomUUID()
const MEETINGID = crypto.randomUUID()
const context = new AudioContext({ sampleRate: 16000 })
const microphone = await navigator.mediaDevices.getUserMedia({
audio: true,
Expand All @@ -84,13 +83,16 @@ <h1 class="title">Streaming Whisper Demo</h1>
const langSel = document.getElementById('langselector')
const transcribeBtn = document.getElementById('transcribebtn')
const stopBtn = document.getElementById('stopbtn')
const clearBtn = document.getElementById('clearbtn')
const output = document.getElementById('outputcontainer')
const final = document.getElementById('finalcontainer')
const jwt = document.getElementById('jwt')
const wsHost = document.getElementById('wshost')
const muteBtn = document.getElementById('mutebtn')

let isMuted = false
let clientId = crypto.randomUUID()

muteBtn.addEventListener('click', () => {
if (isMuted) {
microphone.getAudioTracks()[0].enabled = true
Expand All @@ -103,6 +105,13 @@ <h1 class="title">Streaming Whisper Demo</h1>
}
})

clearBtn.addEventListener('click', () => {
finals = []
final.innerHTML = ''
interims = []
output.innerHTML = 'Waiting for interims...'
})

var isSpeaking = false

// create the recorder worklet
Expand All @@ -128,7 +137,8 @@ <h1 class="title">Streaming Whisper Demo</h1>
final.innerHTML += '<div class="columns">' +
'<div class="column"><span class="tag is-info is-light is-family-monospace">' +
h + ':' + m + ':' + s + '.' + ms + '</span>' + playButton + '</div>' +
'<div class="column is-three-fifths"><span class="transcript">' + msg.text + '</span>' +
'<div class="column is-three-fifths"><span class="tag is-info">' +
msg.variance.toFixed(2) + '</span><span class="transcript">' + msg.text + '</span>' +
'</div></div>'
}
}
Expand Down Expand Up @@ -156,7 +166,9 @@ <h1 class="title">Streaming Whisper Demo</h1>
}

function wsConnect() {
let wsConnectionString = wsHost.value.trim() + '/' + MEETINGID
let meetingId = crypto.randomUUID()
clientId = crypto.randomUUID()
let wsConnectionString = wsHost.value.trim() + '/' + meetingId
if (jwt.value.trim() != '') {
wsConnectionString += '?auth_token=' + jwt.value.trim()
}
Expand All @@ -173,7 +185,7 @@ <h1 class="title">Streaming Whisper Demo</h1>

function preparePayload(data) {
let lang = langSel.value
let str = CLIENTID + "|" + lang
let str = clientId + "|" + lang
if (str.length < 60) {
str = str.padEnd(60, " ")
}
Expand Down
4 changes: 3 additions & 1 deletion skynet/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def tobool(val: str | None):
whisper_model_path = os.environ.get('WHISPER_MODEL_PATH', f'{os.getcwd()}/models/streaming_whisper')
whisper_return_transcribed_audio = tobool(os.environ.get('WHISPER_RETURN_TRANSCRIBED_AUDIO'))
whisper_max_connections = int(os.environ.get('WHISPER_MAX_CONNECTIONS', 10))
whisper_min_probability = float(os.environ.get('WHISPER_MIN_PROBABILITY', 0.7))
ws_max_size_bytes = int(os.environ.get('WS_MAX_SIZE_BYTES', 1000000))
ws_max_queue_size = int(os.environ.get('WS_MAX_QUEUE_SIZE', 3000))
ws_max_ping_interval = int(os.environ.get('WS_MAX_PING_INTERVAL', 30))
Expand All @@ -91,7 +92,8 @@ def tobool(val: str | None):
# This is used to provide some context to the model
# The larger the initial prompt (max 224 tokens), the slower the inference.
whisper_max_finals_in_initial_prompt = int(os.environ.get('WHISPER_MAX_FINALS_IN_INITIAL_PROMPT', 2))

# The period in milliseconds to flush the buffer after no new spoken audio is detected
whisper_flush_interval = int(os.environ.get('WHISPER_FLUSH_BUFFER_INTERVAL', 2000))

# jobs
job_timeout = int(os.environ.get('JOB_TIMEOUT', 60 * 5)) # 5 minutes default
Expand Down
17 changes: 8 additions & 9 deletions skynet/modules/stt/streaming_whisper/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
from fastapi import WebSocket, WebSocketDisconnect

from skynet.auth.jwt import authorize
from skynet.env import bypass_auth, whisper_max_connections
from skynet.env import bypass_auth, whisper_flush_interval, whisper_max_connections
from skynet.logs import get_logger
from skynet.modules.monitoring import CONNECTIONS_METRIC, TRANSCRIBE_CONNECTIONS_COUNTER, TRANSCRIBE_STRESS_LEVEL_METRIC
from skynet.modules.stt.streaming_whisper.meeting_connection import MeetingConnection
from skynet.modules.stt.streaming_whisper.utils import utils

log = get_logger(__name__)

FLUSH_AFTER_MS = 2000


class ConnectionManager:
connections: dict[str, MeetingConnection]
Expand Down Expand Up @@ -76,12 +74,13 @@ async def flush_working_audio_worker(self):
while True:
for meeting_id in self.connections:
for participant in self.connections[meeting_id].participants:
now = utils.now()
last_received_chunk = self.connections[meeting_id].participants[participant].last_received_chunk
is_due = now - last_received_chunk > FLUSH_AFTER_MS
is_silent, _ = utils.is_silent(self.connections[meeting_id].participants[participant].working_audio)
if is_due and not is_silent:
state = self.connections[meeting_id].participants[participant]
diff = utils.now() - state.last_received_chunk
log.debug(
f'Participant {participant} in meeting {meeting_id} has been silent for {diff} ms and has {len(state.working_audio)} bytes of audio'
)
if diff > whisper_flush_interval and len(state.working_audio) > 0 and not state.is_transcribing:
log.info(f'Forcing a transcription in meeting {meeting_id} for {participant}')
results = await self.connections[meeting_id].participants[participant].force_transcription()
results = await self.connections[meeting_id].force_transcription(participant)
await self.send(meeting_id, results)
await asyncio.sleep(1)
26 changes: 17 additions & 9 deletions skynet/modules/stt/streaming_whisper/meeting_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def __init__(self, ws: WebSocket):
self.meeting_language = None
self.tokenizer = None

async def update_initial_prompt(self, new_transcription: str):
self.previous_transcription_store.append(self.tokenizer.encode(f' {new_transcription.strip()}'))
if len(self.previous_transcription_tokens) > max_finals:
self.previous_transcription_store.pop(0)
# flatten the list of lists
self.previous_transcription_tokens = list(chain.from_iterable(self.previous_transcription_store))
async def update_initial_prompt(self, previous_payloads: list[utils.TranscriptionResponse]):
for payload in previous_payloads:
if payload.type == 'final' and not any(prompt in payload.text for prompt in utils.black_listed_prompts):
self.previous_transcription_store.append(self.tokenizer.encode(f' {payload.text.strip()}'))
if len(self.previous_transcription_tokens) > max_finals:
self.previous_transcription_store.pop(0)
# flatten the list of lists
self.previous_transcription_tokens = list(chain.from_iterable(self.previous_transcription_store))

async def process(self, chunk: bytes, chunk_timestamp: int) -> List[utils.TranscriptionResponse] | None:
a_chunk = Chunk(chunk, chunk_timestamp)
Expand All @@ -56,7 +58,13 @@ async def process(self, chunk: bytes, chunk_timestamp: int) -> List[utils.Transc

payloads = await self.participants[a_chunk.participant_id].process(a_chunk, self.previous_transcription_tokens)
if payloads:
for payload in payloads:
if payload.type == 'final':
await self.update_initial_prompt(payload.text)
await self.update_initial_prompt(payloads)
return payloads

async def force_transcription(self, participant_id: str):
if participant_id in self.participants:
payloads = await self.participants[participant_id].force_transcription(self.previous_transcription_tokens)
if payloads:
await self.update_initial_prompt(payloads)
return payloads
return None
Loading

0 comments on commit b06d22c

Please sign in to comment.