diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 778e152..c8e3908 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -123,6 +123,7 @@ AsyncWebsocketsChatEventHandler, ChatUpdateEvent, ConversationAudioDeltaEvent, + ConversationAudioTranscriptCompletedEvent, ConversationChatCompletedEvent, ConversationChatCreatedEvent, ConversationChatRequiresActionEvent, @@ -258,6 +259,7 @@ "ConversationChatSubmitToolOutputsEvent", "ConversationChatCreatedEvent", "ConversationMessageDeltaEvent", + "ConversationAudioTranscriptCompletedEvent", "ConversationChatRequiresActionEvent", "ConversationAudioDeltaEvent", "ConversationChatCompletedEvent", diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 2e0e73b..7bd5083 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -81,6 +81,15 @@ class ConversationMessageDeltaEvent(WebsocketsEvent): data: Message +# resp +class ConversationAudioTranscriptCompletedEvent(WebsocketsEvent): + class Data(BaseModel): + content: str + + event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED + data: Data + + # resp class ConversationMessageCompletedEvent(WebsocketsEvent): event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_MESSAGE_COMPLETED @@ -128,6 +137,11 @@ def on_conversation_chat_in_progress(self, cli: "WebsocketsChatClient", event: C def on_conversation_message_delta(self, cli: "WebsocketsChatClient", event: ConversationMessageDeltaEvent): pass + def on_conversation_audio_transcript_completed( + self, cli: "WebsocketsChatClient", event: ConversationAudioTranscriptCompletedEvent + ): + pass + def on_conversation_message_completed(self, cli: "WebsocketsChatClient", event: ConversationMessageCompletedEvent): pass @@ -165,6 +179,7 @@ def __init__( WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress, WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED: on_event.on_conversation_audio_transcript_completed, WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, WebsocketsEventType.CONVERSATION_MESSAGE_COMPLETED: on_event.on_conversation_message_completed, WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, @@ -247,6 +262,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "data": Message.model_validate(data), } ) + elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED.value: + return ConversationAudioTranscriptCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": ConversationAudioTranscriptCompletedEvent.Data.model_validate(data), + } + ) elif event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: return ConversationChatRequiresActionEvent.model_validate( { @@ -338,6 +361,11 @@ async def on_conversation_message_delta( ): pass + async def on_conversation_audio_transcript_completed( + self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioTranscriptCompletedEvent + ): + pass + async def on_conversation_chat_requires_action( self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent ): @@ -381,6 +409,7 @@ def __init__( WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created, WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress, WebsocketsEventType.CONVERSATION_MESSAGE_DELTA: on_event.on_conversation_message_delta, + WebsocketsEventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED: on_event.on_conversation_audio_transcript_completed, WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION: on_event.on_conversation_chat_requires_action, WebsocketsEventType.CONVERSATION_MESSAGE_COMPLETED: on_event.on_conversation_message_completed, WebsocketsEventType.CONVERSATION_AUDIO_DELTA: on_event.on_conversation_audio_delta, @@ -463,6 +492,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: "data": Message.model_validate(data), } ) + elif event_type == WebsocketsEventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED.value: + return ConversationAudioTranscriptCompletedEvent.model_validate( + { + "id": event_id, + "detail": detail, + "data": ConversationAudioTranscriptCompletedEvent.Data.model_validate(data), + } + ) elif event_type == WebsocketsEventType.CONVERSATION_CHAT_REQUIRES_ACTION.value: return ConversationChatRequiresActionEvent.model_validate( { diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 318383f..70e6c28 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -99,6 +99,7 @@ class WebsocketsEventType(str, Enum): CONVERSATION_CHAT_IN_PROGRESS = "conversation.chat.in_progress" CONVERSATION_MESSAGE_DELTA = "conversation.message.delta" # get agent text message update CONVERSATION_CHAT_REQUIRES_ACTION = "conversation.chat.requires_action" # need plugin submit + CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED = "conversation.audio_transcript.completed" CONVERSATION_MESSAGE_COMPLETED = "conversation.message.completed" CONVERSATION_AUDIO_DELTA = "conversation.audio.delta" # get agent audio message update CONVERSATION_AUDIO_COMPLETED = "conversation.audio.completed" diff --git a/examples/benchmark_text_chat.py b/examples/benchmark_text_chat.py new file mode 100644 index 0000000..5cd9732 --- /dev/null +++ b/examples/benchmark_text_chat.py @@ -0,0 +1,103 @@ +import asyncio +import json +import logging +import os +import time +from typing import List, Optional + +from cozepy import ( + COZE_CN_BASE_URL, + ChatEventType, + Coze, + DeviceOAuthApp, + Message, + TokenAuth, + setup_logging, +) + + +def get_coze_api_base() -> str: + # The default access is api.coze.com, but if you need to access api.coze.cn, + # please use base_url to configure the api endpoint to access + coze_api_base = os.getenv("COZE_API_BASE") + if coze_api_base: + return coze_api_base + + return COZE_CN_BASE_URL # default + + +def get_coze_api_token(workspace_id: Optional[str] = None) -> str: + # Get an access_token through personal access token or oauth. + coze_api_token = os.getenv("COZE_API_TOKEN") + if coze_api_token: + return coze_api_token + + coze_api_base = get_coze_api_base() + + device_oauth_app = DeviceOAuthApp(client_id="57294420732781205987760324720643.app.coze", base_url=coze_api_base) + device_code = device_oauth_app.get_device_code(workspace_id) + print(f"Please Open: {device_code.verification_url} to get the access token") + return device_oauth_app.get_access_token(device_code=device_code.device_code, poll=True).access_token + + +def setup_examples_logger(): + coze_log = os.getenv("COZE_LOG") + if coze_log: + setup_logging(logging.getLevelNamesMapping().get(coze_log.upper(), logging.INFO)) + + +def get_current_time_ms(): + return int(time.time() * 1000) + + +setup_examples_logger() + +kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") + + +def cal_latency(latency_list: List[int]) -> str: + if latency_list is None or len(latency_list) == 0: + return "0" + if len(latency_list) == 1: + return f"{latency_list[0]}" + res = latency_list.copy() + res.sort() + return "%2d" % ((sum(res[:-1]) * 1.0) / (len(res) - 1)) + + +async def test_latency(coze: Coze, bot_id: str, text: str) -> (str, str, int): + start = get_current_time_ms() + stream = coze.chat.stream( + bot_id=bot_id, + user_id="user id", + additional_messages=[ + Message.build_user_question_text(text), + ], + ) + for event in stream: + if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA: + return stream.response.logid, event.message.content, get_current_time_ms() - start + + +async def main(): + coze_api_token = get_coze_api_token() + coze_api_base = get_coze_api_base() + bot_id = os.getenv("COZE_BOT_ID") + text = os.getenv("COZE_TEXT") or "讲个笑话" + + # Initialize Coze client + coze = Coze( + auth=TokenAuth(coze_api_token), + base_url=coze_api_base, + ) + + times = 50 + text_latency = [] + for i in range(times): + logid, text, latency = await test_latency(coze, bot_id, text) + text_latency.append(latency) + print(f"[latency.text] {i}, latency: {cal_latency(text_latency)} ms, log: {logid}, text: {text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/benchmark_websockets_chat.py b/examples/benchmark_websockets_chat.py new file mode 100644 index 0000000..2ad7b99 --- /dev/null +++ b/examples/benchmark_websockets_chat.py @@ -0,0 +1,184 @@ +import asyncio +import json +import logging +import os +import time +from typing import List, Optional + +from cozepy import ( + COZE_CN_BASE_URL, + AsyncCoze, + AsyncWebsocketsChatClient, + AsyncWebsocketsChatEventHandler, + AudioFormat, + ConversationAudioDeltaEvent, + ConversationAudioTranscriptCompletedEvent, + ConversationChatCreatedEvent, + ConversationMessageDeltaEvent, + DeviceOAuthApp, + InputAudioBufferAppendEvent, + TokenAuth, + WebsocketsEventType, + setup_logging, +) +from cozepy.log import log_info + + +def get_coze_api_base() -> str: + # The default access is api.coze.com, but if you need to access api.coze.cn, + # please use base_url to configure the api endpoint to access + coze_api_base = os.getenv("COZE_API_BASE") + if coze_api_base: + return coze_api_base + + return COZE_CN_BASE_URL # default + + +def get_coze_api_token(workspace_id: Optional[str] = None) -> str: + # Get an access_token through personal access token or oauth. + coze_api_token = os.getenv("COZE_API_TOKEN") + if coze_api_token: + return coze_api_token + + coze_api_base = get_coze_api_base() + + device_oauth_app = DeviceOAuthApp(client_id="57294420732781205987760324720643.app.coze", base_url=coze_api_base) + device_code = device_oauth_app.get_device_code(workspace_id) + print(f"Please Open: {device_code.verification_url} to get the access token") + return device_oauth_app.get_access_token(device_code=device_code.device_code, poll=True).access_token + + +def setup_examples_logger(): + coze_log = os.getenv("COZE_LOG") + if coze_log: + setup_logging(logging.getLevelNamesMapping().get(coze_log.upper(), logging.INFO)) + + +def get_current_time_ms(): + return int(time.time() * 1000) + + +setup_examples_logger() + +kwargs = json.loads(os.getenv("COZE_KWARGS") or "{}") + + +class AsyncWebsocketsChatEventHandlerSub(AsyncWebsocketsChatEventHandler): + """ + Class is not required, you can also use Dict to set callback + """ + + logid = "" + input_audio_buffer_completed_at = 0 + conversation_chat_created_at = 0 + conversation_audio_transcript_completed = 0 + text_first_token = 0 + audio_first_token = 0 + + async def on_error(self, cli: AsyncWebsocketsChatClient, e: Exception): + import traceback + + log_info(f"Error occurred: {str(e)}") + log_info(f"Stack trace:\n{traceback.format_exc()}") + + async def on_conversation_chat_created(self, cli: AsyncWebsocketsChatClient, event: ConversationChatCreatedEvent): + self.logid = event.detail.logid + self.conversation_chat_created_at = get_current_time_ms() + + async def on_conversation_audio_transcript_completed( + self, cli: AsyncWebsocketsChatClient, event: ConversationAudioTranscriptCompletedEvent + ): + self.conversation_audio_transcript_completed = get_current_time_ms() + + async def on_conversation_message_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationMessageDeltaEvent): + if self.text_first_token == 0: + self.text_first_token = get_current_time_ms() + + async def on_conversation_audio_delta(self, cli: AsyncWebsocketsChatClient, event: ConversationAudioDeltaEvent): + if self.audio_first_token == 0: + self.audio_first_token = get_current_time_ms() + + +async def generate_audio(coze: AsyncCoze, text: str) -> List[bytes]: + voices = await coze.audio.voices.list(**kwargs) + content = await coze.audio.speech.create( + input=text, + voice_id=voices.items[0].voice_id, + response_format=AudioFormat.WAV, + sample_rate=24000, + **kwargs, + ) + content.write_to_file("test.wav") + return [data for data in content._raw_response.iter_bytes(chunk_size=1024)] + + +def cal_latency(latency_list: List[int]) -> str: + if latency_list is None or len(latency_list) == 0: + return "0" + if len(latency_list) == 1: + return f"{latency_list[0]}" + res = latency_list.copy() + res.sort() + return "%2d" % ((sum(res[:-1]) * 1.0) / (len(res) - 1)) + + +async def test_latency(coze: AsyncCoze, bot_id: str, audios: List[bytes]) -> AsyncWebsocketsChatEventHandlerSub: + handler = AsyncWebsocketsChatEventHandlerSub() + chat = coze.websockets.chat.create( + bot_id=bot_id, + on_event=handler, + **kwargs, + ) + + # Create and connect WebSocket client + async with chat() as client: + # Read and send audio data + for delta in audios: + await client.input_audio_buffer_append( + InputAudioBufferAppendEvent.Data.model_validate( + { + "delta": delta, + } + ) + ) + await asyncio.sleep(len(delta) * 1.0 / 24000 / 2) + + await client.input_audio_buffer_complete() + handler.input_audio_buffer_completed_at = int(time.time() * 1000) + await client.wait( + events=[WebsocketsEventType.CONVERSATION_MESSAGE_DELTA, WebsocketsEventType.CONVERSATION_AUDIO_DELTA] + ) + + return handler + + +async def main(): + coze_api_token = get_coze_api_token() + coze_api_base = get_coze_api_base() + bot_id = os.getenv("COZE_BOT_ID") + text = os.getenv("COZE_TEXT") or "讲个笑话" + + # Initialize Coze client + coze = AsyncCoze( + auth=TokenAuth(coze_api_token), + base_url=coze_api_base, + ) + # Initialize Audio + audios = await generate_audio(coze, text) + + times = 50 + text_latency = [] + audio_latency = [] + asr_latency = [] + for i in range(times): + handler = await test_latency(coze, bot_id, audios) + asr_latency.append(handler.conversation_audio_transcript_completed - handler.input_audio_buffer_completed_at) + text_latency.append(handler.text_first_token - handler.input_audio_buffer_completed_at) + audio_latency.append(handler.audio_first_token - handler.input_audio_buffer_completed_at) + print( + f"[latency.ws] {i}, asr: {cal_latency(asr_latency)}, text: {cal_latency(text_latency)} ms, audio: {cal_latency(audio_latency)} ms, log: {handler.logid}" + ) + + +if __name__ == "__main__": + asyncio.run(main())