From bc868be9c4114d17f7d374a9c3d1706b1e9ef6ef Mon Sep 17 00:00:00 2001 From: chyroc Date: Sat, 28 Sep 2024 23:19:33 +0800 Subject: [PATCH] refactor: Merge Chat and Workflow Iterator into a generic Stream Iterator (#32) - Merged `ChatChatIterator` and `WorkflowIterator` into a single generic Stream Iterator - Unified iterator to support both Chat and Workflow module --- README.md | 4 +- cozepy/__init__.py | 6 +- cozepy/chat/__init__.py | 96 ++++++++++++------------------ cozepy/model.py | 42 ++++++++++++- cozepy/workflows/runs/__init__.py | 98 ++++++++++++------------------- tests/test_chat.py | 4 +- tests/test_workflow.py | 10 ++-- 7 files changed, 124 insertions(+), 136 deletions(-) diff --git a/README.md b/README.md index 8234257..9d82a1d 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ coze.files.retrieve(file_id=file.id) ### Workflows ```python -from cozepy import Coze, TokenAuth, WorkflowEventType, WorkflowEventIterator +from cozepy import Coze, TokenAuth, Stream, WorkflowEvent, WorkflowEventType coze = Coze(auth=TokenAuth("your_token")) @@ -197,7 +197,7 @@ result = coze.workflows.runs.create( # stream workflow run -def handle_workflow_iterator(iterator: WorkflowEventIterator): +def handle_workflow_iterator(iterator: Stream[WorkflowEvent]): for event in iterator: if event.event == WorkflowEventType.MESSAGE: print('got message', event.message) diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 3c3b4dc..725494e 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -10,7 +10,6 @@ ) from .chat import ( Chat, - ChatChatIterator, ChatEvent, ChatEventType, ChatStatus, @@ -47,6 +46,7 @@ from .model import ( LastIDPaged, NumberPaged, + Stream, TokenPaged, ) from .request import HTTPClient @@ -56,7 +56,6 @@ WorkflowEventError, WorkflowEventInterrupt, WorkflowEventInterruptData, - WorkflowEventIterator, WorkflowEventMessage, WorkflowEventType, WorkflowRunResult, @@ -91,7 +90,6 @@ "Message", "Chat", "ChatEvent", - "ChatChatIterator", "ToolOutput", # conversations "Conversation", @@ -115,7 +113,6 @@ "WorkflowEventInterrupt", "WorkflowEventError", "WorkflowEvent", - "WorkflowEventIterator", # workspaces "WorkspaceRoleType", "WorkspaceType", @@ -137,6 +134,7 @@ "TokenPaged", "NumberPaged", "LastIDPaged", + "Stream", # request "HTTPClient", ] diff --git a/cozepy/chat/__init__.py b/cozepy/chat/__init__.py index 9d57534..1ccc1c1 100644 --- a/cozepy/chat/__init__.py +++ b/cozepy/chat/__init__.py @@ -1,9 +1,8 @@ from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from cozepy.auth import Auth -from cozepy.exception import CozeEventError -from cozepy.model import CozeModel +from cozepy.model import CozeModel, Stream from cozepy.request import Requester if TYPE_CHECKING: @@ -232,57 +231,28 @@ class ChatEvent(CozeModel): message: Message = None -class ChatChatIterator(object): - def __init__(self, iters: Iterator[str]): - self._iters = iters - - def __iter__(self): - return self - - def __next__(self) -> ChatEvent: - event = "" - data = "" - line = "" - times = 0 - - while times < 2: - line = next(self._iters) - if line == "": - continue - elif line.startswith("event:"): - if event == "": - event = line[6:] - else: - raise CozeEventError("event", line) - elif line.startswith("data:"): - if data == "": - data = line[5:] - else: - raise CozeEventError("data", line) - else: - raise CozeEventError("", line) - - times += 1 - - if event == ChatEventType.DONE: - raise StopIteration - elif event == ChatEventType.ERROR: - raise Exception(f"error event: {line}") # TODO: error struct format - elif event in [ - ChatEventType.CONVERSATION_MESSAGE_DELTA, - ChatEventType.CONVERSATION_MESSAGE_COMPLETED, - ]: - return ChatEvent(event=event, message=Message.model_validate_json(data)) - elif event in [ - ChatEventType.CONVERSATION_CHAT_CREATED, - ChatEventType.CONVERSATION_CHAT_IN_PROGRESS, - ChatEventType.CONVERSATION_CHAT_COMPLETED, - ChatEventType.CONVERSATION_CHAT_FAILED, - ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION, - ]: - return ChatEvent(event=event, chat=Chat.model_validate_json(data)) - else: - raise ValueError(f"invalid chat.event: {event}, {data}") +def _chat_stream_handler(data: Dict) -> ChatEvent: + event = data["event"] + data = data["data"] + if event == ChatEventType.DONE: + raise StopIteration + elif event == ChatEventType.ERROR: + raise Exception(f"error event: {data}") # TODO: error struct format + elif event in [ + ChatEventType.CONVERSATION_MESSAGE_DELTA, + ChatEventType.CONVERSATION_MESSAGE_COMPLETED, + ]: + return ChatEvent(event=event, message=Message.model_validate_json(data)) + elif event in [ + ChatEventType.CONVERSATION_CHAT_CREATED, + ChatEventType.CONVERSATION_CHAT_IN_PROGRESS, + ChatEventType.CONVERSATION_CHAT_COMPLETED, + ChatEventType.CONVERSATION_CHAT_FAILED, + ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION, + ]: + return ChatEvent(event=event, chat=Chat.model_validate_json(data)) + else: + raise ValueError(f"invalid chat.event: {event}, {data}") class ToolOutput(CozeModel): @@ -350,7 +320,7 @@ def stream( auto_save_history: bool = True, meta_data: Dict[str, str] = None, conversation_id: str = None, - ) -> ChatChatIterator: + ) -> Stream[ChatEvent]: """ Call the Chat API with streaming to send messages to a published Coze bot. @@ -390,7 +360,7 @@ def _create( auto_save_history: bool = True, meta_data: Dict[str, str] = None, conversation_id: str = None, - ) -> Union[Chat, ChatChatIterator]: + ) -> Union[Chat, Stream[ChatEvent]]: """ Create a conversation. Conversation is an interaction between a bot and a user, including one or more messages. @@ -409,7 +379,11 @@ def _create( if not stream: return self._requester.request("post", url, Chat, body=body, stream=stream) - return ChatChatIterator(self._requester.request("post", url, Chat, body=body, stream=stream)) + return Stream( + self._requester.request("post", url, Chat, body=body, stream=stream), + fields=["event", "data"], + handler=_chat_stream_handler, + ) def retrieve( self, @@ -436,7 +410,7 @@ def retrieve( def submit_tool_outputs( self, *, conversation_id: str, chat_id: str, tool_outputs: List[ToolOutput], stream: bool - ) -> Union[Chat, ChatChatIterator]: + ) -> Union[Chat, Stream[ChatEvent]]: """ Call this API to submit the results of tool execution. @@ -466,7 +440,11 @@ def submit_tool_outputs( if not stream: return self._requester.request("post", url, Chat, params=params, body=body, stream=stream) - return ChatChatIterator(self._requester.request("post", url, Chat, params=params, body=body, stream=stream)) + return Stream( + self._requester.request("post", url, Chat, params=params, body=body, stream=stream), + fields=["event", "data"], + handler=_chat_stream_handler, + ) def cancel( self, diff --git a/cozepy/model.py b/cozepy/model.py index 6ebc93e..0fd713c 100644 --- a/cozepy/model.py +++ b/cozepy/model.py @@ -1,8 +1,10 @@ -from typing import Generic, List, TypeVar +from typing import Callable, Dict, Generic, Iterator, List, Tuple, TypeVar from pydantic import BaseModel, ConfigDict -T = TypeVar("T", bound=BaseModel) +from cozepy.exception import CozeEventError + +T = TypeVar("T") class CozeModel(BaseModel): @@ -67,3 +69,39 @@ def __init__( def __repr__(self): return f"LastIDPaged(items={self.items}, first_id={self.first_id}, last_id={self.last_id}, has_more={self.has_more})" + + +class Stream(Generic[T]): + def __init__(self, iters: Iterator[str], fields: List[str], handler: Callable[[Dict[str, str]], T]): + self._iters = iters + self._fields = fields + self._handler = handler + + def __iter__(self): + return self + + def __next__(self) -> T: + return self._handler(self._extra_event()) + + def _extra_event(self) -> Dict[str, str]: + data = dict(map(lambda x: (x, ""), self._fields)) + times = 0 + + while times < len(data): + line = next(self._iters) + if line == "": + continue + + field, value = self._extra_field_data(line, data) + data[field] = value + times += 1 + return data + + def _extra_field_data(self, line: str, data: Dict[str, str]) -> Tuple[str, str]: + for field in self._fields: + if line.startswith(field + ":"): + if data[field] == "": + return field, line[len(field) + 1 :].strip() + else: + raise CozeEventError(field, line) + raise CozeEventError("", line) diff --git a/cozepy/workflows/runs/__init__.py b/cozepy/workflows/runs/__init__.py index 8c43e63..dd0ba37 100644 --- a/cozepy/workflows/runs/__init__.py +++ b/cozepy/workflows/runs/__init__.py @@ -1,9 +1,8 @@ from enum import Enum -from typing import Any, Dict, Iterator +from typing import Any, Dict from cozepy.auth import Auth -from cozepy.exception import CozeEventError -from cozepy.model import CozeModel +from cozepy.model import CozeModel, Stream from cozepy.request import Requester @@ -100,61 +99,28 @@ class WorkflowEvent(CozeModel): error: WorkflowEventError = None -class WorkflowEventIterator(object): - def __init__(self, iters: Iterator[str]): - self._iters = iters - - def __iter__(self): - return self - - def __next__(self) -> WorkflowEvent: - id = "" - event = "" - data = "" - times = 0 - - while times < 3: - line = next(self._iters) - if line == "": - continue - elif line.startswith("id:"): - if event == "": - id = line[3:].strip() - else: - raise CozeEventError("id", line) - elif line.startswith("event:"): - if event == "": - event = line[6:].strip() - else: - raise CozeEventError("event", line) - elif line.startswith("data:"): - if data == "": - data = line[5:].strip() - else: - raise CozeEventError("data", line) - else: - raise CozeEventError("", line) - - times += 1 - - if event == WorkflowEventType.DONE: - raise StopIteration - elif event == WorkflowEventType.MESSAGE: - return WorkflowEvent( - id=id, - event=event, - message=WorkflowEventMessage.model_validate_json(data), - ) - elif event == WorkflowEventType.ERROR: - return WorkflowEvent(id=id, event=event, error=WorkflowEventError.model_validate_json(data)) - elif event == WorkflowEventType.INTERRUPT: - return WorkflowEvent( - id=id, - event=event, - interrupt=WorkflowEventInterrupt.model_validate_json(data), - ) - else: - raise ValueError(f"invalid workflows.event: {event}, {data}") +def _workflow_stream_handler(data: Dict[str, str]) -> WorkflowEvent: + id = data["id"] + event = data["event"] + data = data["data"] + if event == WorkflowEventType.DONE: + raise StopIteration + elif event == WorkflowEventType.MESSAGE: + return WorkflowEvent( + id=id, + event=event, + message=WorkflowEventMessage.model_validate_json(data), + ) + elif event == WorkflowEventType.ERROR: + return WorkflowEvent(id=id, event=event, error=WorkflowEventError.model_validate_json(data)) + elif event == WorkflowEventType.INTERRUPT: + return WorkflowEvent( + id=id, + event=event, + interrupt=WorkflowEventInterrupt.model_validate_json(data), + ) + else: + raise ValueError(f"invalid workflows.event: {event}, {data}") class WorkflowsClient(object): @@ -203,7 +169,7 @@ def stream( parameters: Dict[str, Any] = None, bot_id: str = None, ext: Dict[str, Any] = None, - ) -> WorkflowEventIterator: + ) -> Stream[WorkflowEvent]: """ Execute the published workflow with a streaming response method. @@ -225,7 +191,11 @@ def stream( "bot_id": bot_id, "ext": ext, } - return WorkflowEventIterator(self._requester.request("post", url, None, body=body, stream=True)) + return Stream( + self._requester.request("post", url, None, body=body, stream=True), + fields=["id", "event", "data"], + handler=_workflow_stream_handler, + ) def resume( self, @@ -234,7 +204,7 @@ def resume( event_id: str, resume_data: str, interrupt_type: int, - ) -> WorkflowEventIterator: + ) -> Stream[WorkflowEvent]: """ docs zh: https://www.coze.cn/docs/developer_guides/workflow_resume @@ -251,4 +221,8 @@ def resume( "resume_data": resume_data, "interrupt_type": interrupt_type, } - return WorkflowEventIterator(self._requester.request("post", url, None, body=body, stream=True)) + return Stream( + self._requester.request("post", url, None, body=body, stream=True), + fields=["id", "event", "data"], + handler=_workflow_stream_handler, + ) diff --git a/tests/test_chat.py b/tests/test_chat.py index c6b62e8..323e355 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,6 +1,6 @@ import os -from cozepy import COZE_CN_BASE_URL, ChatChatIterator, ChatEventType, Coze, Message +from cozepy import COZE_CN_BASE_URL, ChatEventType, Coze, Message from cozepy.auth import _random_hex from tests.config import fixed_token_auth @@ -32,7 +32,7 @@ def test_chat_stream(): cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL) - chat_iter: ChatChatIterator = cli.chat.stream( + chat_iter = cli.chat.stream( bot_id=bot_id, user_id=_random_hex(10), additional_messages=[Message.user_text_message("Hi, how are you?")], diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 11fb239..b5bbd8f 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,6 +1,6 @@ import unittest -from cozepy import Coze, COZE_CN_BASE_URL, WorkflowEventType, WorkflowEventIterator +from cozepy import COZE_CN_BASE_URL, Coze, Stream, WorkflowEvent, WorkflowEventType from tests.config import fixed_token_auth @@ -30,13 +30,13 @@ def test_workflows(): }, ) - def handle_iter(iter: WorkflowEventIterator): + def handle_iter(iter: Stream[WorkflowEvent]): for item in iter: - if item.event == WorkflowEventType.message: + if item.event == WorkflowEventType.MESSAGE: print("msg", item.message) - elif item.event == WorkflowEventType.error: + elif item.event == WorkflowEventType.ERROR: print("error", item.error) - elif item.event == WorkflowEventType.interrupt: + elif item.event == WorkflowEventType.INTERRUPT: print("interrupt", item.interrupt) handle_iter( cli.workflows.runs.resume(