Skip to content

Commit

Permalink
feat: Add WebSocket conversation.chat.cancel support (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbb1234567890 authored Mar 5, 2025
1 parent 61c83d3 commit fa3ba2e
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
ChatUpdateEvent,
ConversationAudioDeltaEvent,
ConversationAudioTranscriptCompletedEvent,
ConversationChatCanceledEvent,
ConversationChatCancelEvent,
ConversationChatCompletedEvent,
ConversationChatCreatedEvent,
ConversationChatRequiresActionEvent,
Expand Down Expand Up @@ -257,12 +259,14 @@
# websockets.chat
"ChatUpdateEvent",
"ConversationChatSubmitToolOutputsEvent",
"ConversationChatCancelEvent",
"ConversationChatCreatedEvent",
"ConversationMessageDeltaEvent",
"ConversationAudioTranscriptCompletedEvent",
"ConversationChatRequiresActionEvent",
"ConversationAudioDeltaEvent",
"ConversationChatCompletedEvent",
"ConversationChatCanceledEvent",
"WebsocketsChatEventHandler",
"WebsocketsChatClient",
"AsyncWebsocketsChatEventHandler",
Expand Down
40 changes: 40 additions & 0 deletions cozepy/websockets/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class Data(BaseModel):
data: Data


# req
class ConversationChatCancelEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CANCEL


# resp
class ChatCreatedEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CHAT_CREATED
Expand Down Expand Up @@ -119,6 +124,11 @@ class ConversationChatCompletedEvent(WebsocketsEvent):
data: Chat


# resp
class ConversationChatCanceledEvent(WebsocketsEvent):
event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CANCELED


class WebsocketsChatEventHandler(WebsocketsBaseEventHandler):
def on_chat_created(self, cli: "WebsocketsChatClient", event: ChatCreatedEvent):
pass
Expand Down Expand Up @@ -160,6 +170,9 @@ def on_conversation_audio_completed(self, cli: "WebsocketsChatClient", event: Co
def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: ConversationChatCompletedEvent):
pass

def on_conversation_chat_canceled(self, cli: "WebsocketsChatClient", event: ConversationChatCanceledEvent):
pass


class WebsocketsChatClient(WebsocketsBaseClient):
def __init__(
Expand Down Expand Up @@ -187,6 +200,7 @@ def __init__(
WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta,
WebsocketsEventType.CONVERSATION_AUDIO_COMPLETED: on_event.on_conversation_audio_completed,
WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed,
WebsocketsEventType.CONVERSATION_CHAT_CANCELED: on_event.on_conversation_chat_canceled,
}
)
super().__init__(
Expand All @@ -211,6 +225,9 @@ def chat_update(self, data: ChatUpdateEvent.Data) -> None:
def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data):
self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data}))

def conversation_chat_cancel(self):
self._input_queue.put(ConversationChatCancelEvent.model_validate({}))

def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None:
self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data}))

Expand Down Expand Up @@ -313,6 +330,13 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
"data": Chat.model_validate(data),
}
)
elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CANCELED.value:
return ConversationChatCanceledEvent.model_validate(
{
"id": event_id,
"detail": detail,
}
)
else:
log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid)
return None
Expand Down Expand Up @@ -396,6 +420,11 @@ async def on_conversation_chat_completed(
):
pass

async def on_conversation_chat_canceled(
self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCanceledEvent
):
pass


class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient):
def __init__(
Expand Down Expand Up @@ -423,6 +452,7 @@ def __init__(
WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta,
WebsocketsEventType.CONVERSATION_AUDIO_COMPLETED: on_event.on_conversation_audio_completed,
WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed,
WebsocketsEventType.CONVERSATION_CHAT_CANCELED: on_event.on_conversation_chat_canceled,
}
)
super().__init__(
Expand All @@ -447,6 +477,9 @@ async def chat_update(self, data: ChatUpdateEvent.Data) -> None:
async def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data) -> None:
await self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data}))

async def conversation_chat_cancel(self) -> None:
await self._input_queue.put(ConversationChatCancelEvent.model_validate({}))

async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None:
await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data}))

Expand Down Expand Up @@ -549,6 +582,13 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
"data": Chat.model_validate(data),
}
)
elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CANCELED.value:
return ConversationChatCanceledEvent.model_validate(
{
"id": event_id,
"detail": detail,
}
)
else:
log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid)
return None
Expand Down
2 changes: 2 additions & 0 deletions cozepy/websockets/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class WebsocketsEventType(str, Enum):
# INPUT_AUDIO_BUFFER_COMPLETE = "input_audio_buffer.complete" # no audio send, start chat
CHAT_UPDATE = "chat.update" # send chat config to server
CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server
CONVERSATION_CHAT_CANCEL = "conversation.chat.cancel" # send cancel chat to server
# resp
CHAT_CREATED = "chat.created"
CHAT_UPDATED = "chat.updated"
Expand All @@ -104,6 +105,7 @@ class WebsocketsEventType(str, Enum):
CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update
CONVERSATION_AUDIO_COMPLETED = "conversation.audio.completed"
CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" # all message received, can close connection
CONVERSATION_CHAT_CANCELED = "conversation.chat.canceled" # chat canceled


class WebsocketsEvent(CozeModel, ABC):
Expand Down
7 changes: 7 additions & 0 deletions examples/websockets_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
AudioFormat,
ConversationAudioDeltaEvent,
ConversationChatCompletedEvent,
ConversationChatCanceledEvent,
ConversationChatCreatedEvent,
ConversationChatRequiresActionEvent,
ConversationChatSubmitToolOutputsEvent,
ConversationChatCancelEvent,
ConversationMessageDeltaEvent,
DeviceOAuthApp,
InputAudioBufferAppendEvent,
Expand Down Expand Up @@ -113,6 +115,11 @@ async def on_conversation_chat_completed(
log_info("[examples] Saving audio data to output.wav")
write_pcm_to_wav_file(b"".join(self.delta), "output.wav")

async def on_conversation_chat_canceled(
self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCanceledEvent
):
log_info("[examples] chat canceled")


def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str):
async def iterator():
Expand Down
9 changes: 9 additions & 0 deletions examples/websockets_chat_realtime_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ChatUpdateEvent,
ConversationAudioDeltaEvent,
ConversationChatCompletedEvent,
ConversationChatCanceledEvent,
InputAudio,
InputAudioBufferAppendEvent,
TokenAuth,
Expand Down Expand Up @@ -295,6 +296,14 @@ async def on_conversation_chat_completed(
except Exception as e:
print(f"完成对话错误: {e}")

async def on_conversation_chat_canceled(
self, cli: AsyncWebsocketsChatClient, event: ConversationChatCanceledEvent
):
try:
print("打断")
except Exception as e:
print(f"对话打断错误: {e}")

kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}")
self.chat_client = self.coze.websockets.chat.create(
bot_id=os.getenv("COZE_BOT_ID"),
Expand Down

0 comments on commit fa3ba2e

Please sign in to comment.