Skip to content

Commit

Permalink
Add STTSentinel() to be added with some delay after the last
Browse files Browse the repository at this point in the history
active chunk. Use it to flush STT buffer to the LLM.
  • Loading branch information
sobomax committed Jan 11, 2025
1 parent a73e4db commit d1832b4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 40 deletions.
11 changes: 6 additions & 5 deletions Apps/AIAttendant/AIAActor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Union
from uuid import UUID
from functools import partial

Expand All @@ -11,7 +11,7 @@
from Cluster.InfernTTSActor import InfernTTSActor
from Cluster.InfernSTTActor import InfernSTTActor
from Cluster.InfernLLMActor import InfernLLMActor
from Cluster.STTSession import STTResult
from Cluster.STTSession import STTResult, STTSentinel
from Cluster.LLMSession import LLMResult
from SIP.RemoteSession import RemoteSessionOffer
from Core.T2T.NumbersToWords import NumbersToWords
Expand Down Expand Up @@ -74,9 +74,10 @@ def sess_term(self, sess_id:UUID, sip_sess_id:UUID, relaxed:bool=False):
if len(self.thumbstones) > 100:
self.thumbstones = self.thumbstones[-100:]

def text_in(self, sess_id:UUID, result:STTResult):
self.swriter.add_scalar(f'stt/inf_time', result.inf_time, self.nstts)
self.nstts += 1
def text_in(self, sess_id:UUID, result:Union[STTResult,STTSentinel]):
if isinstance(result, STTResult):
self.swriter.add_scalar(f'stt/inf_time', result.inf_time, self.nstts)
self.nstts += 1
self._get_session(sess_id).text_in(result)

def text_out(self, sess_id:UUID, result:LLMResult):
Expand Down
71 changes: 50 additions & 21 deletions Apps/AIAttendant/AIASession.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Tuple, List, Optional, Dict
from typing import Tuple, List, Optional, Dict, Union
from uuid import UUID, uuid4
from functools import partial
import ray

from nltk.tokenize import sent_tokenize

from Cluster.TTSSession import TTSRequest
from Cluster.STTSession import STTRequest, STTResult
from Cluster.STTSession import STTRequest, STTResult, STTSentinel
from Cluster.LLMSession import LLMRequest, LLMResult
from Cluster.RemoteTTSSession import RemoteTTSSession
from Cluster.InfernRTPActor import InfernRTPActor
Expand All @@ -17,16 +17,34 @@
from Core.AudioChunk import AudioChunk
from ..LiveTranslator.LTSession import _sess_term, TTSProxy

class STTProxy():
class STTProxy(AudioInput):
from time import monotonic
last_chunk_time: Optional[float] = None
debug = True
stt_do: callable
stt_done: callable
def __init__(self, stt_actr, stt_lang, stt_sess_id, stt_done):
self.stt_do = partial(stt_actr.stt_session_soundin.remote, sess_id=stt_sess_id)
self.lang, self.stt_done = stt_lang, stt_done

def audio_in(self, chunk:AudioChunk):
if self.last_chunk_time is None:
return
if chunk.active:
self.last_chunk_time = None
return
if self.monotonic() - self.last_chunk_time < 2.0:
return
def stt_done(result:STTSentinel):
print(f'STTProxy: {result=}')
self.stt_done(result=result)
self.last_chunk_time = None
sreq = STTSentinel('flush', stt_done)
self.stt_do(req=sreq)

# This method runs in the context of the inbound RTP Actor
def __call__(self, chunk:AudioChunk):
def vad_chunk_in(self, chunk:AudioChunk):
self.last_chunk_time = self.monotonic()
if self.debug:
print(f'STTProxy: VAD: {len(chunk.audio)=} {chunk.track_id=}')
def stt_done(result:STTResult):
Expand All @@ -48,6 +66,7 @@ class AIASession():
say_buffer: List[TTSRequest]
translator: Optional[Translator]
stt_sess_term: callable
text_in_buffer: List[str]

def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer):
self.id = uuid4()
Expand All @@ -70,7 +89,7 @@ def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer):
self.translator = aiaa.translator
text_cb = partial(aiaa.aia_actr.text_in.remote, sess_id=self.id)
vad_handler = STTProxy(aiaa.stt_actr, aiaa.stt_lang, self.stt_sess_id, text_cb)
self.rtp_actr.rtp_session_connect.remote(self.rtp_sess_id, AudioInput(None, vad_handler))
self.rtp_actr.rtp_session_connect.remote(self.rtp_sess_id, vad_handler)
soundout = partial(self.rtp_actr.rtp_session_soundout.remote, self.rtp_sess_id)
tts_soundout = TTSProxy(soundout)
self.tts_sess.start(tts_soundout)
Expand All @@ -81,31 +100,41 @@ def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer):
sess_id=self.llm_sess_id)
si = new_sess.sess_info
self.n2w = NumbersToWords()
self.text_in_buffer = []
self.text_to_llm(f'<Incoming call from "{si.from_name}" at "{si.from_number}">')
print(f'Agent {self.speaker} at your service.')

def text_to_llm(self, text:str):
req = LLMRequest(text, self.llm_text_cb)
req.auto_ctx_add = False
self.llm_session_textin(req=req)
self.last_llm_req_id = req.id

def text_in(self, result:STTResult):
print(f'STT: "{result.text=}" {result.no_speech_prob=}')
nsp = result.no_speech_prob
if nsp > STTRequest.max_ns_prob or len(result.text) == 0:
if result.duration < 5.0:
return
text = f'<unintelligible duration={result.duration} no_speech_probability={nsp}>'
else:
text = result.text
if len(self.say_buffer) > 1:
self.say_buffer = self.say_buffer[:1]
self.llm_session_context_add(content='<sentence interrupted>', role='user')
def text_in(self, result:Union[STTResult,STTSentinel]):
if isinstance(result, STTResult):
if self.debug:
print(f'STT: "{result.text=}" {result.no_speech_prob=}')
nsp = result.no_speech_prob
if nsp > STTRequest.max_ns_prob or len(result.text) == 0:
if result.duration < 5.0:
return
text = f'<unaudible duration={result.duration} no_speech_probability={nsp}>'
else:
text = result.text
self.text_in_buffer.append(text)
if len(self.say_buffer) > 1:
self.say_buffer = self.say_buffer[:1]
self.llm_session_context_add(content='<sentence interrupted>', role='user')
return
if len(self.text_in_buffer) == 0:
return
text = ' '.join(self.text_in_buffer)
self.text_in_buffer = []
self.text_to_llm(text)
return

def text_out(self, result:LLMResult):
print(f'text_out({result.text=})')
if self.debug: print(f'text_out({result.text=})')
if result.req_id != self.last_llm_req_id:
print(f'LLMResult for old req_id: {result.req_id}')
return
Expand All @@ -120,11 +149,11 @@ def text_out(self, result:LLMResult):
self.tts_say(t)

def _tts_say(self, tr:TTSRequest):
self.tts_sess.say(tr)
self.llm_session_context_add(content=tr.text[0], role='assistant')
self.tts_sess.say(tr)
self.llm_session_context_add(content=tr.text[0], role='assistant')

def tts_say(self, text):
print(f'tts_say({text=})')
if self.debug: print(f'tts_say({text=})')
text = self.n2w(text)
tts_req = TTSRequest([text,], done_cb=self.tts_say_done_cb, speaker_id=self.speaker)
self.say_buffer.append(tts_req)
Expand Down
6 changes: 3 additions & 3 deletions Cluster/InfernSTTActor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#try: import intel_extension_for_pytorch as ipex
#except ModuleNotFoundError: ipex = None

from typing import Dict
from typing import Dict, Union
from uuid import UUID

import ray

from Cluster.InfernSTTWorker import InfernSTTWorker
from Cluster.STTSession import STTSession, STTRequest
from Cluster.STTSession import STTSession, STTRequest, STTSentinel

@ray.remote(num_gpus=0.25, resources={"stt": 1})
class InfernSTTActor():
Expand Down Expand Up @@ -47,7 +47,7 @@ def stt_session_end(self, sess_id):
sess.stop()
del self.sessions[sess_id]

def stt_session_soundin(self, sess_id, req:STTRequest):
def stt_session_soundin(self, sess_id, req:Union[STTRequest,STTSentinel]):
if self.debug: print('InfernSTTActor.stt_session_soundin')
sess = self.sessions[sess_id]
sess.soundin(req)
Expand Down
44 changes: 33 additions & 11 deletions Cluster/STTSession.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union
from uuid import uuid4, UUID
from fractions import Fraction
from functools import partial
Expand All @@ -19,6 +19,13 @@ def __init__(self, chunk:AudioChunk, text_cb:callable, lang:str):
self.stime = monotonic()
self.lang, self.chunk, self.text_cb = lang, chunk, text_cb

class STTSentinel():
stime: float
text_cb: callable
def __init__(self, signal:str, text_cb:callable):
self.stime = monotonic()
self.signal, self.text_cb = signal, text_cb

class STTResult():
text: str
no_speech_prob: float
Expand Down Expand Up @@ -53,29 +60,44 @@ def stop(self):
with self.state_lock:
del self.stt, self.pending

def soundin(self, req:STTRequest):
if self.debug: print(f'STTSession.soundin({len(req.chunk.audio)=})')
if req.chunk.samplerate != self.stt.sample_rate:
req.chunk.resample(self.stt.sample_rate)
req.chunk.audio = req.chunk.audio.numpy()
def soundin(self, req:Union[STTRequest,STTSentinel]):
if self.debug:
if isinstance(req, STTRequest):
print(f'STTSession.soundin({len(req.chunk.audio)=})')
else:
print(f'STTSession.soundin({req=})')
if isinstance(req, STTRequest):
if req.chunk.samplerate != self.stt.sample_rate:
req.chunk.resample(self.stt.sample_rate)
req.chunk.audio = req.chunk.audio.numpy()
with self.state_lock:
if self.busy:
self.pending.append(req)
return
self.busy = True
if isinstance(req, STTRequest):
self.busy = True
else:
req.text_cb(result=req)
return
req.text_cb = partial(self.tts_out, req.text_cb)
self.stt.infer((req, self.context))

def tts_out(self, text_cb, result:STTResult):
results = [(text_cb, result)]
with self.state_lock:
if not hasattr(self, 'stt'):
return
if self.debug: print(f'STTSession.tts_out({result.text=})')
assert self.busy
if self.pending:
while self.pending:
req = self.pending.pop(0)
req.text_cb = partial(self.tts_out, req.text_cb)
self.stt.infer((req, self.context))
if isinstance(req, STTRequest):
req.text_cb = partial(self.tts_out, req.text_cb)
self.stt.infer((req, self.context))
break
if all(isinstance(r, STTRequest) for r in self.pending):
results.append((req.text_cb, req))
else:
self.busy = False
text_cb(result=result)
for cb, r in results:
cb(result=r)

0 comments on commit d1832b4

Please sign in to comment.