Skip to content

Commit

Permalink
feat: add message api (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
chyroc authored Sep 26, 2024
1 parent 11e5581 commit 78d5c54
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 10 deletions.
9 changes: 9 additions & 0 deletions cozepy/conversation/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester):
self._base_url = base_url
self._auth = auth
self._requester = requester
self._message = None

@property
def message(self):
if not self._message:
from .message import MessageClient

self._message = MessageClient(self._base_url, self._auth, self._requester)
return self._message

def create(self, *, messages: List[Message] = None, meta_data: Dict[str, str] = None) -> Conversation:
"""
Expand Down
128 changes: 128 additions & 0 deletions cozepy/conversation/v1/message/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Dict, List

from cozepy import MessageRole, MessageContentType, Message
from cozepy.auth import Auth
from cozepy.model import CozeModel, LastIDPaged
from cozepy.request import Requester


class MessageClient(object):
"""
Message class.
"""

def __init__(self, base_url: str, auth: Auth, requester: Requester):
self._base_url = base_url
self._auth = auth
self._requester = requester

def create(
self,
*,
conversation_id: str,
role: MessageRole,
content: str,
content_type: MessageContentType,
meta_data: Dict[str, str] = None,
) -> Message:
url = f"{self._base_url}/v1/conversation/message/create"
params = {
"conversation_id": conversation_id,
}
body = {
"role": role,
"content": content,
"content_type": content_type,
"meta_data": meta_data,
}

return self._requester.request("post", url, Message, params=params, body=body)

def list(
self,
*,
conversation_id: str,
order: str = "desc",
chat_id: str = None,
before_id: str = None,
after_id: str = None,
limit: int = 50,
) -> LastIDPaged[Message]:
url = f"{self._base_url}/v1/conversation/message/list"
params = {
"conversation_id": conversation_id,
}
body = {
"order": order,
"chat_id": chat_id,
"before_id": before_id,
"after_id": after_id,
"limit": limit,
}

res = self._requester.request("post", url, self._PrivateListMessageResp, params=params, body=body)
return LastIDPaged(res.items, res.first_id, res.last_id, res.has_more)

def retrieve(
self,
*,
conversation_id: str,
message_id: str,
) -> Message:
url = f"{self._base_url}/v1/conversation/message/retrieve"
params = {
"conversation_id": conversation_id,
"message_id": message_id,
}

return self._requester.request("get", url, Message, params=params)

def update(
self,
*,
conversation_id: str,
message_id: str,
content: str = None,
content_type: MessageContentType = None,
meta_data: Dict[str, str] = None,
) -> Message:
url = f"{self._base_url}/v1/conversation/message/modify"
params = {
"conversation_id": conversation_id,
"message_id": message_id,
}
body = {
"content": content,
"content_type": content_type,
"meta_data": meta_data,
}

return self._requester.request("post", url, Message, params=params, body=body, data_field="message")

def delete(
self,
*,
conversation_id: str,
message_id: str,
content: str = None,
content_type: MessageContentType = None,
meta_data: Dict[str, str] = None,
) -> Message:
url = f"{self._base_url}/v1/conversation/message/delete"
params = {
"conversation_id": conversation_id,
"message_id": message_id,
}
body = {
"content": content,
"content_type": content_type,
"meta_data": meta_data,
}

return self._requester.request("post", url, Message, params=params, body=body)

class _PrivateListMessageResp(CozeModel):
first_id: str
last_id: str
has_more: bool
items: List[Message]
19 changes: 16 additions & 3 deletions cozepy/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from enum import Enum
from typing import TypeVar, Generic, List, Optional, Dict

from pydantic import BaseModel, ConfigDict
from typing import TypeVar, Generic, List, Optional, Dict

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -52,6 +51,18 @@ def __repr__(self):
)


class LastIDPaged(PagedBase[T]):
def __init__(self, items: List[T], first_id: str = "", last_id: str = "", has_more: bool = None):
has_more = has_more if has_more is not None else last_id != ""
super().__init__(items, has_more)
self.first_id = first_id
self.last_id = last_id
self.has_more = has_more

def __repr__(self):
return f"LastIDPaged(items={self.items}, first_id={self.first_id}, last_id={self.last_id}, has_more={self.has_more})"


class MessageRole(str, Enum):
# Indicates that the content of the message is sent by the user.
user = "user"
Expand Down Expand Up @@ -82,6 +93,8 @@ class MessageType(str, Enum):
# 多 answer 场景下,服务端会返回一个 verbose 包,对应的 content 为 JSON 格式,content.msg_type =generate_answer_finish 代表全部 answer 回复完成。不支持在请求中作为入参。
verbose = "verbose"

unknown = ""


class MessageContentType(str, Enum):
# Text.
Expand Down Expand Up @@ -126,7 +139,7 @@ class Message(CozeModel):
# The entity that sent this message.
role: MessageRole
# The type of message.
type: MessageType
type: MessageType = ""
# The content of the message. It supports various types of content, including plain text, multimodal (a mix of text, images, and files), message cards, and more.
# 消息的内容,支持纯文本、多模态(文本、图片、文件混合输入)、卡片等多种类型的内容。
content: str
Expand Down
27 changes: 20 additions & 7 deletions cozepy/request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args, Iterator

import requests
from requests import Response
from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args, Iterator

if TYPE_CHECKING:
from cozepy.auth import Auth
Expand Down Expand Up @@ -37,6 +36,7 @@ def request(
headers: dict = None,
body: dict = None,
stream: bool = False,
data_field: str = "data",
) -> Union[T, List[T], Iterator[bytes]]:
"""
Send a request to the server.
Expand All @@ -49,7 +49,7 @@ def request(
if stream:
return r.iter_lines()

code, msg, data = self.__parse_requests_code_msg(r)
code, msg, data = self.__parse_requests_code_msg(r, data_field)

if code is not None and code > 0:
# TODO: Exception 自定义类型
Expand All @@ -72,17 +72,30 @@ async def arequest(self, method: str, path: str, **kwargs) -> dict:
"""
pass

def __parse_requests_code_msg(self, r: Response) -> Tuple[Optional[int], str, Optional[T]]:
def __parse_requests_code_msg(
self, r: Response, data_field: str = "data"
) -> Tuple[Optional[int], str, Optional[T]]:
try:
json = r.json()
except:
r.raise_for_status()
return

if "code" in json and "msg" in json and int(json["code"]) > 0:
return int(json["code"]), json["msg"], json.get("data") or None
return int(json["code"]), json["msg"], json.get(data_field) or None
if "error_message" in json and json["error_message"] != "":
return None, json["error_message"], None
if "data" in json:
return 0, "", json["data"]
if data_field in json:
if "first_id" in json:
return (
0,
"",
{
"first_id": json["first_id"],
"has_more": json["has_more"],
"last_id": json["last_id"],
"items": json["data"],
},
)
return 0, "", json[data_field]
return 0, "", json
58 changes: 58 additions & 0 deletions tests/test_conversation_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import time
import unittest
from cozepy import Coze, COZE_CN_BASE_URL, Message
from tests.config import fixed_token_auth


@unittest.skip("not available in not cn")
def test_conversation_message():
cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL)

# create conversation
conversation = cli.conversation.v1.create(
messages=[
Message.user_text_message("who are you?"),
Message.assistant_text_message("i am your friend bob."),
]
)
assert conversation is not None

# retrieve conversation
conversation_retrieve = cli.conversation.v1.retrieve(conversation_id=conversation.id)
assert conversation.id == conversation_retrieve.id

# create message
user_input = Message.user_text_message("nice to meet you.")
message = cli.conversation.v1.message.create(
conversation_id=conversation.id,
role=user_input.role,
content=user_input.content,
content_type=user_input.content_type,
)
assert message is not None
assert message.id != ""

time.sleep(1)

# retrieve message
message_retrieve = cli.conversation.v1.message.retrieve(conversation_id=conversation.id, message_id=message.id)
assert message_retrieve is not None
assert message.id == message_retrieve.id

# list message
message_list = cli.conversation.v1.message.list(conversation_id=conversation.id, message_id=message.id)
assert len(message_list) > 2

# update message
cli.conversation.v1.message.update(
conversation_id=conversation.id,
message_id=message.id,
content="wow, nice to meet you",
content_type=message.content_type,
)

# delete message
cli.conversation.v1.message.delete(
conversation_id=conversation.id,
message_id=message.id,
)

0 comments on commit 78d5c54

Please sign in to comment.