From 2ba559810f6dc22aabdb3c7e32928cf907996296 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 31 Dec 2024 17:39:51 +0530 Subject: [PATCH 01/19] init --- examples/multimodal-agent/gemini_agent.py | 3 +- .../agents/multimodal/agent_playout.py | 51 ++++-- .../agents/multimodal/multimodal_agent.py | 156 +++++++++++++----- 3 files changed, 157 insertions(+), 53 deletions(-) diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py index 81a474609..f9ac707f1 100644 --- a/examples/multimodal-agent/gemini_agent.py +++ b/examples/multimodal-agent/gemini_agent.py @@ -14,7 +14,7 @@ llm, multimodal, ) -from livekit.plugins import google +from livekit.plugins import deepgram, google load_dotenv() @@ -60,6 +60,7 @@ async def get_weather( ), fnc_ctx=fnc_ctx, chat_ctx=chat_ctx, + stt=deepgram.STT(), ) agent.start(ctx.room, participant) diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index f1dbda1e7..8ed5a162c 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -4,11 +4,13 @@ from typing import AsyncIterable, Literal from livekit import rtc -from livekit.agents import transcription, utils +from livekit.agents import stt, transcription, utils from ..log import logger -EventTypes = Literal["playout_started", "playout_stopped"] +EventTypes = Literal[ + "playout_started", "playout_stopped", "final_transcript", "interim_transcript" +] class PlayoutHandle: @@ -68,9 +70,17 @@ def interrupt(self) -> None: class AgentPlayout(utils.EventEmitter[EventTypes]): - def __init__(self, *, audio_source: rtc.AudioSource) -> None: + def __init__( + self, + *, + audio_source: rtc.AudioSource, + stt: stt.STT, + stt_forwarder: stt.STTForwarder, + ) -> None: super().__init__() self._source = audio_source + self._stt = stt + self._stt_forwarder = stt_forwarder self._playout_atask: asyncio.Task[None] | None = None def play( @@ -106,6 +116,7 @@ async def _playout_task( await utils.aio.gracefully_cancel(old_task) first_frame = True + stt_stream = self._stt.stream() if self._stt is not None else None @utils.log_exceptions(logger=logger) async def _play_text_stream(): @@ -134,36 +145,54 @@ async def _capture_task(): handle._tr_fwd.push_audio(frame) for f in bstream.write(frame.data.tobytes()): + if stt_stream is not None: + stt_stream.push_frame(f) handle._pushed_duration += f.samples_per_channel / f.sample_rate await self._source.capture_frame(f) for f in bstream.flush(): handle._pushed_duration += f.samples_per_channel / f.sample_rate + if stt_stream is not None: + stt_stream.push_frame(f) await self._source.capture_frame(f) handle._tr_fwd.mark_audio_segment_end() await self._source.wait_for_playout() + async def _stt_stream_co() -> None: + if stt_stream is not None: + async for ev in stt_stream: + self._stt_forwarder.update(ev) + + if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: + self.emit("final_transcript", ev) + elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: + self.emit("interim_transcript", ev) + read_text_task = asyncio.create_task(_play_text_stream()) - capture_task = asyncio.create_task(_capture_task()) + tasks = [ + asyncio.create_task(_capture_task()), + asyncio.create_task(_stt_stream_co()), + ] try: - await asyncio.wait( - [capture_task, handle._int_fut], + done, _ = await asyncio.wait( + [asyncio.gather(*tasks), handle._int_fut], return_when=asyncio.FIRST_COMPLETED, ) - finally: - await utils.aio.gracefully_cancel(capture_task) handle._total_played_time = ( handle._pushed_duration - self._source.queued_duration ) - if handle.interrupted or capture_task.exception(): - self._source.clear_queue() # make sure to remove any queued frames + for task in done: + if handle.interrupted or task.exception(): + self._source.clear_queue() # make sure to remove any queued frames + break - await utils.aio.gracefully_cancel(read_text_task) + finally: + await utils.aio.gracefully_cancel(*tasks, read_text_task) # make sure the text_data.sentence_stream is closed handle._tr_fwd.mark_text_segment_end() diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index f02bb2e64..e9223a7c9 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -16,9 +16,10 @@ import aiohttp from livekit import rtc -from livekit.agents import llm, stt, tokenize, transcription, utils, vad +from livekit.agents import llm, stt, tokenize, utils, vad from livekit.agents.llm import ChatMessage from livekit.agents.metrics import MultimodalLLMMetrics +from livekit.agents.transcription import STTSegmentsForwarder, TTSSegmentsForwarder from ..log import logger from ..types import ATTRIBUTE_AGENT_STATE, AgentState @@ -35,6 +36,8 @@ "function_calls_collected", "function_calls_finished", "metrics_collected", + "final_transcript", + "interim_transcript", ] @@ -143,6 +146,7 @@ def __init__( vad: vad.VAD | None = None, chat_ctx: llm.ChatContext | None = None, fnc_ctx: llm.FunctionContext | None = None, + stt: stt.STT | None = None, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), max_text_response_retries: int = 5, loop: asyncio.AbstractEventLoop | None = None, @@ -175,8 +179,18 @@ def __init__( transcription=transcription, ) + # if not stt.capabilities.streaming: + # from .. import stt as speech_to_text + + # stt = speech_to_text.StreamAdapter( + # stt=stt, + # vad=vad, + # ) + self._stt = stt + self.on("final_transcript", self._on_final_transcript) + # audio input - self._read_micro_atask: asyncio.Task | None = None + self._recognize_atask: asyncio.Task | None = None self._subscribed_track: rtc.RemoteAudioTrack | None = None self._input_audio_ch = utils.aio.Chan[rtc.AudioFrame]() @@ -196,6 +210,10 @@ def __init__( def vad(self) -> vad.VAD | None: return self._vad + @property + def stt(self) -> stt.STT | None: + return self._stt + @property def fnc_ctx(self) -> llm.FunctionContext | None: return self._session.fnc_ctx @@ -252,7 +270,7 @@ async def _init_and_start(): @self._session.on("response_content_added") def _on_content_added(message: _ContentProto): - tr_fwd = transcription.TTSSegmentsForwarder( + tr_fwd = TTSSegmentsForwarder( room=self._room, participant=self._room.local_participant, speed=self._opts.transcription.agent_transcription_speed, @@ -319,11 +337,7 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto): ev.item_id, user_msg.content ) - self.emit("user_speech_committed", user_msg) - logger.debug( - "committed user speech", - extra={"user_transcript": ev.transcript}, - ) + self._emit_speech_committed("user", ev.transcript) @self._session.on("input_speech_started") def _input_speech_started(): @@ -378,8 +392,21 @@ async def _run_task(delay: float) -> None: async def _main_task(self) -> None: self._update_state("initializing") self._audio_source = rtc.AudioSource(24000, 1) + track = rtc.LocalAudioTrack.create_audio_track( + "assistant_voice", self._audio_source + ) + self._agent_publication = await self._room.local_participant.publish_track( + track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) + ) + self._agent_stt_forwarder = STTSegmentsForwarder( + room=self._room, + participant=self._room.local_participant, + track=track, + ) self._agent_playout = agent_playout.AgentPlayout( - audio_source=self._audio_source + audio_source=self._audio_source, + stt=self._stt, + stt_forwarder=self._agent_stt_forwarder, ) def _on_playout_started() -> None: @@ -395,38 +422,25 @@ def _on_playout_stopped(interrupted: bool) -> None: if interrupted: collected_text += "..." - msg = ChatMessage.create( - text=collected_text, - role="assistant", - id=self._playing_handle.item_id, - ) - if self._model.capabilities.supports_truncate: + if self._model.capabilities.supports_truncate and collected_text: + msg = ChatMessage.create( + text=collected_text, + role="assistant", + id=self._playing_handle.item_id, + ) self._session._update_conversation_item_content( self._playing_handle.item_id, msg.content ) - if interrupted: - self.emit("agent_speech_interrupted", msg) - else: - self.emit("agent_speech_committed", msg) + self._emit_speech_committed("agent", collected_text, interrupted) - logger.debug( - "committed agent speech", - extra={ - "agent_transcript": collected_text, - "interrupted": interrupted, - }, - ) + def _on_final_transcript(ev: stt.SpeechEvent): + self._emit_speech_committed("agent", ev.alternatives[0].text) self._agent_playout.on("playout_started", _on_playout_started) self._agent_playout.on("playout_stopped", _on_playout_stopped) + self._agent_playout.on("final_transcript", _on_final_transcript) - track = rtc.LocalAudioTrack.create_audio_track( - "assistant_voice", self._audio_source - ) - self._agent_publication = await self._room.local_participant.publish_track( - track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) - ) await self._agent_publication.wait_for_subscription() @@ -455,10 +469,7 @@ def _link_participant(self, participant_identity: str) -> None: self._subscribe_to_microphone() - async def _micro_task(self, track: rtc.LocalAudioTrack) -> None: - stream_24khz = rtc.AudioStream(track, sample_rate=24000, num_channels=1) - async for ev in stream_24khz: - self._input_audio_ch.send_nowait(ev.frame) + def _subscribe_to_microphone(self, *args, **kwargs) -> None: """Subscribe to the participant microphone if found""" @@ -478,22 +489,85 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: and publication.track != self._subscribed_track ): self._subscribed_track = publication.track # type: ignore - self._stt_forwarder = transcription.STTSegmentsForwarder( + stream_24khz = rtc.AudioStream( + self._subscribed_track, sample_rate=24000, num_channels=1 + ) + self._stt_forwarder = STTSegmentsForwarder( room=self._room, participant=self._linked_participant, track=self._subscribed_track, ) - if self._read_micro_atask is not None: - self._read_micro_atask.cancel() + if self._recognize_atask is not None: + self._recognize_atask.cancel() - self._read_micro_atask = asyncio.create_task( - self._micro_task(self._subscribed_track) # type: ignore + self._recognize_atask = asyncio.create_task( + self._recognize_task(stream_24khz) ) break + + + @utils.log_exceptions(logger=logger) + async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: + """ + Receive the frames from the user audio stream. + """ + + stt_stream = self._stt.stream() if self._stt is not None else None + + async def _micro_task() -> None: + async for ev in audio_stream: + if stt_stream is not None: + stt_stream.push_frame(ev.frame) + self._input_audio_ch.send_nowait(ev.frame) + + async def _stt_stream_co() -> None: + if stt_stream is not None: + async for ev in stt_stream: + self._stt_forwarder.update(ev) + + if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: + self.emit("final_transcript", ev) + elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: + self.emit("interim_transcript", ev) + + tasks = [ + asyncio.create_task(_micro_task()), + asyncio.create_task(_stt_stream_co()), + ] + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + if stt_stream is not None: + await stt_stream.aclose() + def _ensure_session(self) -> aiohttp.ClientSession: if not self._http_session: self._http_session = utils.http_context.http_session() return self._http_session + + def _on_final_transcript(self, ev: stt.SpeechEvent): + self._emit_speech_committed("user", ev.alternatives[0].text) + + def _emit_speech_committed( + self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False + ): + + if speaker == "user": + self.emit("user_speech_committed", msg) + else: + if interrupted: + self.emit("agent_speech_interrupted", msg) + else: + self.emit("agent_speech_committed", msg) + + logger.debug( + f"committed {speaker} speech", + extra={ + f"{speaker}_transcript": msg, + "interrupted": interrupted, + }, + ) From 48893010b61afbf9d0dcbc87e95c583f4aaec7a9 Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 2 Jan 2025 19:18:09 +0530 Subject: [PATCH 02/19] updates --- examples/multimodal-agent/gemini_agent.py | 5 +- .../agents/multimodal/agent_playout.py | 2 +- .../agents/multimodal/multimodal_agent.py | 18 +-- .../plugins/google/beta/realtime/__init__.py | 2 + .../plugins/google/beta/realtime/stt.py | 151 ++++++++++++++++++ 5 files changed, 163 insertions(+), 15 deletions(-) create mode 100644 livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py index f9ac707f1..d597e48a9 100644 --- a/examples/multimodal-agent/gemini_agent.py +++ b/examples/multimodal-agent/gemini_agent.py @@ -14,7 +14,7 @@ llm, multimodal, ) -from livekit.plugins import deepgram, google +from livekit.plugins import google, silero load_dotenv() @@ -60,7 +60,8 @@ async def get_weather( ), fnc_ctx=fnc_ctx, chat_ctx=chat_ctx, - stt=deepgram.STT(), + stt=google.beta.realtime.STT(), + vad=silero.VAD.load(), ) agent.start(ctx.room, participant) diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index 8ed5a162c..a01d5b0ea 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -188,7 +188,7 @@ async def _stt_stream_co() -> None: for task in done: if handle.interrupted or task.exception(): - self._source.clear_queue() # make sure to remove any queued frames + self._source.clear_queue() # make sure to remove any queued frames break finally: diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index e9223a7c9..efab5757e 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -179,13 +179,13 @@ def __init__( transcription=transcription, ) - # if not stt.capabilities.streaming: - # from .. import stt as speech_to_text + if not stt.capabilities.streaming: + from .. import stt as speech_to_text - # stt = speech_to_text.StreamAdapter( - # stt=stt, - # vad=vad, - # ) + stt = speech_to_text.StreamAdapter( + stt=stt, + vad=vad, + ) self._stt = stt self.on("final_transcript", self._on_final_transcript) @@ -441,7 +441,6 @@ def _on_final_transcript(ev: stt.SpeechEvent): self._agent_playout.on("playout_stopped", _on_playout_stopped) self._agent_playout.on("final_transcript", _on_final_transcript) - await self._agent_publication.wait_for_subscription() bstream = utils.audio.AudioByteStream( @@ -469,8 +468,6 @@ def _link_participant(self, participant_identity: str) -> None: self._subscribe_to_microphone() - - def _subscribe_to_microphone(self, *args, **kwargs) -> None: """Subscribe to the participant microphone if found""" @@ -506,8 +503,6 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: ) break - - @utils.log_exceptions(logger=logger) async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: """ @@ -555,7 +550,6 @@ def _on_final_transcript(self, ev: stt.SpeechEvent): def _emit_speech_committed( self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False ): - if speaker == "user": self.emit("user_speech_committed", msg) else: diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py index e95a86917..a4c7b215c 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py @@ -5,6 +5,7 @@ Voice, ) from .realtime_api import RealtimeModel +from .stt import STT __all__ = [ "RealtimeModel", @@ -12,4 +13,5 @@ "LiveAPIModels", "ResponseModality", "Voice", + "STT", ] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py new file mode 100644 index 000000000..9153039a7 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py @@ -0,0 +1,151 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass + +from livekit import rtc +from livekit.agents import ( + APIConnectionError, + APIConnectOptions, + stt, + utils, +) + +from google import genai +from google.genai import types + +from .api_proto import LiveAPIModels + +SAMPLE_RATE = 16000 + +SYSTEM_INSTRUCTIONS = """ +You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text. + +**Guidelines:** + +1. **Transcription Only:** + - Transcribe spoken words exactly as they are. + - Exclude any non-speech sounds (e.g., background noise, music). + +2. **Response Format:** + - Provide only the transcription without any additional text or explanations. + - If the audio is unclear or inaudible, respond with: `...` + +3. **Accuracy:** + - Ensure the transcription is free from errors. + - Maintain the original meaning and context of the speech. + +4. **Clarity:** + - Use proper punctuation and formatting to enhance readability. + - Preserve the original speaker's intent and tone as much as possible. + +**Do Not:** +- Add any explanations, comments, or additional information. +- Include timestamps, speaker labels, or annotations unless specified. +""" + + +@dataclass +class STTOptions: + language: str + detect_language: bool + system_instructions: str + model: LiveAPIModels + + +class STT(stt.STT): + def __init__( + self, + *, + api_key: str | None = None, + language: str = "en-US", + detect_language: bool = True, + system_instructions: str = SYSTEM_INSTRUCTIONS, + model: LiveAPIModels = "gemini-2.0-flash-exp", + ): + """ + Create a new instance of Google Realtime STT. + """ + super().__init__( + capabilities=stt.STTCapabilities(streaming=False, interim_results=False) + ) + + self._config = STTOptions( + language=language, + model=model, + system_instructions=system_instructions, + detect_language=detect_language, + ) + self._api_key = api_key or os.getenv("GOOGLE_API_KEY") + self._client = genai.Client( + api_key=self._api_key, + ) + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: str | None, + conn_options: APIConnectOptions, + ) -> stt.SpeechEvent: + try: + data = rtc.combine_audio_frames(buffer).to_wav_bytes() + resp = await self._client.aio.models.generate_content( + model=self._config.model, + contents=types.Content( + parts=[ + types.Part( + text=self._config.system_instructions, + ), + types.Part( + inline_data=types.Blob( + mime_type="audio/wav", + data=data, + ) + ), + ], + ), + config=types.GenerationConfigDict( + response_mime_type="application/json", + response_schema={ + "required": [ + "transcribed_text", + "confidence_score", + "language", + ], + "properties": { + "transcribed_text": {"type": "STRING"}, + "confidence_score": {"type": "NUMBER"}, + "language": {"type": "STRING"}, + }, + "type": "OBJECT", + }, + ), + ) + resp = json.loads(resp.text) + return stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[ + stt.SpeechData( + text=resp.get("transcribed_text") or "", + language=resp.get("language") or self._config.language, + ) + ], + ) + except Exception as e: + raise APIConnectionError() from e From 06d60baf2e2db3cbe81e9260c54d20ec22755088 Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 2 Jan 2025 19:24:18 +0530 Subject: [PATCH 03/19] updates --- livekit-agents/livekit/agents/multimodal/multimodal_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index efab5757e..afc65cdee 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -179,7 +179,7 @@ def __init__( transcription=transcription, ) - if not stt.capabilities.streaming: + if stt is not None and not stt.capabilities.streaming: from .. import stt as speech_to_text stt = speech_to_text.StreamAdapter( From 48002617e74c1313c08c96d29c2b74682d1218ab Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 2 Jan 2025 19:58:44 +0530 Subject: [PATCH 04/19] updates --- livekit-agents/livekit/agents/multimodal/agent_playout.py | 6 +++--- .../livekit/agents/multimodal/multimodal_agent.py | 4 ++-- .../livekit/plugins/google/beta/realtime/stt.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index a01d5b0ea..4ce66cae1 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -74,8 +74,8 @@ def __init__( self, *, audio_source: rtc.AudioSource, - stt: stt.STT, - stt_forwarder: stt.STTForwarder, + stt: stt.STT | None, + stt_forwarder: transcription.STTSegmentsForwarder | None, ) -> None: super().__init__() self._source = audio_source @@ -180,7 +180,7 @@ async def _stt_stream_co() -> None: done, _ = await asyncio.wait( [asyncio.gather(*tasks), handle._int_fut], return_when=asyncio.FIRST_COMPLETED, - ) + ) # type: ignore handle._total_played_time = ( handle._pushed_duration - self._source.queued_duration diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index afc65cdee..2d24cfd85 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -179,7 +179,7 @@ def __init__( transcription=transcription, ) - if stt is not None and not stt.capabilities.streaming: + if stt and vad and not stt.capabilities.streaming: from .. import stt as speech_to_text stt = speech_to_text.StreamAdapter( @@ -488,7 +488,7 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: self._subscribed_track = publication.track # type: ignore stream_24khz = rtc.AudioStream( self._subscribed_track, sample_rate=24000, num_channels=1 - ) + ) # type: ignore self._stt_forwarder = STTSegmentsForwarder( room=self._room, participant=self._linked_participant, diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py index 9153039a7..4b791f408 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py @@ -26,8 +26,8 @@ utils, ) -from google import genai -from google.genai import types +from google import genai # type: ignore +from google.genai import types # type: ignore from .api_proto import LiveAPIModels From 97f5040105f426551622f9fb010daaf8a986b56b Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 2 Jan 2025 20:47:40 +0530 Subject: [PATCH 05/19] updates --- .../agents/multimodal/agent_playout.py | 6 ++-- .../agents/multimodal/multimodal_agent.py | 23 +++++++-------- .../plugins/google/beta/realtime/stt.py | 28 +++++-------------- 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index 4ce66cae1..ba3b0405a 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -161,14 +161,14 @@ async def _capture_task(): await self._source.wait_for_playout() async def _stt_stream_co() -> None: - if stt_stream is not None: + if stt_stream and self._stt_forwarder is not None: async for ev in stt_stream: self._stt_forwarder.update(ev) if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev) + self.emit("final_transcript", ev.alternatives[0].text) elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev) + self.emit("interim_transcript", ev.alternatives[0].text) read_text_task = asyncio.create_task(_play_text_stream()) diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index 2d24cfd85..ca11d729d 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -434,8 +434,8 @@ def _on_playout_stopped(interrupted: bool) -> None: self._emit_speech_committed("agent", collected_text, interrupted) - def _on_final_transcript(ev: stt.SpeechEvent): - self._emit_speech_committed("agent", ev.alternatives[0].text) + def _on_final_transcript(text: str): + self._emit_speech_committed("agent", text) self._agent_playout.on("playout_started", _on_playout_started) self._agent_playout.on("playout_stopped", _on_playout_stopped) @@ -486,9 +486,6 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: and publication.track != self._subscribed_track ): self._subscribed_track = publication.track # type: ignore - stream_24khz = rtc.AudioStream( - self._subscribed_track, sample_rate=24000, num_channels=1 - ) # type: ignore self._stt_forwarder = STTSegmentsForwarder( room=self._room, participant=self._linked_participant, @@ -499,20 +496,20 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: self._recognize_atask.cancel() self._recognize_atask = asyncio.create_task( - self._recognize_task(stream_24khz) + self._recognize_task(self._subscribed_track) # type: ignore ) break @utils.log_exceptions(logger=logger) - async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: + async def _recognize_task(self, track: rtc.LocalAudioTrack) -> None: """ Receive the frames from the user audio stream. """ - + stream_24khz = rtc.AudioStream(track, sample_rate=24000, num_channels=1) stt_stream = self._stt.stream() if self._stt is not None else None async def _micro_task() -> None: - async for ev in audio_stream: + async for ev in stream_24khz: if stt_stream is not None: stt_stream.push_frame(ev.frame) self._input_audio_ch.send_nowait(ev.frame) @@ -523,9 +520,9 @@ async def _stt_stream_co() -> None: self._stt_forwarder.update(ev) if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev) + self.emit("final_transcript", ev.alternatives[0].text) elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev) + self.emit("interim_transcript", ev.alternatives[0].text) tasks = [ asyncio.create_task(_micro_task()), @@ -544,8 +541,8 @@ def _ensure_session(self) -> aiohttp.ClientSession: return self._http_session - def _on_final_transcript(self, ev: stt.SpeechEvent): - self._emit_speech_committed("user", ev.alternatives[0].text) + def _on_final_transcript(self, text: str): + self._emit_speech_committed("user", text) def _emit_speech_committed( self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py index 4b791f408..d4b0dc395 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py @@ -36,27 +36,13 @@ SYSTEM_INSTRUCTIONS = """ You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text. -**Guidelines:** - -1. **Transcription Only:** - - Transcribe spoken words exactly as they are. - - Exclude any non-speech sounds (e.g., background noise, music). - -2. **Response Format:** - - Provide only the transcription without any additional text or explanations. - - If the audio is unclear or inaudible, respond with: `...` - -3. **Accuracy:** - - Ensure the transcription is free from errors. - - Maintain the original meaning and context of the speech. - -4. **Clarity:** - - Use proper punctuation and formatting to enhance readability. - - Preserve the original speaker's intent and tone as much as possible. - -**Do Not:** -- Add any explanations, comments, or additional information. -- Include timestamps, speaker labels, or annotations unless specified. +- Transcribe verbatim; exclude non-speech sounds. +- Provide only transcription; no extra text or explanations. +- If audio is unclear, respond with: `...` +- Ensure error-free transcription, preserving meaning and context. +- Use proper punctuation and formatting. +- Do not add explanations, comments, or extra information. +- Do not include timestamps, speaker labels, or annotations unless specified. """ From f5249f430397ca23e3b58e1d155264cdfed5a579 Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 2 Jan 2025 21:36:48 +0530 Subject: [PATCH 06/19] added testcase --- tests/test_stt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_stt.py b/tests/test_stt.py index d1f340b1e..104a908d3 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -29,6 +29,7 @@ ), pytest.param(lambda: openai.STT(), id="openai"), pytest.param(lambda: fal.WizperSTT(), id="fal"), + pytest.param(lambda: google.beta.realtime.STT(), id="google-realtime"), ] From 61da56a97b27d6993f067526dcc9d25172c5488a Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 3 Jan 2025 17:41:40 +0530 Subject: [PATCH 07/19] updates --- .../google/beta/realtime/realtime_api.py | 27 +++++--- .../plugins/google/beta/realtime/stt.py | 65 ++++++++++++------- 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 40bb0d7a1..86b09f93d 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -89,7 +89,7 @@ def __init__( voice: Voice | str = "Puck", modalities: ResponseModality = "AUDIO", vertexai: bool = False, - project: str | None = None, + project_id: str | None = None, location: str | None = None, candidate_count: int = 1, temperature: float | None = None, @@ -111,7 +111,7 @@ def __init__( voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck". temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False. - project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai) + project_id (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai) location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai) candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1. top_p (float, optional): The top-p value for response generation @@ -130,21 +130,30 @@ def __init__( self._model = model self._loop = loop or asyncio.get_event_loop() self._api_key = api_key or os.environ.get("GOOGLE_API_KEY") - self._vertexai = vertexai - self._project_id = project or os.environ.get("GOOGLE_PROJECT") + self._project_id = project_id or os.environ.get("GOOGLE_PROJECT_ID") self._location = location or os.environ.get("GOOGLE_LOCATION") - if self._api_key is None and not self._vertexai: - raise ValueError("GOOGLE_API_KEY is not set") + if vertexai: + if not self._project_id or not self._location: + raise ValueError( + "Project and location are required for VertexAI either via project and location or GOOGLE_PROJECT_ID and GOOGLE_LOCATION environment variables" + ) + self._api_key = None # VertexAI does not require an API key + + else: + if not self._api_key: + raise ValueError( + "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" + ) self._rt_sessions: list[GeminiRealtimeSession] = [] self._opts = ModelOptions( model=model, - api_key=api_key, + api_key=self._api_key, voice=voice, response_modalities=modalities, vertexai=vertexai, - project=project, - location=location, + project=self._project_id, + location=self._location, candidate_count=candidate_count, temperature=temperature, max_output_tokens=max_output_tokens, diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py index d4b0dc395..b292d3f14 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py @@ -49,7 +49,6 @@ @dataclass class STTOptions: language: str - detect_language: bool system_instructions: str model: LiveAPIModels @@ -60,52 +59,74 @@ def __init__( *, api_key: str | None = None, language: str = "en-US", - detect_language: bool = True, + vertexai: bool = False, + project_id: str | None = None, + location: str = "us-central1", system_instructions: str = SYSTEM_INSTRUCTIONS, model: LiveAPIModels = "gemini-2.0-flash-exp", ): """ - Create a new instance of Google Realtime STT. + Create a new instance of Google Realtime STT. you must provide either api_key or vertexai with project_id. api key and project id can be set via environment variables or via the arguments. + Args: + api_key (str, optional) : The API key to use for the API. + vertexai(bool, optional) : Whether to use VertexAI. + project_id(str, optional) : The project id to use for the vertex ai. + location (str, optional) : The location to use for the vertex ai. defaults to us-central1 + system_instructions (str, optional) : custom system instructions to use for the transcription. + language (str, optional) : The language of the audio. defaults to en-US + model (LiveAPIModels, optional) : The model to use for the transcription. defaults to gemini-2.0-flash-exp """ super().__init__( capabilities=stt.STTCapabilities(streaming=False, interim_results=False) ) self._config = STTOptions( - language=language, - model=model, - system_instructions=system_instructions, - detect_language=detect_language, + language=language, model=model, system_instructions=system_instructions ) - self._api_key = api_key or os.getenv("GOOGLE_API_KEY") + if vertexai: + self._project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") + self._location = location or os.getenv("GOOGLE_LOCATION") + if not self._project_id or not self._location: + raise ValueError( + "Project and location are required for VertexAI either via project and location or GOOGLE_PROJECT_ID and GOOGLE_LOCATION environment variables" + ) + self._api_key = None # VertexAI does not require an API key + else: + self._api_key = api_key or os.getenv("GOOGLE_API_KEY") + if not self._api_key: + raise ValueError( + "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" + ) self._client = genai.Client( api_key=self._api_key, + vertexai=vertexai, + project=self._project_id, + location=self._location, ) async def _recognize_impl( self, buffer: utils.AudioBuffer, *, - language: str | None, conn_options: APIConnectOptions, + language: str | None = None, ) -> stt.SpeechEvent: try: + instructions = self._config.system_instructions + if language: + instructions += "The language of the audio is " + language data = rtc.combine_audio_frames(buffer).to_wav_bytes() resp = await self._client.aio.models.generate_content( model=self._config.model, - contents=types.Content( - parts=[ - types.Part( - text=self._config.system_instructions, - ), - types.Part( - inline_data=types.Blob( - mime_type="audio/wav", - data=data, - ) - ), - ], - ), + contents=[ + types.Part(text=instructions), + types.Part( + inline_data=types.Blob( + mime_type="audio/wav", + data=data, + ) + ), + ], config=types.GenerationConfigDict( response_mime_type="application/json", response_schema={ From a88281b56327b14f53c2024fd4f3e0a592a039d1 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 01:50:00 +0530 Subject: [PATCH 08/19] updates --- examples/multimodal-agent/gemini_agent.py | 7 +- .../agents/multimodal/agent_playout.py | 54 ++---- .../agents/multimodal/multimodal_agent.py | 97 +++-------- .../plugins/google/beta/realtime/__init__.py | 2 - .../plugins/google/beta/realtime/api_proto.py | 44 +++++ .../google/beta/realtime/realtime_api.py | 102 ++++++++++- .../plugins/google/beta/realtime/stt.py | 158 ------------------ .../google/beta/realtime/transcriber.py | 135 +++++++++++++++ 8 files changed, 315 insertions(+), 284 deletions(-) delete mode 100644 livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py create mode 100644 livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py index d597e48a9..55258bd09 100644 --- a/examples/multimodal-agent/gemini_agent.py +++ b/examples/multimodal-agent/gemini_agent.py @@ -14,7 +14,7 @@ llm, multimodal, ) -from livekit.plugins import google, silero +from livekit.plugins import google load_dotenv() @@ -57,11 +57,12 @@ async def get_weather( voice="Charon", temperature=0.8, instructions="You are a helpful assistant", + api_key=api_key, ), fnc_ctx=fnc_ctx, chat_ctx=chat_ctx, - stt=google.beta.realtime.STT(), - vad=silero.VAD.load(), + # stt=google.beta.realtime.STT(), + # vad=silero.VAD.load(), ) agent.start(ctx.room, participant) diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index ba3b0405a..df46b48f1 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -4,13 +4,11 @@ from typing import AsyncIterable, Literal from livekit import rtc -from livekit.agents import stt, transcription, utils +from livekit.agents import transcription, utils from ..log import logger -EventTypes = Literal[ - "playout_started", "playout_stopped", "final_transcript", "interim_transcript" -] +EventTypes = Literal["playout_started", "playout_stopped"] class PlayoutHandle: @@ -70,17 +68,9 @@ def interrupt(self) -> None: class AgentPlayout(utils.EventEmitter[EventTypes]): - def __init__( - self, - *, - audio_source: rtc.AudioSource, - stt: stt.STT | None, - stt_forwarder: transcription.STTSegmentsForwarder | None, - ) -> None: + def __init__(self, *, audio_source: rtc.AudioSource) -> None: super().__init__() self._source = audio_source - self._stt = stt - self._stt_forwarder = stt_forwarder self._playout_atask: asyncio.Task[None] | None = None def play( @@ -116,7 +106,6 @@ async def _playout_task( await utils.aio.gracefully_cancel(old_task) first_frame = True - stt_stream = self._stt.stream() if self._stt is not None else None @utils.log_exceptions(logger=logger) async def _play_text_stream(): @@ -145,54 +134,37 @@ async def _capture_task(): handle._tr_fwd.push_audio(frame) for f in bstream.write(frame.data.tobytes()): - if stt_stream is not None: - stt_stream.push_frame(f) handle._pushed_duration += f.samples_per_channel / f.sample_rate await self._source.capture_frame(f) for f in bstream.flush(): handle._pushed_duration += f.samples_per_channel / f.sample_rate - if stt_stream is not None: - stt_stream.push_frame(f) await self._source.capture_frame(f) handle._tr_fwd.mark_audio_segment_end() await self._source.wait_for_playout() - async def _stt_stream_co() -> None: - if stt_stream and self._stt_forwarder is not None: - async for ev in stt_stream: - self._stt_forwarder.update(ev) - - if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev.alternatives[0].text) - elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev.alternatives[0].text) - read_text_task = asyncio.create_task(_play_text_stream()) + capture_task = asyncio.create_task(_capture_task()) - tasks = [ - asyncio.create_task(_capture_task()), - asyncio.create_task(_stt_stream_co()), - ] try: - done, _ = await asyncio.wait( - [asyncio.gather(*tasks), handle._int_fut], + await asyncio.wait( + [capture_task, handle._int_fut], return_when=asyncio.FIRST_COMPLETED, - ) # type: ignore + ) + + finally: + await utils.aio.gracefully_cancel(capture_task) handle._total_played_time = ( handle._pushed_duration - self._source.queued_duration ) - for task in done: - if handle.interrupted or task.exception(): - self._source.clear_queue() # make sure to remove any queued frames - break + if handle.interrupted or capture_task.exception(): + self._source.clear_queue() # make sure to remove any queued frames - finally: - await utils.aio.gracefully_cancel(*tasks, read_text_task) + await utils.aio.gracefully_cancel(read_text_task) # make sure the text_data.sentence_stream is closed handle._tr_fwd.mark_text_segment_end() diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index ca11d729d..2529a1255 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -16,10 +16,9 @@ import aiohttp from livekit import rtc -from livekit.agents import llm, stt, tokenize, utils, vad +from livekit.agents import llm, stt, tokenize, transcription, utils, vad from livekit.agents.llm import ChatMessage from livekit.agents.metrics import MultimodalLLMMetrics -from livekit.agents.transcription import STTSegmentsForwarder, TTSSegmentsForwarder from ..log import logger from ..types import ATTRIBUTE_AGENT_STATE, AgentState @@ -36,8 +35,6 @@ "function_calls_collected", "function_calls_finished", "metrics_collected", - "final_transcript", - "interim_transcript", ] @@ -146,7 +143,6 @@ def __init__( vad: vad.VAD | None = None, chat_ctx: llm.ChatContext | None = None, fnc_ctx: llm.FunctionContext | None = None, - stt: stt.STT | None = None, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), max_text_response_retries: int = 5, loop: asyncio.AbstractEventLoop | None = None, @@ -179,18 +175,8 @@ def __init__( transcription=transcription, ) - if stt and vad and not stt.capabilities.streaming: - from .. import stt as speech_to_text - - stt = speech_to_text.StreamAdapter( - stt=stt, - vad=vad, - ) - self._stt = stt - self.on("final_transcript", self._on_final_transcript) - # audio input - self._recognize_atask: asyncio.Task | None = None + self._read_micro_atask: asyncio.Task | None = None self._subscribed_track: rtc.RemoteAudioTrack | None = None self._input_audio_ch = utils.aio.Chan[rtc.AudioFrame]() @@ -210,10 +196,6 @@ def __init__( def vad(self) -> vad.VAD | None: return self._vad - @property - def stt(self) -> stt.STT | None: - return self._stt - @property def fnc_ctx(self) -> llm.FunctionContext | None: return self._session.fnc_ctx @@ -270,7 +252,7 @@ async def _init_and_start(): @self._session.on("response_content_added") def _on_content_added(message: _ContentProto): - tr_fwd = TTSSegmentsForwarder( + tr_fwd = transcription.TTSSegmentsForwarder( room=self._room, participant=self._room.local_participant, speed=self._opts.transcription.agent_transcription_speed, @@ -339,6 +321,16 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto): self._emit_speech_committed("user", ev.transcript) + @self._session.on("agent_speech_transcription_completed") + def _agent_speech_transcription_completed(ev: _InputTranscriptionProto): + self._agent_stt_forwarder.update( + stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[stt.SpeechData(language="", text=ev.transcript)], + ) + ) + self._emit_speech_committed("agent", ev.transcript) + @self._session.on("input_speech_started") def _input_speech_started(): self.emit("user_started_speaking") @@ -398,15 +390,13 @@ async def _main_task(self) -> None: self._agent_publication = await self._room.local_participant.publish_track( track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) ) - self._agent_stt_forwarder = STTSegmentsForwarder( + self._agent_stt_forwarder = transcription.STTSegmentsForwarder( room=self._room, participant=self._room.local_participant, track=track, ) self._agent_playout = agent_playout.AgentPlayout( - audio_source=self._audio_source, - stt=self._stt, - stt_forwarder=self._agent_stt_forwarder, + audio_source=self._audio_source ) def _on_playout_started() -> None: @@ -434,12 +424,8 @@ def _on_playout_stopped(interrupted: bool) -> None: self._emit_speech_committed("agent", collected_text, interrupted) - def _on_final_transcript(text: str): - self._emit_speech_committed("agent", text) - self._agent_playout.on("playout_started", _on_playout_started) self._agent_playout.on("playout_stopped", _on_playout_stopped) - self._agent_playout.on("final_transcript", _on_final_transcript) await self._agent_publication.wait_for_subscription() @@ -468,6 +454,11 @@ def _link_participant(self, participant_identity: str) -> None: self._subscribe_to_microphone() + async def _micro_task(self, track: rtc.LocalAudioTrack) -> None: + stream_24khz = rtc.AudioStream(track, sample_rate=24000, num_channels=1) + async for ev in stream_24khz: + self._input_audio_ch.send_nowait(ev.frame) + def _subscribe_to_microphone(self, *args, **kwargs) -> None: """Subscribe to the participant microphone if found""" @@ -486,64 +477,26 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: and publication.track != self._subscribed_track ): self._subscribed_track = publication.track # type: ignore - self._stt_forwarder = STTSegmentsForwarder( + self._stt_forwarder = transcription.STTSegmentsForwarder( room=self._room, participant=self._linked_participant, track=self._subscribed_track, ) - if self._recognize_atask is not None: - self._recognize_atask.cancel() + if self._read_micro_atask is not None: + self._read_micro_atask.cancel() - self._recognize_atask = asyncio.create_task( - self._recognize_task(self._subscribed_track) # type: ignore + self._read_micro_atask = asyncio.create_task( + self._micro_task(self._subscribed_track) # type: ignore ) break - @utils.log_exceptions(logger=logger) - async def _recognize_task(self, track: rtc.LocalAudioTrack) -> None: - """ - Receive the frames from the user audio stream. - """ - stream_24khz = rtc.AudioStream(track, sample_rate=24000, num_channels=1) - stt_stream = self._stt.stream() if self._stt is not None else None - - async def _micro_task() -> None: - async for ev in stream_24khz: - if stt_stream is not None: - stt_stream.push_frame(ev.frame) - self._input_audio_ch.send_nowait(ev.frame) - - async def _stt_stream_co() -> None: - if stt_stream is not None: - async for ev in stt_stream: - self._stt_forwarder.update(ev) - - if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev.alternatives[0].text) - elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev.alternatives[0].text) - - tasks = [ - asyncio.create_task(_micro_task()), - asyncio.create_task(_stt_stream_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - if stt_stream is not None: - await stt_stream.aclose() - def _ensure_session(self) -> aiohttp.ClientSession: if not self._http_session: self._http_session = utils.http_context.http_session() return self._http_session - def _on_final_transcript(self, text: str): - self._emit_speech_committed("user", text) - def _emit_speech_committed( self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False ): diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py index a4c7b215c..e95a86917 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py @@ -5,7 +5,6 @@ Voice, ) from .realtime_api import RealtimeModel -from .stt import STT __all__ = [ "RealtimeModel", @@ -13,5 +12,4 @@ "LiveAPIModels", "ResponseModality", "Voice", - "STT", ] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py index c02fb3859..f79569df4 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py @@ -1,10 +1,22 @@ from __future__ import annotations import inspect +import json from typing import Any, Dict, List, Literal, Sequence, Union +from livekit.agents import llm + from google.genai import types # type: ignore +__all__ = [ + "ClientEvents", + "LiveAPIModels", + "ResponseModality", + "Voice", + "_build_gemini_ctx", + "_build_tools", +] + LiveAPIModels = Literal["gemini-2.0-flash-exp"] Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"] @@ -77,3 +89,35 @@ def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]: function_declarations.append(func_decl) return function_declarations + + +def _build_gemini_ctx(chat_ctx: llm.ChatContext) -> List[types.Content]: + content = None + turns = [] + + for msg in chat_ctx.messages: + role = None + if msg.role == "assistant": + role = "model" + elif msg.role in {"system", "user"}: + role = "user" + elif msg.role == "tool": + continue + + if content and content.role == role: + if isinstance(msg.content, str): + content.parts.append(types.Part(text=msg.content)) + elif isinstance(msg.content, dict): + content.parts.append(types.Part(text=json.dumps(msg.content))) + elif isinstance(msg.content, list): + for item in msg.content: + if isinstance(item, str): + content.parts.append(types.Part(text=item)) + else: + content = types.Content( + parts=[types.Part(text=msg.content)], + role=role, + ) + turns.append(content) + + return turns diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 86b09f93d..d3e66a92d 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -15,6 +15,7 @@ from google.genai.types import ( # type: ignore FunctionResponse, GenerationConfigDict, + LiveClientContent, LiveClientToolResponse, LiveConnectConfigDict, PrebuiltVoiceConfig, @@ -28,8 +29,10 @@ LiveAPIModels, ResponseModality, Voice, + _build_gemini_ctx, _build_tools, ) +from .transcriber import TranscriberSession, TranscriptionContent EventTypes = Literal[ "start_session", @@ -39,6 +42,9 @@ "function_calls_collected", "function_calls_finished", "function_calls_cancelled", + "input_speech_transcription_completed", + "agent_speech_transcription_completed", + "agent_speech_transcription_interrupted", ] @@ -55,6 +61,12 @@ class GeminiContent: content_type: Literal["text", "audio"] +@dataclass +class InputTranscription: + item_id: str + transcript: str + + @dataclass class Capabilities: supports_truncate: bool @@ -77,6 +89,7 @@ class ModelOptions: presence_penalty: float | None frequency_penalty: float | None instructions: str + enable_transcription: bool class RealtimeModel: @@ -88,6 +101,7 @@ def __init__( api_key: str | None = None, voice: Voice | str = "Puck", modalities: ResponseModality = "AUDIO", + enable_transcription: bool = True, vertexai: bool = False, project_id: str | None = None, location: str | None = None, @@ -109,6 +123,7 @@ def __init__( modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"]. model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp". voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck". + enable_transcription (bool, optional): Whether to enable transcription. Defaults to True temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False. project_id (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai) @@ -150,6 +165,7 @@ def __init__( model=model, api_key=self._api_key, voice=voice, + enable_transcription=enable_transcription, response_modalities=modalities, vertexai=vertexai, project=self._project_id, @@ -224,7 +240,6 @@ def __init__( tools.append({"function_declarations": functions}) self._config = LiveConnectConfigDict( - model=self._opts.model, response_modalities=self._opts.response_modalities, generation_config=GenerationConfigDict( candidate_count=self._opts.candidate_count, @@ -255,10 +270,20 @@ def __init__( self._main_atask = asyncio.create_task( self._main_task(), name="gemini-realtime-session" ) - # dummy task to wait for the session to be initialized # TODO: sync chat ctx - self._init_sync_task = asyncio.create_task( - asyncio.sleep(0), name="gemini-realtime-session-init" - ) + if self._opts.enable_transcription: + self._transcriber = TranscriberSession( + client=self._client, model=self._opts.model + ) + self._agent_transcriber = TranscriberSession( + client=self._client, model=self._opts.model + ) + self._transcriber.on("input_speech_done", self._on_input_speech_done) + self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done) + self._agent_transcriber.on( + "input_speech_interrupted", self._on_agent_speech_interrupted + ) + # init dummy task + self._init_sync_task = asyncio.create_task(asyncio.sleep(0)) self._send_ch = utils.aio.Chan[ClientEvents]() self._active_response_id = None @@ -277,19 +302,78 @@ def fnc_ctx(self) -> llm.FunctionContext | None: def fnc_ctx(self, value: llm.FunctionContext | None) -> None: self._fnc_ctx = value + def _update_conversation_item_content(self, item_id: str, content: str) -> None: + pass + def _push_audio(self, frame: rtc.AudioFrame) -> None: - data = base64.b64encode(frame.data).decode("utf-8") - self._queue_msg({"mime_type": "audio/pcm", "data": data}) + if self._opts.enable_transcription: + self._transcriber._push_audio(frame) + else: + data = base64.b64encode(frame.data).decode("utf-8") + self._queue_msg({"mime_type": "audio/pcm", "data": data}) - def _queue_msg(self, msg: dict) -> None: + def _queue_msg(self, msg: ClientEvents) -> None: self._send_ch.send_nowait(msg) + def create_conversation( + self, chat_ctx: llm.ChatContext | llm.ChatMessage, turn_complete: bool = True + ) -> None: + if isinstance(chat_ctx, llm.ChatMessage): + new_chat_ctx = llm.ChatContext() + new_chat_ctx.append(text=chat_ctx.content, role=chat_ctx.role) + else: + new_chat_ctx = chat_ctx + gemini_ctx = _build_gemini_ctx(new_chat_ctx) + client_content = LiveClientContent( + turn_complete=turn_complete, + turns=gemini_ctx, + ) + self._queue_msg(client_content) + def chat_ctx_copy(self) -> llm.ChatContext: return self._chat_ctx.copy() async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: self._chat_ctx = ctx.copy() + def _on_input_speech_done(self, content: TranscriptionContent) -> None: + self.emit( + "input_speech_transcription_completed", + InputTranscription( + item_id=content.response_id, + transcript=content.text, + ), + ) + + self._chat_ctx.append(text=content.text, role="user") + conversation = _build_gemini_ctx(self._chat_ctx) + + client_content = LiveClientContent( + turn_complete=True, + turns=conversation, + ) + self._queue_msg(client_content) + + def _on_agent_speech_done(self, content: TranscriptionContent) -> None: + self.emit( + "agent_speech_transcription_completed", + InputTranscription( + item_id=content.response_id, + transcript=content.text, + ), + ) + self._chat_ctx.append(text=content.text, role="assistant") + + def _on_agent_speech_interrupted(self, content: TranscriptionContent) -> None: + self.emit( + "agent_speech_transcription_completed", + InputTranscription( + item_id=content.response_id, + transcript=content.text, + ), + ) + self._chat_ctx.append(text=content.text, role="assistant") + @utils.log_exceptions(logger=logger) async def _main_task(self): @utils.log_exceptions(logger=logger) @@ -335,6 +419,8 @@ async def _recv_task(): samples_per_channel=len(part.inline_data.data) // 2, ) + if self._opts.enable_transcription: + self._agent_transcriber._push_audio(frame) content.audio_stream.send_nowait(frame) if server_content.interrupted or server_content.turn_complete: diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py deleted file mode 100644 index b292d3f14..000000000 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/stt.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2023 LiveKit, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import json -import os -from dataclasses import dataclass - -from livekit import rtc -from livekit.agents import ( - APIConnectionError, - APIConnectOptions, - stt, - utils, -) - -from google import genai # type: ignore -from google.genai import types # type: ignore - -from .api_proto import LiveAPIModels - -SAMPLE_RATE = 16000 - -SYSTEM_INSTRUCTIONS = """ -You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text. - -- Transcribe verbatim; exclude non-speech sounds. -- Provide only transcription; no extra text or explanations. -- If audio is unclear, respond with: `...` -- Ensure error-free transcription, preserving meaning and context. -- Use proper punctuation and formatting. -- Do not add explanations, comments, or extra information. -- Do not include timestamps, speaker labels, or annotations unless specified. -""" - - -@dataclass -class STTOptions: - language: str - system_instructions: str - model: LiveAPIModels - - -class STT(stt.STT): - def __init__( - self, - *, - api_key: str | None = None, - language: str = "en-US", - vertexai: bool = False, - project_id: str | None = None, - location: str = "us-central1", - system_instructions: str = SYSTEM_INSTRUCTIONS, - model: LiveAPIModels = "gemini-2.0-flash-exp", - ): - """ - Create a new instance of Google Realtime STT. you must provide either api_key or vertexai with project_id. api key and project id can be set via environment variables or via the arguments. - Args: - api_key (str, optional) : The API key to use for the API. - vertexai(bool, optional) : Whether to use VertexAI. - project_id(str, optional) : The project id to use for the vertex ai. - location (str, optional) : The location to use for the vertex ai. defaults to us-central1 - system_instructions (str, optional) : custom system instructions to use for the transcription. - language (str, optional) : The language of the audio. defaults to en-US - model (LiveAPIModels, optional) : The model to use for the transcription. defaults to gemini-2.0-flash-exp - """ - super().__init__( - capabilities=stt.STTCapabilities(streaming=False, interim_results=False) - ) - - self._config = STTOptions( - language=language, model=model, system_instructions=system_instructions - ) - if vertexai: - self._project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") - self._location = location or os.getenv("GOOGLE_LOCATION") - if not self._project_id or not self._location: - raise ValueError( - "Project and location are required for VertexAI either via project and location or GOOGLE_PROJECT_ID and GOOGLE_LOCATION environment variables" - ) - self._api_key = None # VertexAI does not require an API key - else: - self._api_key = api_key or os.getenv("GOOGLE_API_KEY") - if not self._api_key: - raise ValueError( - "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" - ) - self._client = genai.Client( - api_key=self._api_key, - vertexai=vertexai, - project=self._project_id, - location=self._location, - ) - - async def _recognize_impl( - self, - buffer: utils.AudioBuffer, - *, - conn_options: APIConnectOptions, - language: str | None = None, - ) -> stt.SpeechEvent: - try: - instructions = self._config.system_instructions - if language: - instructions += "The language of the audio is " + language - data = rtc.combine_audio_frames(buffer).to_wav_bytes() - resp = await self._client.aio.models.generate_content( - model=self._config.model, - contents=[ - types.Part(text=instructions), - types.Part( - inline_data=types.Blob( - mime_type="audio/wav", - data=data, - ) - ), - ], - config=types.GenerationConfigDict( - response_mime_type="application/json", - response_schema={ - "required": [ - "transcribed_text", - "confidence_score", - "language", - ], - "properties": { - "transcribed_text": {"type": "STRING"}, - "confidence_score": {"type": "NUMBER"}, - "language": {"type": "STRING"}, - }, - "type": "OBJECT", - }, - ), - ) - resp = json.loads(resp.text) - return stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[ - stt.SpeechData( - text=resp.get("transcribed_text") or "", - language=resp.get("language") or self._config.language, - ) - ], - ) - except Exception as e: - raise APIConnectionError() from e diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py new file mode 100644 index 000000000..e17a7908b --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import asyncio +import base64 +from dataclasses import dataclass +from typing import Literal + +from livekit import rtc +from livekit.agents import utils + +from google import genai # type: ignore +from google.genai import types # type: ignore + +from ...log import logger +from .api_proto import ClientEvents, LiveAPIModels + +EventTypes = Literal[ + "input_speech_started", + "input_speech_interrupted", + "input_speech_done", +] + +SYSTEM_INSTRUCTIONS = """ +You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text. + +- Transcribe verbatim; exclude non-speech sounds. +- Provide only transcription; no extra text or explanations. +- If audio is unclear, respond with: `...` +- Ensure error-free transcription, preserving meaning and context. +- Use proper punctuation and formatting. +- Do not add explanations, comments, or extra information. +- Do not include timestamps, speaker labels, or annotations unless specified. +""" + + +@dataclass +class TranscriptionContent: + response_id: str + text: str + + +class TranscriberSession(utils.EventEmitter[EventTypes]): + def __init__( + self, + *, + client: genai.Client, + model: LiveAPIModels, + ): + """ + Initializes a TranscriberSession instance for interacting with Google's Realtime API. + """ + super().__init__() + self._client = client + self._model = model + + self._config = types.LiveConnectConfigDict( + response_modalities="TEXT", + system_instruction=SYSTEM_INSTRUCTIONS, + generation_config=types.GenerationConfigDict( + temperature=0.0, + ), + ) + self._main_atask = asyncio.create_task( + self._main_task(), name="gemini-realtime-transcriber" + ) + self._send_ch = utils.aio.Chan[ClientEvents]() + self._active_response_id = None + + def _push_audio(self, frame: rtc.AudioFrame) -> None: + data = base64.b64encode(frame.data).decode("utf-8") + self._queue_msg({"mime_type": "audio/pcm", "data": data}) + + def _queue_msg(self, msg: dict) -> None: + self._send_ch.send_nowait(msg) + + async def aclose(self) -> None: + if self._send_ch.closed: + return + + self._send_ch.close() + await self._main_atask + + @utils.log_exceptions(logger=logger) + async def _main_task(self): + @utils.log_exceptions(logger=logger) + async def _send_task(): + async for msg in self._send_ch: + await self._session.send(msg) + + @utils.log_exceptions(logger=logger) + async def _recv_task(): + while True: + async for response in self._session.receive(): + if self._active_response_id is None: + self._active_response_id = utils.shortuuid() + content = TranscriptionContent( + response_id=self._active_response_id, + text="", + ) + self.emit("input_speech_started", content) + + server_content = response.server_content + if server_content: + model_turn = server_content.model_turn + if model_turn: + for part in model_turn.parts: + if part.text: + content.text += part.text + + if server_content.interrupted or server_content.turn_complete: + if server_content.interrupted: + self.emit("input_speech_interrupted", content) + elif server_content.turn_complete: + self.emit("input_speech_done", content) + + self._active_response_id = None + + async with self._client.aio.live.connect( + model=self._model, config=self._config + ) as session: + self._session = session + tasks = [ + asyncio.create_task( + _send_task(), name="gemini-realtime-transcriber-send" + ), + asyncio.create_task( + _recv_task(), name="gemini-realtime-transcriber-recv" + ), + ] + + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + await self._session.close() From 3ac01ac69eba5a3844b7dca617488ee3611afd72 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 01:53:05 +0530 Subject: [PATCH 09/19] updates --- examples/multimodal-agent/gemini_agent.py | 71 ----------------------- 1 file changed, 71 deletions(-) delete mode 100644 examples/multimodal-agent/gemini_agent.py diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py deleted file mode 100644 index 55258bd09..000000000 --- a/examples/multimodal-agent/gemini_agent.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Annotated - -import aiohttp -from dotenv import load_dotenv -from livekit.agents import ( - AutoSubscribe, - JobContext, - WorkerOptions, - WorkerType, - cli, - llm, - multimodal, -) -from livekit.plugins import google - -load_dotenv() - -logger = logging.getLogger("my-worker") -logger.setLevel(logging.INFO) - - -async def entrypoint(ctx: JobContext): - logger.info("starting entrypoint") - - fnc_ctx = llm.FunctionContext() - - @fnc_ctx.ai_callable() - async def get_weather( - location: Annotated[ - str, llm.TypeInfo(description="The location to get the weather for") - ], - ): - """Called when the user asks about the weather. This function will return the weather for the given location.""" - logger.info(f"getting weather for {location}") - url = f"https://wttr.in/{location}?format=%C+%t" - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - weather_data = await response.text() - # # response from the function call is returned to the LLM - return f"The weather in {location} is {weather_data}." - else: - raise Exception( - f"Failed to get weather data, status code: {response.status}" - ) - - await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - participant = await ctx.wait_for_participant() - - chat_ctx = llm.ChatContext() - - agent = multimodal.MultimodalAgent( - model=google.beta.realtime.RealtimeModel( - voice="Charon", - temperature=0.8, - instructions="You are a helpful assistant", - api_key=api_key, - ), - fnc_ctx=fnc_ctx, - chat_ctx=chat_ctx, - # stt=google.beta.realtime.STT(), - # vad=silero.VAD.load(), - ) - agent.start(ctx.room, participant) - - -if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) From aee4c1c3934e4dc12308fd7ad5b226e2ae33cb93 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 02:00:54 +0530 Subject: [PATCH 10/19] updates --- tests/test_stt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_stt.py b/tests/test_stt.py index 104a908d3..d1f340b1e 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -29,7 +29,6 @@ ), pytest.param(lambda: openai.STT(), id="openai"), pytest.param(lambda: fal.WizperSTT(), id="fal"), - pytest.param(lambda: google.beta.realtime.STT(), id="google-realtime"), ] From f42cbe19222d6f6f79374a104e8b262bc82db74a Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 04:47:49 +0530 Subject: [PATCH 11/19] updates --- examples/multimodal-agent/gemini_agent.py | 77 +++++++++++++++++++ .../agents/multimodal/agent_playout.py | 1 - 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 examples/multimodal-agent/gemini_agent.py diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py new file mode 100644 index 000000000..eceb8f99c --- /dev/null +++ b/examples/multimodal-agent/gemini_agent.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import logging +from typing import Annotated + +import aiohttp +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + WorkerOptions, + WorkerType, + cli, + llm, + multimodal, +) +from livekit.plugins import google + +load_dotenv() + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + + +async def entrypoint(ctx: JobContext): + logger.info("starting entrypoint") + + fnc_ctx = llm.FunctionContext() + + @fnc_ctx.ai_callable() + async def get_weather( + location: Annotated[ + str, llm.TypeInfo(description="The location to get the weather for") + ], + ): + """Called when the user asks about the weather. This function will return the weather for the given location.""" + logger.info(f"getting weather for {location}") + url = f"https://wttr.in/{location}?format=%C+%t" + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + weather_data = await response.text() + # # response from the function call is returned to the LLM + return f"The weather in {location} is {weather_data}." + else: + raise Exception( + f"Failed to get weather data, status code: {response.status}" + ) + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + participant = await ctx.wait_for_participant() + + # chat_ctx is used to serve as initial context, Agent will start the conversation first if chat_ctx is provided + chat_ctx = llm.ChatContext() + chat_ctx.append(text="What is LiveKit?", role="user") + chat_ctx.append(text="LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.", role="assistant") + chat_ctx.append(text="What is the LiveKit Agents framework?", role="user") + + agent = multimodal.MultimodalAgent( + model=google.beta.realtime.RealtimeModel( + voice="Charon", + temperature=0.8, + instructions=""" + You are a helpful assistant + Here are some helpful information about LiveKit and its products and services: + - LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing. + - LiveKit provides an Agents framework for building server-side AI agents, client SDKs for building frontends, and LiveKit Cloud is a global network that transports voice, video, and data traffic in realtime. + """, + ), + fnc_ctx=fnc_ctx, + chat_ctx=chat_ctx, + ) + agent.start(ctx.room, participant) + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) \ No newline at end of file diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py index df46b48f1..f1dbda1e7 100644 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ b/livekit-agents/livekit/agents/multimodal/agent_playout.py @@ -153,7 +153,6 @@ async def _capture_task(): [capture_task, handle._int_fut], return_when=asyncio.FIRST_COMPLETED, ) - finally: await utils.aio.gracefully_cancel(capture_task) From 99c4fa7f35b07931d830c8018ba45499c740d3b7 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 05:05:46 +0530 Subject: [PATCH 12/19] updates --- .../livekit/plugins/google/beta/realtime/realtime_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index d3e66a92d..e39ceef5e 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -286,6 +286,8 @@ def __init__( self._init_sync_task = asyncio.create_task(asyncio.sleep(0)) self._send_ch = utils.aio.Chan[ClientEvents]() self._active_response_id = None + if chat_ctx: + self.create_conversation(chat_ctx) async def aclose(self) -> None: if self._send_ch.closed: From 39110370f452c13509da319e597a9504ca1e05a2 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 15:50:12 +0530 Subject: [PATCH 13/19] addressing sdk update --- .../livekit/plugins/google/beta/realtime/realtime_api.py | 2 +- .../livekit/plugins/google/beta/realtime/transcriber.py | 2 +- livekit-plugins/livekit-plugins-google/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index e39ceef5e..2e706f35b 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -381,7 +381,7 @@ async def _main_task(self): @utils.log_exceptions(logger=logger) async def _send_task(): async for msg in self._send_ch: - await self._session.send(msg) + await self._session.send(input=msg) await self._session.send(".", end_of_turn=True) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py index e17a7908b..89bc243dc 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py @@ -85,7 +85,7 @@ async def _main_task(self): @utils.log_exceptions(logger=logger) async def _send_task(): async for msg in self._send_ch: - await self._session.send(msg) + await self._session.send(input=msg) @utils.log_exceptions(logger=logger) async def _recv_task(): diff --git a/livekit-plugins/livekit-plugins-google/setup.py b/livekit-plugins/livekit-plugins-google/setup.py index 0db8addce..b3f6adf0d 100644 --- a/livekit-plugins/livekit-plugins-google/setup.py +++ b/livekit-plugins/livekit-plugins-google/setup.py @@ -51,7 +51,7 @@ "google-auth >= 2, < 3", "google-cloud-speech >= 2, < 3", "google-cloud-texttospeech >= 2, < 3", - "google-genai >= 0.3.0", + "google-genai == 0.4.0", "livekit-agents>=0.12.3", ], package_data={"livekit.plugins.google": ["py.typed"]}, From 7eb6766b520ae70f8671225270e7b3d754e30ee8 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 11 Jan 2025 16:10:07 +0530 Subject: [PATCH 14/19] ruff --- examples/multimodal-agent/gemini_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py index eceb8f99c..dbee55618 100644 --- a/examples/multimodal-agent/gemini_agent.py +++ b/examples/multimodal-agent/gemini_agent.py @@ -53,7 +53,10 @@ async def get_weather( # chat_ctx is used to serve as initial context, Agent will start the conversation first if chat_ctx is provided chat_ctx = llm.ChatContext() chat_ctx.append(text="What is LiveKit?", role="user") - chat_ctx.append(text="LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.", role="assistant") + chat_ctx.append( + text="LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.", + role="assistant", + ) chat_ctx.append(text="What is the LiveKit Agents framework?", role="user") agent = multimodal.MultimodalAgent( @@ -74,4 +77,4 @@ async def get_weather( if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) \ No newline at end of file + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) From e8617e8937c89fe50480d41e1b1aed3ddd05f1b5 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sun, 12 Jan 2025 19:47:51 +0530 Subject: [PATCH 15/19] updates --- examples/multimodal-agent/gemini_agent.py | 2 +- .../agents/multimodal/multimodal_agent.py | 25 +++++--- .../google/beta/realtime/realtime_api.py | 60 ++++++------------- .../google/beta/realtime/transcriber.py | 9 +-- 4 files changed, 37 insertions(+), 59 deletions(-) diff --git a/examples/multimodal-agent/gemini_agent.py b/examples/multimodal-agent/gemini_agent.py index dbee55618..0b7a191d6 100644 --- a/examples/multimodal-agent/gemini_agent.py +++ b/examples/multimodal-agent/gemini_agent.py @@ -61,7 +61,7 @@ async def get_weather( agent = multimodal.MultimodalAgent( model=google.beta.realtime.RealtimeModel( - voice="Charon", + voice="Puck", temperature=0.8, instructions=""" You are a helpful assistant diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index 2529a1255..a9eb1eedb 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -311,13 +311,14 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto): alternatives=[stt.SpeechData(language="", text=ev.transcript)], ) ) - user_msg = ChatMessage.create( - text=ev.transcript, role="user", id=ev.item_id - ) + if self._model.capabilities.supports_truncate: + user_msg = ChatMessage.create( + text=ev.transcript, role="user", id=ev.item_id + ) - self._session._update_conversation_item_content( - ev.item_id, user_msg.content - ) + self._session._update_conversation_item_content( + ev.item_id, user_msg.content + ) self._emit_speech_committed("user", ev.transcript) @@ -331,6 +332,12 @@ def _agent_speech_transcription_completed(ev: _InputTranscriptionProto): ) self._emit_speech_committed("agent", ev.transcript) + @self._session.on("agent_speech_completed") + def _agent_speech_completed(): + self._update_state("listening") + if self._playing_handle is not None and not self._playing_handle.done(): + self._playing_handle.interrupt() + @self._session.on("input_speech_started") def _input_speech_started(): self.emit("user_started_speaking") @@ -371,9 +378,9 @@ async def _run_task(delay: float) -> None: await asyncio.sleep(delay) if self._room.isconnected(): - await self._room.local_participant.set_attributes( - {ATTRIBUTE_AGENT_STATE: state} - ) + await self._room.local_participant.set_attributes({ + ATTRIBUTE_AGENT_STATE: state + }) if self._update_state_task is not None: self._update_state_task.cancel() diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 2e706f35b..92c9ad305 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -44,7 +44,6 @@ "function_calls_cancelled", "input_speech_transcription_completed", "agent_speech_transcription_completed", - "agent_speech_transcription_interrupted", ] @@ -233,6 +232,7 @@ def __init__( self._chat_ctx = chat_ctx self._fnc_ctx = fnc_ctx self._fnc_tasks = utils.aio.TaskSet() + self._is_interrupted = False tools = [] if self._fnc_ctx is not None: @@ -279,15 +279,12 @@ def __init__( ) self._transcriber.on("input_speech_done", self._on_input_speech_done) self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done) - self._agent_transcriber.on( - "input_speech_interrupted", self._on_agent_speech_interrupted - ) # init dummy task self._init_sync_task = asyncio.create_task(asyncio.sleep(0)) self._send_ch = utils.aio.Chan[ClientEvents]() self._active_response_id = None if chat_ctx: - self.create_conversation(chat_ctx) + self.generate_reply(chat_ctx) async def aclose(self) -> None: if self._send_ch.closed: @@ -304,20 +301,16 @@ def fnc_ctx(self) -> llm.FunctionContext | None: def fnc_ctx(self, value: llm.FunctionContext | None) -> None: self._fnc_ctx = value - def _update_conversation_item_content(self, item_id: str, content: str) -> None: - pass - def _push_audio(self, frame: rtc.AudioFrame) -> None: if self._opts.enable_transcription: self._transcriber._push_audio(frame) - else: - data = base64.b64encode(frame.data).decode("utf-8") - self._queue_msg({"mime_type": "audio/pcm", "data": data}) + data = base64.b64encode(frame.data).decode("utf-8") + self._queue_msg({"mime_type": "audio/pcm", "data": data}) def _queue_msg(self, msg: ClientEvents) -> None: self._send_ch.send_nowait(msg) - def create_conversation( + def generate_reply( self, chat_ctx: llm.ChatContext | llm.ChatMessage, turn_complete: bool = True ) -> None: if isinstance(chat_ctx, llm.ChatMessage): @@ -348,33 +341,17 @@ def _on_input_speech_done(self, content: TranscriptionContent) -> None: ) self._chat_ctx.append(text=content.text, role="user") - conversation = _build_gemini_ctx(self._chat_ctx) - - client_content = LiveClientContent( - turn_complete=True, - turns=conversation, - ) - self._queue_msg(client_content) def _on_agent_speech_done(self, content: TranscriptionContent) -> None: - self.emit( - "agent_speech_transcription_completed", - InputTranscription( - item_id=content.response_id, - transcript=content.text, - ), - ) - self._chat_ctx.append(text=content.text, role="assistant") - - def _on_agent_speech_interrupted(self, content: TranscriptionContent) -> None: - self.emit( - "agent_speech_transcription_completed", - InputTranscription( - item_id=content.response_id, - transcript=content.text, - ), - ) - self._chat_ctx.append(text=content.text, role="assistant") + if not self._is_interrupted: + self.emit( + "agent_speech_transcription_completed", + InputTranscription( + item_id=content.response_id, + transcript=content.text, + ), + ) + self._chat_ctx.append(text=content.text, role="assistant") @utils.log_exceptions(logger=logger) async def _main_task(self): @@ -390,6 +367,7 @@ async def _recv_task(): while True: async for response in self._session.receive(): if self._active_response_id is None: + self._is_interrupted = False self._active_response_id = utils.shortuuid() text_stream = utils.aio.Chan[str]() audio_stream = utils.aio.Chan[rtc.AudioFrame]() @@ -430,10 +408,8 @@ async def _recv_task(): if isinstance(stream, utils.aio.Chan): stream.close() - if server_content.interrupted: - self.emit("input_speech_started") - elif server_content.turn_complete: - self.emit("response_content_done", content) + self.emit("agent_speech_completed") + self._is_interrupted = True self._active_response_id = None @@ -516,6 +492,6 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str) ) ] ) - await self._session.send(tool_response) + await self._session.send(input=tool_response) self.emit("function_calls_finished", [called_fnc]) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py index 89bc243dc..5b80de8d2 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py @@ -16,7 +16,6 @@ EventTypes = Literal[ "input_speech_started", - "input_speech_interrupted", "input_speech_done", ] @@ -107,12 +106,8 @@ async def _recv_task(): if part.text: content.text += part.text - if server_content.interrupted or server_content.turn_complete: - if server_content.interrupted: - self.emit("input_speech_interrupted", content) - elif server_content.turn_complete: - self.emit("input_speech_done", content) - + if server_content.turn_complete: + self.emit("input_speech_done", content) self._active_response_id = None async with self._client.aio.live.connect( From a6b378b77c9031c7cd1b061def7cdc170dcdb4aa Mon Sep 17 00:00:00 2001 From: jayesh Date: Mon, 13 Jan 2025 16:10:17 +0530 Subject: [PATCH 16/19] updates --- livekit-agents/livekit/agents/multimodal/multimodal_agent.py | 2 ++ .../livekit/plugins/google/beta/realtime/realtime_api.py | 1 + 2 files changed, 3 insertions(+) diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index a9eb1eedb..42d675558 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -332,6 +332,8 @@ def _agent_speech_transcription_completed(ev: _InputTranscriptionProto): ) self._emit_speech_committed("agent", ev.transcript) + # Similar to _input_speech_started, this handles updating the state to "listening" when the agent's speech is complete. + # However, since Gemini doesn't support VAD events, we are not emitting the `user_started_speaking` event here. @self._session.on("agent_speech_completed") def _agent_speech_completed(): self._update_state("listening") diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 92c9ad305..8ec674945 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -44,6 +44,7 @@ "function_calls_cancelled", "input_speech_transcription_completed", "agent_speech_transcription_completed", + "agent_speech_completed", ] From c624fbaad61c05a2ee7b85f1ebe95917a491c9d4 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 14 Jan 2025 18:21:50 +0530 Subject: [PATCH 17/19] updates --- .../plugins/google/beta/realtime/api_proto.py | 2 +- .../google/beta/realtime/realtime_api.py | 29 +++--- .../google/beta/realtime/transcriber.py | 88 ++++++++++++------- .../livekit-plugins-google/mypy.ini | 2 + 4 files changed, 80 insertions(+), 41 deletions(-) create mode 100644 livekit-plugins/livekit-plugins-google/mypy.ini diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py index f79569df4..80ebb1ea3 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py @@ -6,7 +6,7 @@ from livekit.agents import llm -from google.genai import types # type: ignore +from google.genai import types __all__ = [ "ClientEvents", diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 8ec674945..6b3f5211d 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -11,8 +11,8 @@ from livekit.agents import llm, utils from livekit.agents.llm.function_context import _create_ai_function_info -from google import genai # type: ignore -from google.genai.types import ( # type: ignore +from google import genai +from google.genai.types import ( FunctionResponse, GenerationConfigDict, LiveClientContent, @@ -89,7 +89,8 @@ class ModelOptions: presence_penalty: float | None frequency_penalty: float | None instructions: str - enable_transcription: bool + enable_user_audio_transcription: bool + enable_agent_audio_transcription: bool class RealtimeModel: @@ -101,7 +102,8 @@ def __init__( api_key: str | None = None, voice: Voice | str = "Puck", modalities: ResponseModality = "AUDIO", - enable_transcription: bool = True, + enable_user_audio_transcription: bool = True, + enable_agent_audio_transcription: bool = True, vertexai: bool = False, project_id: str | None = None, location: str | None = None, @@ -123,7 +125,8 @@ def __init__( modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"]. model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp". voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck". - enable_transcription (bool, optional): Whether to enable transcription. Defaults to True + enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True + enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False. project_id (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai) @@ -165,7 +168,8 @@ def __init__( model=model, api_key=self._api_key, voice=voice, - enable_transcription=enable_transcription, + enable_user_audio_transcription=enable_user_audio_transcription, + enable_agent_audio_transcription=enable_agent_audio_transcription, response_modalities=modalities, vertexai=vertexai, project=self._project_id, @@ -271,14 +275,15 @@ def __init__( self._main_atask = asyncio.create_task( self._main_task(), name="gemini-realtime-session" ) - if self._opts.enable_transcription: + if self._opts.enable_user_audio_transcription: self._transcriber = TranscriberSession( client=self._client, model=self._opts.model ) + self._transcriber.on("input_speech_done", self._on_input_speech_done) + if self._opts.enable_agent_audio_transcription: self._agent_transcriber = TranscriberSession( client=self._client, model=self._opts.model ) - self._transcriber.on("input_speech_done", self._on_input_speech_done) self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done) # init dummy task self._init_sync_task = asyncio.create_task(asyncio.sleep(0)) @@ -303,7 +308,7 @@ def fnc_ctx(self, value: llm.FunctionContext | None) -> None: self._fnc_ctx = value def _push_audio(self, frame: rtc.AudioFrame) -> None: - if self._opts.enable_transcription: + if self._opts.enable_user_audio_transcription: self._transcriber._push_audio(frame) data = base64.b64encode(frame.data).decode("utf-8") self._queue_msg({"mime_type": "audio/pcm", "data": data}) @@ -400,7 +405,7 @@ async def _recv_task(): samples_per_channel=len(part.inline_data.data) // 2, ) - if self._opts.enable_transcription: + if self._opts.enable_agent_audio_transcription: self._agent_transcriber._push_audio(frame) content.audio_stream.send_nowait(frame) @@ -461,6 +466,10 @@ async def _recv_task(): finally: await utils.aio.gracefully_cancel(*tasks) await self._session.close() + if self._opts.enable_user_audio_transcription: + await self._transcriber.aclose() + if self._opts.enable_agent_audio_transcription: + await self._agent_transcriber.aclose() @utils.log_exceptions(logger=logger) async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str): diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py index 5b80de8d2..3e6602d47 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/transcriber.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from typing import Literal +import websockets from livekit import rtc from livekit.agents import utils -from google import genai # type: ignore -from google.genai import types # type: ignore +from google import genai +from google.genai import types from ...log import logger from .api_proto import ClientEvents, LiveAPIModels @@ -19,7 +20,9 @@ "input_speech_done", ] -SYSTEM_INSTRUCTIONS = """ +DEFAULT_LANGUAGE = "English" + +SYSTEM_INSTRUCTIONS = f""" You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text. - Transcribe verbatim; exclude non-speech sounds. @@ -29,6 +32,8 @@ - Use proper punctuation and formatting. - Do not add explanations, comments, or extra information. - Do not include timestamps, speaker labels, or annotations unless specified. + +- Audio Language: {DEFAULT_LANGUAGE} """ @@ -43,7 +48,7 @@ def __init__( self, *, client: genai.Client, - model: LiveAPIModels, + model: LiveAPIModels | str, ): """ Initializes a TranscriberSession instance for interacting with Google's Realtime API. @@ -51,6 +56,7 @@ def __init__( super().__init__() self._client = client self._model = model + self._closed = False self._config = types.LiveConnectConfigDict( response_modalities="TEXT", @@ -66,16 +72,19 @@ def __init__( self._active_response_id = None def _push_audio(self, frame: rtc.AudioFrame) -> None: + if self._closed: + return data = base64.b64encode(frame.data).decode("utf-8") self._queue_msg({"mime_type": "audio/pcm", "data": data}) - def _queue_msg(self, msg: dict) -> None: - self._send_ch.send_nowait(msg) + def _queue_msg(self, msg: ClientEvents) -> None: + if not self._closed: + self._send_ch.send_nowait(msg) async def aclose(self) -> None: if self._send_ch.closed: return - + self._closed = True self._send_ch.close() await self._main_atask @@ -83,32 +92,51 @@ async def aclose(self) -> None: async def _main_task(self): @utils.log_exceptions(logger=logger) async def _send_task(): - async for msg in self._send_ch: - await self._session.send(input=msg) + try: + async for msg in self._send_ch: + if self._closed: + break + await self._session.send(input=msg) + except websockets.exceptions.ConnectionClosedError as e: + logger.exception(f"Transcriber session closed in _send_task: {e}") + self._closed = True + except Exception as e: + logger.exception(f"Uncaught error in transcriber _send_task: {e}") + self._closed = True @utils.log_exceptions(logger=logger) async def _recv_task(): - while True: - async for response in self._session.receive(): - if self._active_response_id is None: - self._active_response_id = utils.shortuuid() - content = TranscriptionContent( - response_id=self._active_response_id, - text="", - ) - self.emit("input_speech_started", content) - - server_content = response.server_content - if server_content: - model_turn = server_content.model_turn - if model_turn: - for part in model_turn.parts: - if part.text: - content.text += part.text - - if server_content.turn_complete: - self.emit("input_speech_done", content) - self._active_response_id = None + try: + while not self._closed: + async for response in self._session.receive(): + if self._closed: + break + if self._active_response_id is None: + self._active_response_id = utils.shortuuid() + content = TranscriptionContent( + response_id=self._active_response_id, + text="", + ) + self.emit("input_speech_started", content) + + server_content = response.server_content + if server_content: + model_turn = server_content.model_turn + if model_turn: + for part in model_turn.parts: + if part.text: + content.text += part.text + + if server_content.turn_complete: + self.emit("input_speech_done", content) + self._active_response_id = None + + except websockets.exceptions.ConnectionClosedError as e: + logger.exception(f"Transcriber session closed in _recv_task: {e}") + self._closed = True + except Exception as e: + logger.exception(f"Uncaught error in transcriber _recv_task: {e}") + self._closed = True async with self._client.aio.live.connect( model=self._model, config=self._config diff --git a/livekit-plugins/livekit-plugins-google/mypy.ini b/livekit-plugins/livekit-plugins-google/mypy.ini new file mode 100644 index 000000000..9e4e991a2 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/mypy.ini @@ -0,0 +1,2 @@ +[mypy-google.genai.*] +ignore_missing_imports = True From fc10e5e7913cbc4b59f6c2b977273795262293f3 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 14 Jan 2025 18:24:01 +0530 Subject: [PATCH 18/19] updates --- .../livekit/agents/multimodal/multimodal_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index 42d675558..58bba5f11 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -380,9 +380,9 @@ async def _run_task(delay: float) -> None: await asyncio.sleep(delay) if self._room.isconnected(): - await self._room.local_participant.set_attributes({ - ATTRIBUTE_AGENT_STATE: state - }) + await self._room.local_participant.set_attributes( + {ATTRIBUTE_AGENT_STATE: state} + ) if self._update_state_task is not None: self._update_state_task.cancel() From 9e3b0207e1bd3df2f37f7f7950fe9fec6d550f76 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 14 Jan 2025 18:34:45 +0530 Subject: [PATCH 19/19] updates --- .../plugins/google/beta/realtime/realtime_api.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py index 6b3f5211d..58216cdf5 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -317,13 +317,15 @@ def _queue_msg(self, msg: ClientEvents) -> None: self._send_ch.send_nowait(msg) def generate_reply( - self, chat_ctx: llm.ChatContext | llm.ChatMessage, turn_complete: bool = True + self, ctx: llm.ChatContext | llm.ChatMessage, turn_complete: bool = True ) -> None: - if isinstance(chat_ctx, llm.ChatMessage): + if isinstance(ctx, llm.ChatMessage) and isinstance(ctx.content, str): new_chat_ctx = llm.ChatContext() - new_chat_ctx.append(text=chat_ctx.content, role=chat_ctx.role) + new_chat_ctx.append(text=ctx.content, role=ctx.role) + elif isinstance(ctx, llm.ChatContext): + new_chat_ctx = ctx else: - new_chat_ctx = chat_ctx + raise ValueError("Invalid chat context") gemini_ctx = _build_gemini_ctx(new_chat_ctx) client_content = LiveClientContent( turn_complete=turn_complete,