Skip to content

Commit

Permalink
refactor: Merge Chat and Workflow Iterator into a generic Stream Iter…
Browse files Browse the repository at this point in the history
…ator (#32)

- Merged `ChatChatIterator` and `WorkflowIterator` into a single generic
Stream Iterator
- Unified iterator to support both Chat and Workflow module
  • Loading branch information
chyroc authored Sep 28, 2024
1 parent 23a6785 commit bc868be
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 136 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ coze.files.retrieve(file_id=file.id)
### Workflows

```python
from cozepy import Coze, TokenAuth, WorkflowEventType, WorkflowEventIterator
from cozepy import Coze, TokenAuth, Stream, WorkflowEvent, WorkflowEventType

coze = Coze(auth=TokenAuth("your_token"))

Expand All @@ -197,7 +197,7 @@ result = coze.workflows.runs.create(


# stream workflow run
def handle_workflow_iterator(iterator: WorkflowEventIterator):
def handle_workflow_iterator(iterator: Stream[WorkflowEvent]):
for event in iterator:
if event.event == WorkflowEventType.MESSAGE:
print('got message', event.message)
Expand Down
6 changes: 2 additions & 4 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from .chat import (
Chat,
ChatChatIterator,
ChatEvent,
ChatEventType,
ChatStatus,
Expand Down Expand Up @@ -47,6 +46,7 @@
from .model import (
LastIDPaged,
NumberPaged,
Stream,
TokenPaged,
)
from .request import HTTPClient
Expand All @@ -56,7 +56,6 @@
WorkflowEventError,
WorkflowEventInterrupt,
WorkflowEventInterruptData,
WorkflowEventIterator,
WorkflowEventMessage,
WorkflowEventType,
WorkflowRunResult,
Expand Down Expand Up @@ -91,7 +90,6 @@
"Message",
"Chat",
"ChatEvent",
"ChatChatIterator",
"ToolOutput",
# conversations
"Conversation",
Expand All @@ -115,7 +113,6 @@
"WorkflowEventInterrupt",
"WorkflowEventError",
"WorkflowEvent",
"WorkflowEventIterator",
# workspaces
"WorkspaceRoleType",
"WorkspaceType",
Expand All @@ -137,6 +134,7 @@
"TokenPaged",
"NumberPaged",
"LastIDPaged",
"Stream",
# request
"HTTPClient",
]
96 changes: 37 additions & 59 deletions cozepy/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from cozepy.auth import Auth
from cozepy.exception import CozeEventError
from cozepy.model import CozeModel
from cozepy.model import CozeModel, Stream
from cozepy.request import Requester

if TYPE_CHECKING:
Expand Down Expand Up @@ -232,57 +231,28 @@ class ChatEvent(CozeModel):
message: Message = None


class ChatChatIterator(object):
def __init__(self, iters: Iterator[str]):
self._iters = iters

def __iter__(self):
return self

def __next__(self) -> ChatEvent:
event = ""
data = ""
line = ""
times = 0

while times < 2:
line = next(self._iters)
if line == "":
continue
elif line.startswith("event:"):
if event == "":
event = line[6:]
else:
raise CozeEventError("event", line)
elif line.startswith("data:"):
if data == "":
data = line[5:]
else:
raise CozeEventError("data", line)
else:
raise CozeEventError("", line)

times += 1

if event == ChatEventType.DONE:
raise StopIteration
elif event == ChatEventType.ERROR:
raise Exception(f"error event: {line}") # TODO: error struct format
elif event in [
ChatEventType.CONVERSATION_MESSAGE_DELTA,
ChatEventType.CONVERSATION_MESSAGE_COMPLETED,
]:
return ChatEvent(event=event, message=Message.model_validate_json(data))
elif event in [
ChatEventType.CONVERSATION_CHAT_CREATED,
ChatEventType.CONVERSATION_CHAT_IN_PROGRESS,
ChatEventType.CONVERSATION_CHAT_COMPLETED,
ChatEventType.CONVERSATION_CHAT_FAILED,
ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION,
]:
return ChatEvent(event=event, chat=Chat.model_validate_json(data))
else:
raise ValueError(f"invalid chat.event: {event}, {data}")
def _chat_stream_handler(data: Dict) -> ChatEvent:
event = data["event"]
data = data["data"]
if event == ChatEventType.DONE:
raise StopIteration
elif event == ChatEventType.ERROR:
raise Exception(f"error event: {data}") # TODO: error struct format
elif event in [
ChatEventType.CONVERSATION_MESSAGE_DELTA,
ChatEventType.CONVERSATION_MESSAGE_COMPLETED,
]:
return ChatEvent(event=event, message=Message.model_validate_json(data))
elif event in [
ChatEventType.CONVERSATION_CHAT_CREATED,
ChatEventType.CONVERSATION_CHAT_IN_PROGRESS,
ChatEventType.CONVERSATION_CHAT_COMPLETED,
ChatEventType.CONVERSATION_CHAT_FAILED,
ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION,
]:
return ChatEvent(event=event, chat=Chat.model_validate_json(data))
else:
raise ValueError(f"invalid chat.event: {event}, {data}")


class ToolOutput(CozeModel):
Expand Down Expand Up @@ -350,7 +320,7 @@ def stream(
auto_save_history: bool = True,
meta_data: Dict[str, str] = None,
conversation_id: str = None,
) -> ChatChatIterator:
) -> Stream[ChatEvent]:
"""
Call the Chat API with streaming to send messages to a published Coze bot.
Expand Down Expand Up @@ -390,7 +360,7 @@ def _create(
auto_save_history: bool = True,
meta_data: Dict[str, str] = None,
conversation_id: str = None,
) -> Union[Chat, ChatChatIterator]:
) -> Union[Chat, Stream[ChatEvent]]:
"""
Create a conversation.
Conversation is an interaction between a bot and a user, including one or more messages.
Expand All @@ -409,7 +379,11 @@ def _create(
if not stream:
return self._requester.request("post", url, Chat, body=body, stream=stream)

return ChatChatIterator(self._requester.request("post", url, Chat, body=body, stream=stream))
return Stream(
self._requester.request("post", url, Chat, body=body, stream=stream),
fields=["event", "data"],
handler=_chat_stream_handler,
)

def retrieve(
self,
Expand All @@ -436,7 +410,7 @@ def retrieve(

def submit_tool_outputs(
self, *, conversation_id: str, chat_id: str, tool_outputs: List[ToolOutput], stream: bool
) -> Union[Chat, ChatChatIterator]:
) -> Union[Chat, Stream[ChatEvent]]:
"""
Call this API to submit the results of tool execution.
Expand Down Expand Up @@ -466,7 +440,11 @@ def submit_tool_outputs(
if not stream:
return self._requester.request("post", url, Chat, params=params, body=body, stream=stream)

return ChatChatIterator(self._requester.request("post", url, Chat, params=params, body=body, stream=stream))
return Stream(
self._requester.request("post", url, Chat, params=params, body=body, stream=stream),
fields=["event", "data"],
handler=_chat_stream_handler,
)

def cancel(
self,
Expand Down
42 changes: 40 additions & 2 deletions cozepy/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Generic, List, TypeVar
from typing import Callable, Dict, Generic, Iterator, List, Tuple, TypeVar

from pydantic import BaseModel, ConfigDict

T = TypeVar("T", bound=BaseModel)
from cozepy.exception import CozeEventError

T = TypeVar("T")


class CozeModel(BaseModel):
Expand Down Expand Up @@ -67,3 +69,39 @@ def __init__(

def __repr__(self):
return f"LastIDPaged(items={self.items}, first_id={self.first_id}, last_id={self.last_id}, has_more={self.has_more})"


class Stream(Generic[T]):
def __init__(self, iters: Iterator[str], fields: List[str], handler: Callable[[Dict[str, str]], T]):
self._iters = iters
self._fields = fields
self._handler = handler

def __iter__(self):
return self

def __next__(self) -> T:
return self._handler(self._extra_event())

def _extra_event(self) -> Dict[str, str]:
data = dict(map(lambda x: (x, ""), self._fields))
times = 0

while times < len(data):
line = next(self._iters)
if line == "":
continue

field, value = self._extra_field_data(line, data)
data[field] = value
times += 1
return data

def _extra_field_data(self, line: str, data: Dict[str, str]) -> Tuple[str, str]:
for field in self._fields:
if line.startswith(field + ":"):
if data[field] == "":
return field, line[len(field) + 1 :].strip()
else:
raise CozeEventError(field, line)
raise CozeEventError("", line)
Loading

0 comments on commit bc868be

Please sign in to comment.