diff --git a/.gitignore b/.gitignore index 2dd1059..77b80aa 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ dist/ scripts/ .cache/ output.wav +response.wav +temp_response.pcm \ No newline at end of file diff --git a/cozepy/chat/__init__.py b/cozepy/chat/__init__.py index 5fb140f..0040aad 100644 --- a/cozepy/chat/__init__.py +++ b/cozepy/chat/__init__.py @@ -191,7 +191,7 @@ def build_assistant_answer(content: str, meta_data: Optional[Dict[str, str]] = N def get_audio(self) -> Optional[bytes]: if self.content_type == MessageContentType.AUDIO: return base64.b64decode(self.content) - return None + return b"" class ChatStatus(str, Enum): diff --git a/cozepy/websockets/audio/speech/__init__.py b/cozepy/websockets/audio/speech/__init__.py index 83f020b..7a533f7 100644 --- a/cozepy/websockets/audio/speech/__init__.py +++ b/cozepy/websockets/audio/speech/__init__.py @@ -121,7 +121,7 @@ def speech_update(self, event: SpeechUpdateEvent) -> None: self._input_queue.put(event) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) event_type = message.get("event_type") or "" data = message.get("data") or {} @@ -235,7 +235,7 @@ async def speech_update(self, data: SpeechUpdateEvent.Data) -> None: await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) event_type = message.get("event_type") or "" data = message.get("data") or {} diff --git a/cozepy/websockets/audio/transcriptions/__init__.py b/cozepy/websockets/audio/transcriptions/__init__.py index 72f2f42..4f03b14 100644 --- a/cozepy/websockets/audio/transcriptions/__init__.py +++ b/cozepy/websockets/audio/transcriptions/__init__.py @@ -30,6 +30,16 @@ def serialize_delta(self, delta: bytes, _info): event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND data: Data + def _dump_without_delta(self): + return { + "id": self.id, + "type": self.event_type.value, + "detail": self.detail, + "data": { + "delta_length": len(self.data.delta) if self.data and self.data.delta else 0, + }, + } + # req class InputAudioBufferCompleteEvent(WebsocketsEvent): @@ -127,7 +137,7 @@ def input_audio_buffer_complete(self) -> None: self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" event_type = message.get("event_type") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) data = message.get("data") or {} @@ -250,7 +260,7 @@ async def input_audio_buffer_complete(self) -> None: await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" event_type = message.get("event_type") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) data = message.get("data") or {} diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index bc40485..2e0e73b 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -58,6 +58,12 @@ class ChatCreatedEvent(WebsocketsEvent): event_type: WebsocketsEventType = WebsocketsEventType.CHAT_CREATED +# resp +class ChatUpdatedEvent(WebsocketsEvent): + event_type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATED + data: ChatUpdateEvent.Data + + # resp class ConversationChatCreatedEvent(WebsocketsEvent): event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED @@ -107,6 +113,9 @@ class WebsocketsChatEventHandler(WebsocketsBaseEventHandler): def on_chat_created(self, cli: "WebsocketsChatClient", event: ChatCreatedEvent): pass + def on_chat_updated(self, cli: "WebsocketsChatClient", event: ChatUpdatedEvent): + pass + def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent): pass @@ -151,6 +160,7 @@ def __init__( on_event = on_event.to_dict( { WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created, + WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated, WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress, @@ -188,7 +198,7 @@ def input_audio_buffer_complete(self) -> None: self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) event_type = message.get("event_type") or "" data = message.get("data") or {} @@ -199,6 +209,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "detail": detail, } ) + elif event_type == WebsocketsEventType.CHAT_UPDATED.value: + return ChatUpdatedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": ChatUpdateEvent.Data.model_validate(data), + } + ) elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: return InputAudioBufferCompletedEvent.model_validate( { @@ -299,6 +317,9 @@ class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler): async def on_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ChatCreatedEvent): pass + async def on_chat_updated(self, cli: "AsyncWebsocketsChatClient", event: ChatUpdatedEvent): + pass + async def on_input_audio_buffer_completed( self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent ): @@ -355,6 +376,7 @@ def __init__( on_event = on_event.to_dict( { WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created, + WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated, WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed, WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress, @@ -392,7 +414,7 @@ async def input_audio_buffer_complete(self) -> None: await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({})) def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) event_type = message.get("event_type") or "" data = message.get("data") or {} @@ -403,6 +425,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "detail": detail, } ) + elif event_type == WebsocketsEventType.CHAT_UPDATED.value: + return ChatUpdatedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": ChatUpdateEvent.Data.model_validate(data), + } + ) elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value: return InputAudioBufferCompletedEvent.model_validate( { diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 1182f37..318383f 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -93,6 +93,7 @@ class WebsocketsEventType(str, Enum): CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server # resp CHAT_CREATED = "chat.created" + CHAT_UPDATED = "chat.updated" # INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started CONVERSATION_CHAT_IN_PROGRESS = "conversation.chat.in_progress" @@ -109,7 +110,7 @@ class Detail(BaseModel): logid: Optional[str] = None event_type: WebsocketsEventType - event_id: Optional[str] = None + id: Optional[str] = None detail: Optional[Detail] = None @@ -118,7 +119,7 @@ class WebsocketsErrorEvent(WebsocketsEvent): data: CozeAPIError -class InputAudio(CozeModel): +class InputAudio(BaseModel): format: Optional[str] codec: Optional[str] sample_rate: Optional[int] @@ -266,7 +267,7 @@ def _receive_loop(self) -> None: self._handle_error(e) def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" event_type = message.get("event_type") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) data = message.get("data") or {} @@ -466,7 +467,7 @@ async def _receive_loop(self) -> None: await self._handle_error(e) def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: - event_id = message.get("event_id") or "" + event_id = message.get("id") or "" event_type = message.get("event_type") or "" detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {}) data = message.get("data") or {} @@ -553,7 +554,15 @@ async def _close(self) -> None: async def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None: if not event or not self._ws: return - log_debug("[%s] send event, type=%s", self._path, event.event_type.value) + if event.event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND: + log_debug( + "[%s] send event, type=%s, event=%s", + self._path, + event.event_type.value, + json.dumps(event._dump_without_delta()), # type: ignore + ) + else: + log_debug("[%s] send event, type=%s, event=%s", self._path, event.event_type.value, event.model_dump_json()) await self._ws.send(event.model_dump_json()) diff --git a/examples/websockets_chat_realtime_gui.py b/examples/websockets_chat_realtime_gui.py new file mode 100644 index 0000000..5f4ba3a --- /dev/null +++ b/examples/websockets_chat_realtime_gui.py @@ -0,0 +1,409 @@ +import asyncio +import json +import os +import queue +import threading +import time +import tkinter as tk +from tkinter import scrolledtext, ttk +from typing import Optional + +import pyaudio + +from cozepy import ( + COZE_CN_BASE_URL, + AsyncCoze, + AsyncWebsocketsChatClient, + AsyncWebsocketsChatEventHandler, + ChatUpdateEvent, + ConversationAudioDeltaEvent, + ConversationChatCompletedEvent, + InputAudioBufferAppendEvent, + TokenAuth, +) +from cozepy.websockets.ws import InputAudio +from examples.utils import setup_examples_logger + +# 音频参数设置 +CHUNK = 1024 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 +RATE = 24000 +INPUT_BLOCK_TIME = 0.05 # 50ms per block + +setup_examples_logger() + + +class ModernAudioChatGUI: + def __init__(self, root): + self.root = root + self.root.title("智能语音助手") + self.root.geometry("600x800") # 设置窗口大小 + + # 设置主题样式 + style = ttk.Style() + style.configure("Custom.TButton", padding=10, font=("Helvetica", 12)) + style.configure("Custom.TLabel", font=("Helvetica", 11)) + + # 初始化PyAudio + self.p = pyaudio.PyAudio() + self.recording = False + self.stream: Optional[pyaudio.Stream] = None + self.audio_queue = queue.Queue() + + # 添加音频播放队列 + self.playback_queue = queue.Queue() + self.is_playing = False + self.playback_stream = None + + # 创建GUI组件 + self.setup_gui() + + # 初始化Coze客户端 + self.coze = AsyncCoze( + auth=TokenAuth(os.getenv("COZE_API_TOKEN")), + base_url=os.getenv("COZE_API_BASE", COZE_CN_BASE_URL), + ) + + # 创建事件循环 + self.loop = asyncio.new_event_loop() + self.chat_client: Optional[AsyncWebsocketsChatClient] = None + + # 启动异步事件循环 + threading.Thread(target=self.run_async_loop, daemon=True).start() + + # 添加窗口关闭处理 + self.root.protocol("WM_DELETE_WINDOW", self.on_closing) + + # 启动播放线程 + threading.Thread(target=self.playback_loop, daemon=True).start() + + def setup_gui(self): + # 创建主框架 + self.main_frame = ttk.Frame(self.root, padding="20") + self.main_frame.pack(fill=tk.BOTH, expand=True) + + # 创建聊天记录显示区域 + self.chat_frame = ttk.Frame(self.main_frame) + self.chat_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 20)) + + self.chat_display = scrolledtext.ScrolledText( + self.chat_frame, wrap=tk.WORD, height=20, font=("Helvetica", 11), bg="#f5f5f5" + ) + self.chat_display.pack(fill=tk.BOTH, expand=True) + + # 状态显示区域 + self.status_frame = ttk.Frame(self.main_frame) + self.status_frame.pack(fill=tk.X, pady=(0, 20)) + + self.status_label = ttk.Label(self.status_frame, text="准备就绪", style="Custom.TLabel") + self.status_label.pack() + + # 音量指示器 + self.volume_bar = ttk.Progressbar(self.status_frame, mode="determinate", length=200) + self.volume_bar.pack(pady=10) + + # 按钮控制区域 + self.button_frame = ttk.Frame(self.main_frame) + self.button_frame.pack(fill=tk.X) + + # 开启通话按钮 + self.start_button = ttk.Button( + self.button_frame, text="开启通话", command=self.start_chat, style="Custom.TButton" + ) + self.start_button.pack(side=tk.LEFT, padx=5) + + # 发送数据按钮 + self.send_button = ttk.Button( + self.button_frame, text="发送", command=self.send_audio, state=tk.DISABLED, style="Custom.TButton" + ) + self.send_button.pack(side=tk.LEFT, padx=5) + + # 结束按钮 + self.end_button = ttk.Button( + self.button_frame, text="结束", command=self.end_chat, state=tk.DISABLED, style="Custom.TButton" + ) + self.end_button.pack(side=tk.LEFT, padx=5) + + def update_chat_display(self, message: str, is_user: bool = True): + self.chat_display.insert(tk.END, f"{'你' if is_user else 'AI'}: {message}\n") + self.chat_display.see(tk.END) # 自动滚动到底部 + + def start_chat(self): + self.start_button.config(state=tk.DISABLED) + self.send_button.config(state=tk.NORMAL) + self.end_button.config(state=tk.NORMAL) + + # 开始录音 + self.start_recording() + self.status_label.config(text="正在录音...") + self.update_chat_display("开始新的对话", is_user=False) + + def end_chat(self): + # 停止录音 + if self.recording: + self.stop_recording() + + # 关闭WebSocket连接 + self.loop.call_soon_threadsafe(self.close_connection) + + # 重置UI + self.start_button.config(state=tk.NORMAL) + self.send_button.config(state=tk.DISABLED) + self.end_button.config(state=tk.DISABLED) + self.status_label.config(text="准备就绪") + self.update_chat_display("对话已结束", is_user=False) + + def close_connection(self): + async def close(): + if self.chat_client: + await self.chat_client.close() + self.chat_client = None + + asyncio.run_coroutine_threadsafe(close(), self.loop) + + def start_recording(self): + try: + self.recording = True + + # 计算输入缓冲区大小 + input_frames_per_block = int(RATE * INPUT_BLOCK_TIME) + + # 打开音频流 + self.stream = self.p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=input_frames_per_block, + stream_callback=self.audio_callback, + ) + + # 启动WebSocket连接 + self.loop.call_soon_threadsafe(self.start_websocket_connection) + + except Exception as e: + print(f"启动录音错误: {e}") + self.recording = False + self.status_label.config(text="启动录音失败") + self.start_button.config(state=tk.NORMAL) + self.send_button.config(state=tk.DISABLED) + self.end_button.config(state=tk.DISABLED) + + def audio_callback(self, in_data, frame_count, time_info, status): + if self.recording: + try: + self.audio_queue.put(in_data) + + # 更新音量指示器 + amplitude = max( + abs(int.from_bytes(in_data[i : i + 2], "little", signed=True)) for i in range(0, len(in_data), 2) + ) + volume = min(100, int(amplitude / 32768 * 100)) + self.root.after(0, lambda v=volume: self.volume_bar.configure(value=v)) + + except Exception as e: + print(f"录音回调错误: {e}") + + return (None, pyaudio.paContinue) + + def stop_recording(self): + try: + self.recording = False + if self.stream is not None and self.stream.is_active(): + self.stream.stop_stream() + self.stream.close() + self.stream = None + except Exception as e: + print(f"停止录音错误: {e}") + finally: + self.stream = None + + def on_closing(self): + # 停止录音 + self.stop_recording() + + # 停止播放 + self.is_playing = False + if self.playback_stream: + self.playback_stream.stop_stream() + self.playback_stream.close() + + # 关闭WebSocket连接 + if self.chat_client: + self.loop.call_soon_threadsafe(self.close_connection) + + # 关闭PyAudio + if self.p: + self.p.terminate() + + # 关闭窗口 + self.root.destroy() + + def run_async_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def start_websocket_connection(self): + async def start(): + class ChatEventHandler(AsyncWebsocketsChatEventHandler): + def __init__(self, gui): + self.gui = gui + self.is_first_audio = True + self.temp_file = open("temp_response.pcm", "wb") + + async def on_conversation_audio_delta( + self, cli: AsyncWebsocketsChatClient, event: ConversationAudioDeltaEvent + ): + try: + audio_data = event.data.get_audio() + if audio_data: + # 写入临时文件 + self.temp_file.write(audio_data) + self.temp_file.flush() + + # 如果是第一块音频数据,开始播放 + if self.is_first_audio: + self.is_first_audio = False + self.gui.start_streaming_playback() + + # 将音频数据放入播放队列 + self.gui.playback_queue.put(audio_data) + except Exception as e: + print(f"处理音频数据错误: {e}") + + async def on_conversation_chat_completed( + self, cli: AsyncWebsocketsChatClient, event: ConversationChatCompletedEvent + ): + try: + # 关闭临时文件 + self.temp_file.close() + + # 标记播放结束 + self.gui.playback_queue.put(None) + + # 重新开始录音 + self.gui.root.after(1000, self.gui.resume_recording) + except Exception as e: + print(f"完成对话错误: {e}") + + kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") + self.chat_client = self.coze.websockets.chat.create( + bot_id=os.getenv("COZE_BOT_ID"), + on_event=ChatEventHandler(self), + **kwargs, + ) + + async with self.chat_client() as client: + await client.chat_update( + ChatUpdateEvent.Data.model_validate( + { + "input_audio": InputAudio.model_validate( + { + "format": "pcm", + "sample_rate": RATE, + "channel": CHANNELS, + "bit_depth": 16, + "codec": "pcm", + } + ), + } + ) + ) + while self.chat_client: + if not self.audio_queue.empty(): + audio_data = self.audio_queue.get() + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": audio_data, + } + ) + ) + await asyncio.sleep(0.1) + + asyncio.run_coroutine_threadsafe(start(), self.loop) + + def resume_recording(self): + # 重新开始录音 + self.start_recording() + self.send_button.config(state=tk.NORMAL) + self.status_label.config(text="正在录音...") + + def complete_audio(self): + async def complete(): + while not self.audio_queue.empty(): + await asyncio.sleep(0.1) + if self.chat_client: + await self.chat_client.input_audio_buffer_complete() + await self.chat_client.wait() + + asyncio.run_coroutine_threadsafe(complete(), self.loop) + + def start_streaming_playback(self): + """开始流式播放""" + self.status_label.config(text="正在播放回复...") + self.update_chat_display("正在回复...", is_user=False) + self.is_playing = True + + def playback_loop(self): + """音频播放循环""" + while True: + try: + if self.is_playing: + # 从队列中获取音频数据 + audio_data = self.playback_queue.get() + + # None 表示播放结束 + if audio_data is None: + if self.playback_stream: + self.playback_stream.stop_stream() + self.playback_stream.close() + self.playback_stream = None + self.is_playing = False + continue + + # 创建播放流(如果还没有创建) + if not self.playback_stream: + self.playback_stream = self.p.open( + format=FORMAT, channels=CHANNELS, rate=RATE, output=True, frames_per_buffer=CHUNK + ) + + # 播放音频数据 + self.playback_stream.write(audio_data) + + except Exception as e: + print(f"播放错误: {e}") + self.is_playing = False + if self.playback_stream: + try: + self.playback_stream.stop_stream() + self.playback_stream.close() + except Exception as e: + pass + self.playback_stream = None + + # 短暂休眠以避免CPU过载 + time.sleep(0.001) + + def send_audio(self): + # 停止录音 + self.stop_recording() + + # 禁用发送按钮 + self.send_button.config(state=tk.DISABLED) + self.status_label.config(text="正在发送...") + self.update_chat_display("发送语音消息", is_user=True) + + # 发送完成事件 + self.loop.call_soon_threadsafe(self.complete_audio) + + +def main(): + root = tk.Tk() + ModernAudioChatGUI(root) + root.mainloop() + + +if __name__ == "__main__": + main()