From d1832b4aac770bb8e09e70989a32d04f8bf7f7b5 Mon Sep 17 00:00:00 2001 From: Maksym Sobolyev Date: Sat, 11 Jan 2025 06:05:50 +0000 Subject: [PATCH] Add STTSentinel() to be added with some delay after the last active chunk. Use it to flush STT buffer to the LLM. --- Apps/AIAttendant/AIAActor.py | 11 +++--- Apps/AIAttendant/AIASession.py | 71 ++++++++++++++++++++++++---------- Cluster/InfernSTTActor.py | 6 +-- Cluster/STTSession.py | 44 +++++++++++++++------ 4 files changed, 92 insertions(+), 40 deletions(-) diff --git a/Apps/AIAttendant/AIAActor.py b/Apps/AIAttendant/AIAActor.py index b040441..96b3e2d 100644 --- a/Apps/AIAttendant/AIAActor.py +++ b/Apps/AIAttendant/AIAActor.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Union from uuid import UUID from functools import partial @@ -11,7 +11,7 @@ from Cluster.InfernTTSActor import InfernTTSActor from Cluster.InfernSTTActor import InfernSTTActor from Cluster.InfernLLMActor import InfernLLMActor -from Cluster.STTSession import STTResult +from Cluster.STTSession import STTResult, STTSentinel from Cluster.LLMSession import LLMResult from SIP.RemoteSession import RemoteSessionOffer from Core.T2T.NumbersToWords import NumbersToWords @@ -74,9 +74,10 @@ def sess_term(self, sess_id:UUID, sip_sess_id:UUID, relaxed:bool=False): if len(self.thumbstones) > 100: self.thumbstones = self.thumbstones[-100:] - def text_in(self, sess_id:UUID, result:STTResult): - self.swriter.add_scalar(f'stt/inf_time', result.inf_time, self.nstts) - self.nstts += 1 + def text_in(self, sess_id:UUID, result:Union[STTResult,STTSentinel]): + if isinstance(result, STTResult): + self.swriter.add_scalar(f'stt/inf_time', result.inf_time, self.nstts) + self.nstts += 1 self._get_session(sess_id).text_in(result) def text_out(self, sess_id:UUID, result:LLMResult): diff --git a/Apps/AIAttendant/AIASession.py b/Apps/AIAttendant/AIASession.py index 83f75b6..aa9bef3 100644 --- a/Apps/AIAttendant/AIASession.py +++ b/Apps/AIAttendant/AIASession.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Optional, Dict +from typing import Tuple, List, Optional, Dict, Union from uuid import UUID, uuid4 from functools import partial import ray @@ -6,7 +6,7 @@ from nltk.tokenize import sent_tokenize from Cluster.TTSSession import TTSRequest -from Cluster.STTSession import STTRequest, STTResult +from Cluster.STTSession import STTRequest, STTResult, STTSentinel from Cluster.LLMSession import LLMRequest, LLMResult from Cluster.RemoteTTSSession import RemoteTTSSession from Cluster.InfernRTPActor import InfernRTPActor @@ -17,7 +17,9 @@ from Core.AudioChunk import AudioChunk from ..LiveTranslator.LTSession import _sess_term, TTSProxy -class STTProxy(): +class STTProxy(AudioInput): + from time import monotonic + last_chunk_time: Optional[float] = None debug = True stt_do: callable stt_done: callable @@ -25,8 +27,24 @@ def __init__(self, stt_actr, stt_lang, stt_sess_id, stt_done): self.stt_do = partial(stt_actr.stt_session_soundin.remote, sess_id=stt_sess_id) self.lang, self.stt_done = stt_lang, stt_done + def audio_in(self, chunk:AudioChunk): + if self.last_chunk_time is None: + return + if chunk.active: + self.last_chunk_time = None + return + if self.monotonic() - self.last_chunk_time < 2.0: + return + def stt_done(result:STTSentinel): + print(f'STTProxy: {result=}') + self.stt_done(result=result) + self.last_chunk_time = None + sreq = STTSentinel('flush', stt_done) + self.stt_do(req=sreq) + # This method runs in the context of the inbound RTP Actor - def __call__(self, chunk:AudioChunk): + def vad_chunk_in(self, chunk:AudioChunk): + self.last_chunk_time = self.monotonic() if self.debug: print(f'STTProxy: VAD: {len(chunk.audio)=} {chunk.track_id=}') def stt_done(result:STTResult): @@ -48,6 +66,7 @@ class AIASession(): say_buffer: List[TTSRequest] translator: Optional[Translator] stt_sess_term: callable + text_in_buffer: List[str] def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer): self.id = uuid4() @@ -70,7 +89,7 @@ def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer): self.translator = aiaa.translator text_cb = partial(aiaa.aia_actr.text_in.remote, sess_id=self.id) vad_handler = STTProxy(aiaa.stt_actr, aiaa.stt_lang, self.stt_sess_id, text_cb) - self.rtp_actr.rtp_session_connect.remote(self.rtp_sess_id, AudioInput(None, vad_handler)) + self.rtp_actr.rtp_session_connect.remote(self.rtp_sess_id, vad_handler) soundout = partial(self.rtp_actr.rtp_session_soundout.remote, self.rtp_sess_id) tts_soundout = TTSProxy(soundout) self.tts_sess.start(tts_soundout) @@ -81,7 +100,9 @@ def __init__(self, aiaa:'AIAActor', new_sess:RemoteSessionOffer): sess_id=self.llm_sess_id) si = new_sess.sess_info self.n2w = NumbersToWords() + self.text_in_buffer = [] self.text_to_llm(f'') + print(f'Agent {self.speaker} at your service.') def text_to_llm(self, text:str): req = LLMRequest(text, self.llm_text_cb) @@ -89,23 +110,31 @@ def text_to_llm(self, text:str): self.llm_session_textin(req=req) self.last_llm_req_id = req.id - def text_in(self, result:STTResult): - print(f'STT: "{result.text=}" {result.no_speech_prob=}') - nsp = result.no_speech_prob - if nsp > STTRequest.max_ns_prob or len(result.text) == 0: - if result.duration < 5.0: - return - text = f'' - else: - text = result.text - if len(self.say_buffer) > 1: - self.say_buffer = self.say_buffer[:1] - self.llm_session_context_add(content='', role='user') + def text_in(self, result:Union[STTResult,STTSentinel]): + if isinstance(result, STTResult): + if self.debug: + print(f'STT: "{result.text=}" {result.no_speech_prob=}') + nsp = result.no_speech_prob + if nsp > STTRequest.max_ns_prob or len(result.text) == 0: + if result.duration < 5.0: + return + text = f'' + else: + text = result.text + self.text_in_buffer.append(text) + if len(self.say_buffer) > 1: + self.say_buffer = self.say_buffer[:1] + self.llm_session_context_add(content='', role='user') + return + if len(self.text_in_buffer) == 0: + return + text = ' '.join(self.text_in_buffer) + self.text_in_buffer = [] self.text_to_llm(text) return def text_out(self, result:LLMResult): - print(f'text_out({result.text=})') + if self.debug: print(f'text_out({result.text=})') if result.req_id != self.last_llm_req_id: print(f'LLMResult for old req_id: {result.req_id}') return @@ -120,11 +149,11 @@ def text_out(self, result:LLMResult): self.tts_say(t) def _tts_say(self, tr:TTSRequest): - self.tts_sess.say(tr) - self.llm_session_context_add(content=tr.text[0], role='assistant') + self.tts_sess.say(tr) + self.llm_session_context_add(content=tr.text[0], role='assistant') def tts_say(self, text): - print(f'tts_say({text=})') + if self.debug: print(f'tts_say({text=})') text = self.n2w(text) tts_req = TTSRequest([text,], done_cb=self.tts_say_done_cb, speaker_id=self.speaker) self.say_buffer.append(tts_req) diff --git a/Cluster/InfernSTTActor.py b/Cluster/InfernSTTActor.py index 6d8580c..64263da 100644 --- a/Cluster/InfernSTTActor.py +++ b/Cluster/InfernSTTActor.py @@ -1,13 +1,13 @@ #try: import intel_extension_for_pytorch as ipex #except ModuleNotFoundError: ipex = None -from typing import Dict +from typing import Dict, Union from uuid import UUID import ray from Cluster.InfernSTTWorker import InfernSTTWorker -from Cluster.STTSession import STTSession, STTRequest +from Cluster.STTSession import STTSession, STTRequest, STTSentinel @ray.remote(num_gpus=0.25, resources={"stt": 1}) class InfernSTTActor(): @@ -47,7 +47,7 @@ def stt_session_end(self, sess_id): sess.stop() del self.sessions[sess_id] - def stt_session_soundin(self, sess_id, req:STTRequest): + def stt_session_soundin(self, sess_id, req:Union[STTRequest,STTSentinel]): if self.debug: print('InfernSTTActor.stt_session_soundin') sess = self.sessions[sess_id] sess.soundin(req) diff --git a/Cluster/STTSession.py b/Cluster/STTSession.py index f62e521..08c3a7d 100644 --- a/Cluster/STTSession.py +++ b/Cluster/STTSession.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from uuid import uuid4, UUID from fractions import Fraction from functools import partial @@ -19,6 +19,13 @@ def __init__(self, chunk:AudioChunk, text_cb:callable, lang:str): self.stime = monotonic() self.lang, self.chunk, self.text_cb = lang, chunk, text_cb +class STTSentinel(): + stime: float + text_cb: callable + def __init__(self, signal:str, text_cb:callable): + self.stime = monotonic() + self.signal, self.text_cb = signal, text_cb + class STTResult(): text: str no_speech_prob: float @@ -53,29 +60,44 @@ def stop(self): with self.state_lock: del self.stt, self.pending - def soundin(self, req:STTRequest): - if self.debug: print(f'STTSession.soundin({len(req.chunk.audio)=})') - if req.chunk.samplerate != self.stt.sample_rate: - req.chunk.resample(self.stt.sample_rate) - req.chunk.audio = req.chunk.audio.numpy() + def soundin(self, req:Union[STTRequest,STTSentinel]): + if self.debug: + if isinstance(req, STTRequest): + print(f'STTSession.soundin({len(req.chunk.audio)=})') + else: + print(f'STTSession.soundin({req=})') + if isinstance(req, STTRequest): + if req.chunk.samplerate != self.stt.sample_rate: + req.chunk.resample(self.stt.sample_rate) + req.chunk.audio = req.chunk.audio.numpy() with self.state_lock: if self.busy: self.pending.append(req) return - self.busy = True + if isinstance(req, STTRequest): + self.busy = True + else: + req.text_cb(result=req) + return req.text_cb = partial(self.tts_out, req.text_cb) self.stt.infer((req, self.context)) def tts_out(self, text_cb, result:STTResult): + results = [(text_cb, result)] with self.state_lock: if not hasattr(self, 'stt'): return if self.debug: print(f'STTSession.tts_out({result.text=})') assert self.busy - if self.pending: + while self.pending: req = self.pending.pop(0) - req.text_cb = partial(self.tts_out, req.text_cb) - self.stt.infer((req, self.context)) + if isinstance(req, STTRequest): + req.text_cb = partial(self.tts_out, req.text_cb) + self.stt.infer((req, self.context)) + break + if all(isinstance(r, STTRequest) for r in self.pending): + results.append((req.text_cb, req)) else: self.busy = False - text_cb(result=result) + for cb, r in results: + cb(result=r)