Skip to content

Commit

Permalink
fix: Resolve WebSocket issues and add real-time call example (#160)
Browse files Browse the repository at this point in the history
- Add support for chat update events in WebSocket communication
- Introduce new real-time audio chat GUI example application
  • Loading branch information
chyroc authored Jan 13, 2025
1 parent da63a6d commit 6c5d7b8
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ dist/
scripts/
.cache/
output.wav
response.wav
temp_response.pcm
2 changes: 1 addition & 1 deletion cozepy/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def build_assistant_answer(content: str, meta_data: Optional[Dict[str, str]] = N
def get_audio(self) -> Optional[bytes]:
if self.content_type == MessageContentType.AUDIO:
return base64.b64decode(self.content)
return None
return b""


class ChatStatus(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions cozepy/websockets/audio/speech/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def speech_update(self, event: SpeechUpdateEvent) -> None:
self._input_queue.put(event)

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
event_type = message.get("event_type") or ""
data = message.get("data") or {}
Expand Down Expand Up @@ -235,7 +235,7 @@ async def speech_update(self, data: SpeechUpdateEvent.Data) -> None:
await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data}))

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
event_type = message.get("event_type") or ""
data = message.get("data") or {}
Expand Down
14 changes: 12 additions & 2 deletions cozepy/websockets/audio/transcriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def serialize_delta(self, delta: bytes, _info):
event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND
data: Data

def _dump_without_delta(self):
return {
"id": self.id,
"type": self.event_type.value,
"detail": self.detail,
"data": {
"delta_length": len(self.data.delta) if self.data and self.data.delta else 0,
},
}


# req
class InputAudioBufferCompleteEvent(WebsocketsEvent):
Expand Down Expand Up @@ -127,7 +137,7 @@ def input_audio_buffer_complete(self) -> None:
self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
event_type = message.get("event_type") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
data = message.get("data") or {}
Expand Down Expand Up @@ -250,7 +260,7 @@ async def input_audio_buffer_complete(self) -> None:
await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
event_type = message.get("event_type") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
data = message.get("data") or {}
Expand Down
34 changes: 32 additions & 2 deletions cozepy/websockets/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class ChatCreatedEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CHAT_CREATED


# resp
class ChatUpdatedEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATED
data: ChatUpdateEvent.Data


# resp
class ConversationChatCreatedEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED
Expand Down Expand Up @@ -107,6 +113,9 @@ class WebsocketsChatEventHandler(WebsocketsBaseEventHandler):
def on_chat_created(self, cli: "WebsocketsChatClient", event: ChatCreatedEvent):
pass

def on_chat_updated(self, cli: "WebsocketsChatClient", event: ChatUpdatedEvent):
pass

def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent):
pass

Expand Down Expand Up @@ -151,6 +160,7 @@ def __init__(
on_event = on_event.to_dict(
{
WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created,
WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated,
WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed,
WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created,
WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress,
Expand Down Expand Up @@ -188,7 +198,7 @@ def input_audio_buffer_complete(self) -> None:
self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
event_type = message.get("event_type") or ""
data = message.get("data") or {}
Expand All @@ -199,6 +209,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
"detail": detail,
}
)
elif event_type == WebsocketsEventType.CHAT_UPDATED.value:
return ChatUpdatedEvent.model_validate(
{
"id": event_id,
"detail": detail,
"data": ChatUpdateEvent.Data.model_validate(data),
}
)
elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value:
return InputAudioBufferCompletedEvent.model_validate(
{
Expand Down Expand Up @@ -299,6 +317,9 @@ class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler):
async def on_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ChatCreatedEvent):
pass

async def on_chat_updated(self, cli: "AsyncWebsocketsChatClient", event: ChatUpdatedEvent):
pass

async def on_input_audio_buffer_completed(
self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent
):
Expand Down Expand Up @@ -355,6 +376,7 @@ def __init__(
on_event = on_event.to_dict(
{
WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created,
WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated,
WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed,
WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created,
WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress,
Expand Down Expand Up @@ -392,7 +414,7 @@ async def input_audio_buffer_complete(self) -> None:
await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))

def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
event_type = message.get("event_type") or ""
data = message.get("data") or {}
Expand All @@ -403,6 +425,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
"detail": detail,
}
)
elif event_type == WebsocketsEventType.CHAT_UPDATED.value:
return ChatUpdatedEvent.model_validate(
{
"id": event_id,
"detail": detail,
"data": ChatUpdateEvent.Data.model_validate(data),
}
)
elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value:
return InputAudioBufferCompletedEvent.model_validate(
{
Expand Down
19 changes: 14 additions & 5 deletions cozepy/websockets/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class WebsocketsEventType(str, Enum):
CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server
# resp
CHAT_CREATED = "chat.created"
CHAT_UPDATED = "chat.updated"
# INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event
CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started
CONVERSATION_CHAT_IN_PROGRESS = "conversation.chat.in_progress"
Expand All @@ -109,7 +110,7 @@ class Detail(BaseModel):
logid: Optional[str] = None

event_type: WebsocketsEventType
event_id: Optional[str] = None
id: Optional[str] = None
detail: Optional[Detail] = None


Expand All @@ -118,7 +119,7 @@ class WebsocketsErrorEvent(WebsocketsEvent):
data: CozeAPIError


class InputAudio(CozeModel):
class InputAudio(BaseModel):
format: Optional[str]
codec: Optional[str]
sample_rate: Optional[int]
Expand Down Expand Up @@ -266,7 +267,7 @@ def _receive_loop(self) -> None:
self._handle_error(e)

def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
event_type = message.get("event_type") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
data = message.get("data") or {}
Expand Down Expand Up @@ -466,7 +467,7 @@ async def _receive_loop(self) -> None:
await self._handle_error(e)

def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]:
event_id = message.get("event_id") or ""
event_id = message.get("id") or ""
event_type = message.get("event_type") or ""
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
data = message.get("data") or {}
Expand Down Expand Up @@ -553,7 +554,15 @@ async def _close(self) -> None:
async def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None:
if not event or not self._ws:
return
log_debug("[%s] send event, type=%s", self._path, event.event_type.value)
if event.event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND:
log_debug(
"[%s] send event, type=%s, event=%s",
self._path,
event.event_type.value,
json.dumps(event._dump_without_delta()), # type: ignore
)
else:
log_debug("[%s] send event, type=%s, event=%s", self._path, event.event_type.value, event.model_dump_json())
await self._ws.send(event.model_dump_json())


Expand Down
Loading

0 comments on commit 6c5d7b8

Please sign in to comment.