diff --git a/cozepy/__init__.py b/cozepy/__init__.py index c8e3908..096a197 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -124,6 +124,8 @@ ChatUpdateEvent, ConversationAudioDeltaEvent, ConversationAudioTranscriptCompletedEvent, + ConversationChatCanceledEvent, + ConversationChatCancelEvent, ConversationChatCompletedEvent, ConversationChatCreatedEvent, ConversationChatRequiresActionEvent, @@ -257,12 +259,14 @@ # websockets.chat "ChatUpdateEvent", "ConversationChatSubmitToolOutputsEvent", + "ConversationChatCancelEvent", "ConversationChatCreatedEvent", "ConversationMessageDeltaEvent", "ConversationAudioTranscriptCompletedEvent", "ConversationChatRequiresActionEvent", "ConversationAudioDeltaEvent", "ConversationChatCompletedEvent", + "ConversationChatCanceledEvent", "WebsocketsChatEventHandler", "WebsocketsChatClient", "AsyncWebsocketsChatEventHandler", diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 4e0448b..a76508a 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -54,6 +54,11 @@ class Data(BaseModel): data: Data +# req +class ConversationChatCancelEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CANCEL + + # resp class ChatCreatedEvent(WebsocketsEvent): event_type: WebsocketsEventType = WebsocketsEventType.CHAT_CREATED @@ -119,6 +124,11 @@ class ConversationChatCompletedEvent(WebsocketsEvent): data: Chat +# resp +class ConversationChatCanceledEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CANCELED + + class WebsocketsChatEventHandler(WebsocketsBaseEventHandler): def on_chat_created(self, cli: "WebsocketsChatClient", event: ChatCreatedEvent): pass @@ -160,6 +170,9 @@ def on_conversation_audio_completed(self, cli: "WebsocketsChatClient", event: Co def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: ConversationChatCompletedEvent): pass + def on_conversation_chat_canceled(self, cli: "WebsocketsChatClient", event: ConversationChatCanceledEvent): + pass + class WebsocketsChatClient(WebsocketsBaseClient): def __init__( @@ -187,6 +200,7 @@ def __init__( WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, WebsocketsEventType.CONVERSATION_AUDIO_COMPLETED: on_event.on_conversation_audio_completed, WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + WebsocketsEventType.CONVERSATION_CHAT_CANCELED: on_event.on_conversation_chat_canceled, } ) super().__init__( @@ -211,6 +225,9 @@ def chat_update(self, data: ChatUpdateEvent.Data) -> None: def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data): self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) + def conversation_chat_cancel(self): + self._input_queue.put(ConversationChatCancelEvent.model_validate({})) + def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) @@ -313,6 +330,13 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "data": Chat.model_validate(data), } ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CANCELED.value: + return ConversationChatCanceledEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) else: log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None @@ -396,6 +420,11 @@ async def on_conversation_chat_completed( ): pass + async def on_conversation_chat_canceled( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCanceledEvent + ): + pass + class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient): def __init__( @@ -423,6 +452,7 @@ def __init__( WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, WebsocketsEventType.CONVERSATION_AUDIO_COMPLETED: on_event.on_conversation_audio_completed, WebsocketsEventType.CONVERSATION_CHAT_COMPLETED: on_event.on_conversation_chat_completed, + WebsocketsEventType.CONVERSATION_CHAT_CANCELED: on_event.on_conversation_chat_canceled, } ) super().__init__( @@ -447,6 +477,9 @@ async def chat_update(self, data: ChatUpdateEvent.Data) -> None: async def conversation_chat_submit_tool_outputs(self, data: ConversationChatSubmitToolOutputsEvent.Data) -> None: await self._input_queue.put(ConversationChatSubmitToolOutputsEvent.model_validate({"data": data})) + async def conversation_chat_cancel(self) -> None: + await self._input_queue.put(ConversationChatCancelEvent.model_validate({})) + async def input_audio_buffer_append(self, data: InputAudioBufferAppendEvent.Data) -> None: await self._input_queue.put(InputAudioBufferAppendEvent.model_validate({"data": data})) @@ -549,6 +582,13 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "data": Chat.model_validate(data), } ) + elif event_type == WebsocketsEventType.CONVERSATION_CHAT_CANCELED.value: + return ConversationChatCanceledEvent.model_validate( + { + "id": event_id, + "detail": detail, + } + ) else: log_warning("[%s] unknown event, type=%s, logid=%s", self._path, event_type, detail.logid) return None diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 70e6c28..51e9c1f 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -91,6 +91,7 @@ class WebsocketsEventType(str, Enum): # INPUT_AUDIO_BUFFER_COMPLETE = "input_audio_buffer.complete" # no audio send, start chat CHAT_UPDATE = "chat.update" # send chat config to server CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server + CONVERSATION_CHAT_CANCEL = "conversation.chat.cancel" # send cancel chat to server # resp CHAT_CREATED = "chat.created" CHAT_UPDATED = "chat.updated" @@ -104,6 +105,7 @@ class WebsocketsEventType(str, Enum): CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update CONVERSATION_AUDIO_COMPLETED = "conversation.audio.completed" CONVERSATION_CHAT_COMPLETED = "conversation.chat.completed" # all message received, can close connection + CONVERSATION_CHAT_CANCELED = "conversation.chat.canceled" # chat canceled class WebsocketsEvent(CozeModel, ABC): diff --git a/examples/websockets_chat.py b/examples/websockets_chat.py index d216773..3ffcc92 100644 --- a/examples/websockets_chat.py +++ b/examples/websockets_chat.py @@ -12,9 +12,11 @@ AudioFormat, ConversationAudioDeltaEvent, ConversationChatCompletedEvent, + ConversationChatCanceledEvent, ConversationChatCreatedEvent, ConversationChatRequiresActionEvent, ConversationChatSubmitToolOutputsEvent, + ConversationChatCancelEvent, ConversationMessageDeltaEvent, DeviceOAuthApp, InputAudioBufferAppendEvent, @@ -113,6 +115,11 @@ async def on_conversation_chat_completed( log_info("[examples] Saving audio data to output.wav") write_pcm_to_wav_file(b"".join(self.delta), "output.wav") + async def on_conversation_chat_canceled( + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCanceledEvent + ): + log_info("[examples] chat canceled") + def wrap_coze_speech_to_iterator(coze: AsyncCoze, text: str): async def iterator(): diff --git a/examples/websockets_chat_realtime_gui.py b/examples/websockets_chat_realtime_gui.py index 0d858d6..1ca5336 100644 --- a/examples/websockets_chat_realtime_gui.py +++ b/examples/websockets_chat_realtime_gui.py @@ -19,6 +19,7 @@ ChatUpdateEvent, ConversationAudioDeltaEvent, ConversationChatCompletedEvent, + ConversationChatCanceledEvent, InputAudio, InputAudioBufferAppendEvent, TokenAuth, @@ -295,6 +296,14 @@ async def on_conversation_chat_completed( except Exception as e: print(f"完成对话错误: {e}") + async def on_conversation_chat_canceled( + self, cli: AsyncWebsocketsChatClient, event: ConversationChatCanceledEvent + ): + try: + print("打断") + except Exception as e: + print(f"对话打断错误: {e}") + kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") self.chat_client = self.coze.websockets.chat.create( bot_id=os.getenv("COZE_BOT_ID"),