Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Support STT with Google realtime API #1321

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
16 changes: 14 additions & 2 deletions examples/multimodal-agent/gemini_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,25 @@ async def get_weather(
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()

# chat_ctx is used to serve as initial context, Agent will start the conversation first if chat_ctx is provided
chat_ctx = llm.ChatContext()
chat_ctx.append(text="What is LiveKit?", role="user")
chat_ctx.append(
text="LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.",
role="assistant",
)
chat_ctx.append(text="What is the LiveKit Agents framework?", role="user")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the last message have to be user.. in order for gemini to respond first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it can be either assistant or user.


agent = multimodal.MultimodalAgent(
model=google.beta.realtime.RealtimeModel(
voice="Charon",
voice="Puck",
temperature=0.8,
instructions="You are a helpful assistant",
instructions="""
You are a helpful assistant
Here are some helpful information about LiveKit and its products and services:
- LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.
- LiveKit provides an Agents framework for building server-side AI agents, client SDKs for building frontends, and LiveKit Cloud is a global network that transports voice, video, and data traffic in realtime.
""",
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
Expand Down
101 changes: 63 additions & 38 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,19 +311,32 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto):
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
user_msg = ChatMessage.create(
text=ev.transcript, role="user", id=ev.item_id
)
if self._model.capabilities.supports_truncate:
user_msg = ChatMessage.create(
Comment on lines +314 to +315
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this only done when it supports truncate? it seems you are trying to update an item, instead of truncate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some methods are not implemented in Gemini. We maintain remoteconversations in OpenAI, but not in Gemini. We should prevent invoking those methods when using Gemini. The purpose of supports_truncate is to differentiate between that

text=ev.transcript, role="user", id=ev.item_id
)

self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)
self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)

self._emit_speech_committed("user", ev.transcript)

self.emit("user_speech_committed", user_msg)
logger.debug(
"committed user speech",
extra={"user_transcript": ev.transcript},
@self._session.on("agent_speech_transcription_completed")
def _agent_speech_transcription_completed(ev: _InputTranscriptionProto):
self._agent_stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
self._emit_speech_committed("agent", ev.transcript)

@self._session.on("agent_speech_completed")
def _agent_speech_completed():
self._update_state("listening")
if self._playing_handle is not None and not self._playing_handle.done():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you include comments on why this is needed?

self._playing_handle.interrupt()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we should interrupt here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we call this function when speech is interrupted as well. They likely made some changes, but now Gemini returns server.turn_complete instead of server.interrupted when interrupted. It's confusing. In both cases, we are calling this function.


@self._session.on("input_speech_started")
def _input_speech_started():
Expand Down Expand Up @@ -365,9 +378,9 @@ async def _run_task(delay: float) -> None:
await asyncio.sleep(delay)

if self._room.isconnected():
await self._room.local_participant.set_attributes(
{ATTRIBUTE_AGENT_STATE: state}
)
await self._room.local_participant.set_attributes({
ATTRIBUTE_AGENT_STATE: state
})

if self._update_state_task is not None:
self._update_state_task.cancel()
Expand All @@ -378,6 +391,17 @@ async def _run_task(delay: float) -> None:
async def _main_task(self) -> None:
self._update_state("initializing")
self._audio_source = rtc.AudioSource(24000, 1)
track = rtc.LocalAudioTrack.create_audio_track(
"assistant_voice", self._audio_source
)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)
self._agent_stt_forwarder = transcription.STTSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
track=track,
)
self._agent_playout = agent_playout.AgentPlayout(
audio_source=self._audio_source
)
Expand All @@ -395,39 +419,21 @@ def _on_playout_stopped(interrupted: bool) -> None:
if interrupted:
collected_text += "..."

msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
if self._model.capabilities.supports_truncate:
if self._model.capabilities.supports_truncate and collected_text:
msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
self._session._update_conversation_item_content(
self._playing_handle.item_id, msg.content
)

if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
"committed agent speech",
extra={
"agent_transcript": collected_text,
"interrupted": interrupted,
},
)
self._emit_speech_committed("agent", collected_text, interrupted)

self._agent_playout.on("playout_started", _on_playout_started)
self._agent_playout.on("playout_stopped", _on_playout_stopped)

track = rtc.LocalAudioTrack.create_audio_track(
"assistant_voice", self._audio_source
)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)

await self._agent_publication.wait_for_subscription()

bstream = utils.audio.AudioByteStream(
Expand Down Expand Up @@ -497,3 +503,22 @@ def _ensure_session(self) -> aiohttp.ClientSession:
self._http_session = utils.http_context.http_session()

return self._http_session

def _emit_speech_committed(
self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False
):
if speaker == "user":
self.emit("user_speech_committed", msg)
else:
if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
f"committed {speaker} speech",
extra={
f"{speaker}_transcript": msg,
"interrupted": interrupted,
},
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from __future__ import annotations

import inspect
import json
from typing import Any, Dict, List, Literal, Sequence, Union

from livekit.agents import llm

from google.genai import types # type: ignore
Copy link
Member

@theomonnom theomonnom Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here is hard to follow, really sad we don't have types (it's unclear what is the structure of the dicts)


__all__ = [
"ClientEvents",
"LiveAPIModels",
"ResponseModality",
"Voice",
"_build_gemini_ctx",
"_build_tools",
]

LiveAPIModels = Literal["gemini-2.0-flash-exp"]

Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
Expand Down Expand Up @@ -77,3 +89,35 @@ def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]:
function_declarations.append(func_decl)

return function_declarations


def _build_gemini_ctx(chat_ctx: llm.ChatContext) -> List[types.Content]:
content = None
turns = []

for msg in chat_ctx.messages:
role = None
if msg.role == "assistant":
role = "model"
elif msg.role in {"system", "user"}:
role = "user"
elif msg.role == "tool":
continue

if content and content.role == role:
if isinstance(msg.content, str):
content.parts.append(types.Part(text=msg.content))
elif isinstance(msg.content, dict):
content.parts.append(types.Part(text=json.dumps(msg.content)))
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, str):
content.parts.append(types.Part(text=item))
else:
content = types.Content(
parts=[types.Part(text=msg.content)],
role=role,
)
turns.append(content)

return turns
Loading
Loading