From 78d5c54f8340a571ba2c80080fb0783fdadee015 Mon Sep 17 00:00:00 2001 From: chyroc Date: Thu, 26 Sep 2024 11:20:43 +0800 Subject: [PATCH] feat: add message api (#14) --- cozepy/conversation/v1/__init__.py | 9 ++ cozepy/conversation/v1/message/__init__.py | 128 +++++++++++++++++++++ cozepy/model.py | 19 ++- cozepy/request.py | 27 +++-- tests/test_conversation_message.py | 58 ++++++++++ 5 files changed, 231 insertions(+), 10 deletions(-) create mode 100644 cozepy/conversation/v1/message/__init__.py create mode 100644 tests/test_conversation_message.py diff --git a/cozepy/conversation/v1/__init__.py b/cozepy/conversation/v1/__init__.py index 95d1b58..1c884cc 100644 --- a/cozepy/conversation/v1/__init__.py +++ b/cozepy/conversation/v1/__init__.py @@ -16,6 +16,15 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester + self._message = None + + @property + def message(self): + if not self._message: + from .message import MessageClient + + self._message = MessageClient(self._base_url, self._auth, self._requester) + return self._message def create(self, *, messages: List[Message] = None, meta_data: Dict[str, str] = None) -> Conversation: """ diff --git a/cozepy/conversation/v1/message/__init__.py b/cozepy/conversation/v1/message/__init__.py new file mode 100644 index 0000000..db5aab7 --- /dev/null +++ b/cozepy/conversation/v1/message/__init__.py @@ -0,0 +1,128 @@ +from typing import Dict, List + +from cozepy import MessageRole, MessageContentType, Message +from cozepy.auth import Auth +from cozepy.model import CozeModel, LastIDPaged +from cozepy.request import Requester + + +class MessageClient(object): + """ + Message class. + """ + + def __init__(self, base_url: str, auth: Auth, requester: Requester): + self._base_url = base_url + self._auth = auth + self._requester = requester + + def create( + self, + *, + conversation_id: str, + role: MessageRole, + content: str, + content_type: MessageContentType, + meta_data: Dict[str, str] = None, + ) -> Message: + url = f"{self._base_url}/v1/conversation/message/create" + params = { + "conversation_id": conversation_id, + } + body = { + "role": role, + "content": content, + "content_type": content_type, + "meta_data": meta_data, + } + + return self._requester.request("post", url, Message, params=params, body=body) + + def list( + self, + *, + conversation_id: str, + order: str = "desc", + chat_id: str = None, + before_id: str = None, + after_id: str = None, + limit: int = 50, + ) -> LastIDPaged[Message]: + url = f"{self._base_url}/v1/conversation/message/list" + params = { + "conversation_id": conversation_id, + } + body = { + "order": order, + "chat_id": chat_id, + "before_id": before_id, + "after_id": after_id, + "limit": limit, + } + + res = self._requester.request("post", url, self._PrivateListMessageResp, params=params, body=body) + return LastIDPaged(res.items, res.first_id, res.last_id, res.has_more) + + def retrieve( + self, + *, + conversation_id: str, + message_id: str, + ) -> Message: + url = f"{self._base_url}/v1/conversation/message/retrieve" + params = { + "conversation_id": conversation_id, + "message_id": message_id, + } + + return self._requester.request("get", url, Message, params=params) + + def update( + self, + *, + conversation_id: str, + message_id: str, + content: str = None, + content_type: MessageContentType = None, + meta_data: Dict[str, str] = None, + ) -> Message: + url = f"{self._base_url}/v1/conversation/message/modify" + params = { + "conversation_id": conversation_id, + "message_id": message_id, + } + body = { + "content": content, + "content_type": content_type, + "meta_data": meta_data, + } + + return self._requester.request("post", url, Message, params=params, body=body, data_field="message") + + def delete( + self, + *, + conversation_id: str, + message_id: str, + content: str = None, + content_type: MessageContentType = None, + meta_data: Dict[str, str] = None, + ) -> Message: + url = f"{self._base_url}/v1/conversation/message/delete" + params = { + "conversation_id": conversation_id, + "message_id": message_id, + } + body = { + "content": content, + "content_type": content_type, + "meta_data": meta_data, + } + + return self._requester.request("post", url, Message, params=params, body=body) + + class _PrivateListMessageResp(CozeModel): + first_id: str + last_id: str + has_more: bool + items: List[Message] diff --git a/cozepy/model.py b/cozepy/model.py index 509a6ec..d36e00c 100644 --- a/cozepy/model.py +++ b/cozepy/model.py @@ -1,7 +1,6 @@ from enum import Enum -from typing import TypeVar, Generic, List, Optional, Dict - from pydantic import BaseModel, ConfigDict +from typing import TypeVar, Generic, List, Optional, Dict T = TypeVar("T", bound=BaseModel) @@ -52,6 +51,18 @@ def __repr__(self): ) +class LastIDPaged(PagedBase[T]): + def __init__(self, items: List[T], first_id: str = "", last_id: str = "", has_more: bool = None): + has_more = has_more if has_more is not None else last_id != "" + super().__init__(items, has_more) + self.first_id = first_id + self.last_id = last_id + self.has_more = has_more + + def __repr__(self): + return f"LastIDPaged(items={self.items}, first_id={self.first_id}, last_id={self.last_id}, has_more={self.has_more})" + + class MessageRole(str, Enum): # Indicates that the content of the message is sent by the user. user = "user" @@ -82,6 +93,8 @@ class MessageType(str, Enum): # 多 answer 场景下,服务端会返回一个 verbose 包,对应的 content 为 JSON 格式,content.msg_type =generate_answer_finish 代表全部 answer 回复完成。不支持在请求中作为入参。 verbose = "verbose" + unknown = "" + class MessageContentType(str, Enum): # Text. @@ -126,7 +139,7 @@ class Message(CozeModel): # The entity that sent this message. role: MessageRole # The type of message. - type: MessageType + type: MessageType = "" # The content of the message. It supports various types of content, including plain text, multimodal (a mix of text, images, and files), message cards, and more. # 消息的内容,支持纯文本、多模态(文本、图片、文件混合输入)、卡片等多种类型的内容。 content: str diff --git a/cozepy/request.py b/cozepy/request.py index 88710f0..847a3a8 100644 --- a/cozepy/request.py +++ b/cozepy/request.py @@ -1,7 +1,6 @@ -from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args, Iterator - import requests from requests import Response +from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args, Iterator if TYPE_CHECKING: from cozepy.auth import Auth @@ -37,6 +36,7 @@ def request( headers: dict = None, body: dict = None, stream: bool = False, + data_field: str = "data", ) -> Union[T, List[T], Iterator[bytes]]: """ Send a request to the server. @@ -49,7 +49,7 @@ def request( if stream: return r.iter_lines() - code, msg, data = self.__parse_requests_code_msg(r) + code, msg, data = self.__parse_requests_code_msg(r, data_field) if code is not None and code > 0: # TODO: Exception 自定义类型 @@ -72,7 +72,9 @@ async def arequest(self, method: str, path: str, **kwargs) -> dict: """ pass - def __parse_requests_code_msg(self, r: Response) -> Tuple[Optional[int], str, Optional[T]]: + def __parse_requests_code_msg( + self, r: Response, data_field: str = "data" + ) -> Tuple[Optional[int], str, Optional[T]]: try: json = r.json() except: @@ -80,9 +82,20 @@ def __parse_requests_code_msg(self, r: Response) -> Tuple[Optional[int], str, Op return if "code" in json and "msg" in json and int(json["code"]) > 0: - return int(json["code"]), json["msg"], json.get("data") or None + return int(json["code"]), json["msg"], json.get(data_field) or None if "error_message" in json and json["error_message"] != "": return None, json["error_message"], None - if "data" in json: - return 0, "", json["data"] + if data_field in json: + if "first_id" in json: + return ( + 0, + "", + { + "first_id": json["first_id"], + "has_more": json["has_more"], + "last_id": json["last_id"], + "items": json["data"], + }, + ) + return 0, "", json[data_field] return 0, "", json diff --git a/tests/test_conversation_message.py b/tests/test_conversation_message.py new file mode 100644 index 0000000..a3d8f45 --- /dev/null +++ b/tests/test_conversation_message.py @@ -0,0 +1,58 @@ +import time +import unittest +from cozepy import Coze, COZE_CN_BASE_URL, Message +from tests.config import fixed_token_auth + + +@unittest.skip("not available in not cn") +def test_conversation_message(): + cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL) + + # create conversation + conversation = cli.conversation.v1.create( + messages=[ + Message.user_text_message("who are you?"), + Message.assistant_text_message("i am your friend bob."), + ] + ) + assert conversation is not None + + # retrieve conversation + conversation_retrieve = cli.conversation.v1.retrieve(conversation_id=conversation.id) + assert conversation.id == conversation_retrieve.id + + # create message + user_input = Message.user_text_message("nice to meet you.") + message = cli.conversation.v1.message.create( + conversation_id=conversation.id, + role=user_input.role, + content=user_input.content, + content_type=user_input.content_type, + ) + assert message is not None + assert message.id != "" + + time.sleep(1) + + # retrieve message + message_retrieve = cli.conversation.v1.message.retrieve(conversation_id=conversation.id, message_id=message.id) + assert message_retrieve is not None + assert message.id == message_retrieve.id + + # list message + message_list = cli.conversation.v1.message.list(conversation_id=conversation.id, message_id=message.id) + assert len(message_list) > 2 + + # update message + cli.conversation.v1.message.update( + conversation_id=conversation.id, + message_id=message.id, + content="wow, nice to meet you", + content_type=message.content_type, + ) + + # delete message + cli.conversation.v1.message.delete( + conversation_id=conversation.id, + message_id=message.id, + )