Skip to content

Commit

Permalink
feat: get and cancel chat v3
Browse files Browse the repository at this point in the history
  • Loading branch information
chyroc committed Sep 24, 2024
1 parent 325bc91 commit b7abe3b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
50 changes: 50 additions & 0 deletions cozepy/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,53 @@ def chat_v3(
return self._requester.request("post", url, Chat, body=body, stream=stream)

return ChatIterator(self._requester.request("post", url, Chat, body=body, stream=stream))

def get_v3(
self,
*,
conversation_id: str,
chat_id: str,
) -> Chat:
"""
Create a conversation.
Conversation is an interaction between a bot and a user, including one or more messages.
"""
url = f"{self._base_url}/v3/chat/retrieve"
params = {
"conversation_id": conversation_id,
"chat_id": chat_id,
}
return self._requester.request("post", url, Chat, params=params)

def list_message_v3(
self,
*,
conversation_id: str,
chat_id: str,
) -> List[Message]:
"""
Create a conversation.
Conversation is an interaction between a bot and a user, including one or more messages.
"""
url = f"{self._base_url}/v3/chat/message/list"
params = {
"conversation_id": conversation_id,
"chat_id": chat_id,
}
return self._requester.request("post", url, List[Message], params=params)

def cancel_v3(
self,
*,
conversation_id: str,
chat_id: str,
) -> Chat:
"""
Call this API to cancel an ongoing chat.
"""
url = f"{self._base_url}/v3/chat/cancel"
params = {
"conversation_id": conversation_id,
"chat_id": chat_id,
}
return self._requester.request("post", url, Chat, params=params)
12 changes: 8 additions & 4 deletions cozepy/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, Optional, Iterator, Union
from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args

import requests
from requests import Response
Expand Down Expand Up @@ -32,12 +32,12 @@ def request(
self,
method: str,
url: str,
model: Type[T],
model: Union[Type[T], List[Type[T]]],
params: dict = None,
headers: dict = None,
body: dict = None,
stream: bool = False,
) -> Union[T, Iterator[bytes]]:
) -> Union[T, List[T]]:
"""
Send a request to the server.
"""
Expand All @@ -58,7 +58,11 @@ def request(
elif code is None and msg != "":
logid = r.headers.get("x-tt-logid")
raise Exception(f"{msg}, logid:{logid}")
return model.model_validate(data)
if get_origin(model) is list:
item_model = get_args(model)[0]
return [item_model.model_validate(item) for item in data]
else:
return model.model_validate(data)

async def arequest(self, method: str, path: str, **kwargs) -> dict:
"""
Expand Down
9 changes: 9 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cozepy import TokenAuth, Coze, COZE_CN_BASE_URL, Message, ChatIterator, Event
from cozepy.auth import _random_hex
from cozepy.model import ChatStatus


def test_chat_v3_not_stream():
Expand All @@ -20,6 +21,14 @@ def test_chat_v3_not_stream():
assert chat is not None
assert chat.id != ""

while True:
chat = cli.chat.get_v3(conversation_id=chat.conversation_id, chat_id=chat.id)
if chat.status != ChatStatus.in_progress:
break
messages = cli.chat.list_message_v3(conversation_id=chat.conversation_id, chat_id=chat.id)
print(messages)
assert len(messages) > 0


def test_chat_v3_stream():
token = os.getenv("COZE_TOKEN").strip()
Expand Down

0 comments on commit b7abe3b

Please sign in to comment.