diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 7bd5083..c045afe 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -6,7 +6,7 @@ from cozepy.auth import Auth from cozepy.log import log_warning from cozepy.request import Requester -from cozepy.util import remove_url_trailing_slash +from cozepy.util import remove_none_values, remove_url_trailing_slash from cozepy.websockets.audio.transcriptions import ( InputAudioBufferAppendEvent, InputAudioBufferCompletedEvent, @@ -167,6 +167,7 @@ def __init__( auth: Auth, requester: Requester, bot_id: str, + workflow_id: str, on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): @@ -192,9 +193,12 @@ def __init__( auth=auth, requester=requester, path="v1/chat", - query={ - "bot_id": bot_id, - }, + query=remove_none_values( + { + "bot_id": bot_id, + "workflow_id": workflow_id, + } + ), on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], **kwargs, @@ -323,6 +327,7 @@ def create( self, *, bot_id: str, + workflow_id: str, on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ) -> WebsocketsChatClient: @@ -331,6 +336,7 @@ def create( auth=self._auth, requester=self._requester, bot_id=bot_id, + workflow_id=workflow_id, on_event=on_event, # type: ignore **kwargs, ) @@ -397,6 +403,7 @@ def __init__( auth: Auth, requester: Requester, bot_id: str, + workflow_id: str, on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ): @@ -422,9 +429,12 @@ def __init__( auth=auth, requester=requester, path="v1/chat", - query={ - "bot_id": bot_id, - }, + query=remove_none_values( + { + "bot_id": bot_id, + "workflow_id": workflow_id, + } + ), on_event=on_event, # type: ignore wait_events=[WebsocketsEventType.CONVERSATION_CHAT_COMPLETED], **kwargs, @@ -553,6 +563,7 @@ def create( self, *, bot_id: str, + workflow_id: str, on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs, ) -> AsyncWebsocketsChatClient: @@ -561,6 +572,7 @@ def create( auth=self._auth, requester=self._requester, bot_id=bot_id, + workflow_id=workflow_id, on_event=on_event, # type: ignore **kwargs, ) diff --git a/examples/benchmark_ark_text.py b/examples/benchmark_ark_text.py new file mode 100644 index 0000000..946f345 --- /dev/null +++ b/examples/benchmark_ark_text.py @@ -0,0 +1,55 @@ +import asyncio +import os +import time +from typing import List + + +def get_current_time_ms(): + return int(time.time() * 1000) + + +def cal_latency(latency_list: List[int]) -> str: + if latency_list is None or len(latency_list) == 0: + return "0" + if len(latency_list) == 1: + return f"{latency_list[0]}" + res = latency_list.copy() + res.sort() + return "%2d" % ((sum(res[:-1]) * 1.0) / (len(res) - 1)) + + +def test_latency(ep: str, token: str, text: str): + from volcenginesdkarkruntime import Ark + + client = Ark(base_url="https://ark.cn-beijing.volces.com/api/v3", api_key=token) + start = get_current_time_ms() + stream = client.chat.completions.create( + model=ep, + messages=[ + {"role": "user", "content": text}, + ], + stream=True, + ) + for chunk in stream: + if not chunk.choices: + continue + + if chunk.choices[0].delta.content: + return "", chunk.choices[0].delta.content, get_current_time_ms() - start + + +async def main(): + ep = os.getenv("ARK_EP") + token = os.getenv("ARK_TOKEN") + text = os.getenv("COZE_TEXT") or "讲个笑话" + + times = 50 + text_latency = [] + for i in range(times): + logid, text, latency = test_latency(ep, token, text) + text_latency.append(latency) + print(f"[latency.ark.text] {i}, latency: {cal_latency(text_latency)} ms, log: {logid}, text: {text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/benchmark_websockets_chat.py b/examples/benchmark_websockets_chat.py index 2ad7b99..daff190 100644 --- a/examples/benchmark_websockets_chat.py +++ b/examples/benchmark_websockets_chat.py @@ -108,7 +108,7 @@ async def generate_audio(coze: AsyncCoze, text: str) -> List[bytes]: sample_rate=24000, **kwargs, ) - content.write_to_file("test.wav") + # content.write_to_file("test.wav") return [data for data in content._raw_response.iter_bytes(chunk_size=1024)]