diff --git a/cozepy/__init__.py b/cozepy/__init__.py index 80e4246..182f4ea 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -1,4 +1,4 @@ -from .auth import ApplicationOAuth, Auth, TokenAuth +from .auth import ApplicationOAuth, Auth, TokenAuth, JWTAuth from .config import COZE_COM_BASE_URL, COZE_CN_BASE_URL from .coze import Coze from .model import ( @@ -16,6 +16,7 @@ "ApplicationOAuth", "Auth", "TokenAuth", + "JWTAuth", "COZE_COM_BASE_URL", "COZE_CN_BASE_URL", "Coze", diff --git a/cozepy/auth.py b/cozepy/auth.py index 9781f43..40ebc31 100644 --- a/cozepy/auth.py +++ b/cozepy/auth.py @@ -56,7 +56,7 @@ def __init__(self, client_id: str, client_secret: str = "", base_url: str = COZE self._token = "" self._requester = Requester() - def jwt_auth(self, private_key: str, kid: str, ttl: int): + def jwt_auth(self, private_key: str, kid: str, ttl: int) -> OAuthToken: """ Get the token by jwt with jwt auth flow. """ @@ -137,3 +137,40 @@ def token_type(self) -> str: @property def token(self) -> str: return self._token + + +class JWTAuth(Auth): + """ + The JWT auth flow. + """ + + def __init__(self, client_id: str, private_key: str, kid: str, ttl: int = 7200, base_url: str = COZE_COM_BASE_URL): + assert isinstance(client_id, str) + assert isinstance(private_key, str) + assert isinstance(kid, str) + assert isinstance(ttl, int) + assert ttl > 0 + assert isinstance(base_url, str) + + self._client_id = client_id + self._private_key = private_key + self._kid = kid + self._ttl = ttl + self._base_url = base_url + self._token = None + self._oauth_cli = ApplicationOAuth(self._client_id, base_url=self._base_url) + + @property + def token_type(self) -> str: + return "Bearer" + + @property + def token(self) -> str: + token = self._generate_token() + return token.access_token + + def _generate_token(self): + if self._token is not None and int(time.time()) < self._token.expires_in: + return self._token + self._token = self._oauth_cli.jwt_auth(self._private_key, self._kid, self._ttl) + return self._token diff --git a/cozepy/chat.py b/cozepy/chat.py deleted file mode 100644 index 4eef407..0000000 --- a/cozepy/chat.py +++ /dev/null @@ -1,190 +0,0 @@ -import json -from enum import Enum -from typing import Dict, List, Iterator, Union - -from .auth import Auth -from .model import Message, Chat, CozeModel -from .request import Requester - - -class Event(str, Enum): - # Event for creating a conversation, indicating the start of the conversation. - # 创建对话的事件,表示对话开始。 - conversation_chat_created = "conversation.chat.created" - - # The server is processing the conversation. - # 服务端正在处理对话。 - conversation_chat_in_progress = "conversation.chat.in_progress" - - # Incremental message, usually an incremental message when type=answer. - # 增量消息,通常是 type=answer 时的增量消息。 - conversation_message_delta = "conversation.message.delta" - - # The message has been completely replied to. At this point, the streaming package contains the spliced results of all message.delta, and each message is in a completed state. - # message 已回复完成。此时流式包中带有所有 message.delta 的拼接结果,且每个消息均为 completed 状态。 - conversation_message_completed = "conversation.message.completed" - - # The conversation is completed. - # 对话完成。 - conversation_chat_completed = "conversation.chat.completed" - - # This event is used to mark a failed conversation. - # 此事件用于标识对话失败。 - conversation_chat_failed = "conversation.chat.failed" - - # The conversation is interrupted and requires the user to report the execution results of the tool. - # 对话中断,需要使用方上报工具的执行结果。 - conversation_chat_requires_action = "conversation.chat.requires_action" - - # Error events during the streaming response process. For detailed explanations of code and msg, please refer to Error codes. - # 流式响应过程中的错误事件。关于 code 和 msg 的详细说明,可参考错误码。 - error = "error" - - # The streaming response for this session ended normally. - # 本次会话的流式返回正常结束。 - done = "done" - - -class ChatEvent(CozeModel): - event: Event - chat: Chat = None - message: Message = None - - -class ChatIterator(object): - def __init__(self, iters: Iterator[bytes]): - self._iters = iters - - def __iter__(self): - return self - - def __next__(self) -> ChatEvent: - event = "" - data = "" - line = "" - times = 0 - - while times < 2: - line = next(self._iters).decode("utf-8") - if line == "": - continue - elif line.startswith("event:"): - if event == "": - event = line[6:] - else: - raise Exception(f"invalid event: {line}") - elif line.startswith("data:"): - if data == "": - data = line[5:] - else: - raise Exception(f"invalid event: {line}") - else: - raise Exception(f"invalid event: {line}") - - times += 1 - - if event == Event.done: - raise StopIteration - elif event == Event.error: - raise Exception(f"error event: {line}") - elif event in [Event.conversation_message_delta, Event.conversation_message_completed]: - return ChatEvent(event=event, message=Message.model_validate(json.loads(data))) - elif event in [ - Event.conversation_chat_created, - Event.conversation_chat_in_progress, - Event.conversation_chat_completed, - Event.conversation_chat_failed, - Event.conversation_chat_requires_action, - ]: - return ChatEvent(event=event, chat=Chat.model_validate(json.loads(data))) - else: - raise Exception(f"unknown event: {line}") - - -class ChatClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): - self._base_url = base_url - self._auth = auth - self._requester = requester - - def chat_v3( - self, - *, - bot_id: str, - user_id: str, - additional_messages: List[Message] = None, - stream: bool = False, - custom_variables: Dict[str, str] = None, - auto_save_history: bool = True, - meta_data: Dict[str, str] = None, - conversation_id: str = None, - ) -> Union[Chat, ChatIterator]: - """ - Create a conversation. - Conversation is an interaction between a bot and a user, including one or more messages. - """ - url = f"{self._base_url}/v3/chat" - body = { - "bot_id": bot_id, - "user_id": user_id, - "additional_messages": [i.model_dump() for i in additional_messages] if additional_messages else [], - "stream": stream, - "custom_variables": custom_variables, - "auto_save_history": auto_save_history, - "conversation_id": conversation_id if conversation_id else None, - "meta_data": meta_data, - } - if not stream: - return self._requester.request("post", url, Chat, body=body, stream=stream) - - return ChatIterator(self._requester.request("post", url, Chat, body=body, stream=stream)) - - def get_v3( - self, - *, - conversation_id: str, - chat_id: str, - ) -> Chat: - """ - Create a conversation. - Conversation is an interaction between a bot and a user, including one or more messages. - """ - url = f"{self._base_url}/v3/chat/retrieve" - params = { - "conversation_id": conversation_id, - "chat_id": chat_id, - } - return self._requester.request("post", url, Chat, params=params) - - def list_message_v3( - self, - *, - conversation_id: str, - chat_id: str, - ) -> List[Message]: - """ - Create a conversation. - Conversation is an interaction between a bot and a user, including one or more messages. - """ - url = f"{self._base_url}/v3/chat/message/list" - params = { - "conversation_id": conversation_id, - "chat_id": chat_id, - } - return self._requester.request("post", url, List[Message], params=params) - - def cancel_v3( - self, - *, - conversation_id: str, - chat_id: str, - ) -> Chat: - """ - Call this API to cancel an ongoing chat. - """ - url = f"{self._base_url}/v3/chat/cancel" - params = { - "conversation_id": conversation_id, - "chat_id": chat_id, - } - return self._requester.request("post", url, Chat, params=params) diff --git a/cozepy/request.py b/cozepy/request.py index a047ae4..72901a2 100644 --- a/cozepy/request.py +++ b/cozepy/request.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args +from typing import TYPE_CHECKING, Tuple, Optional, Union, List, get_origin, get_args, Iterator import requests from requests import Response @@ -37,7 +37,7 @@ def request( headers: dict = None, body: dict = None, stream: bool = False, - ) -> Union[T, List[T]]: + ) -> Union[T, List[T], Iterator[bytes]]: """ Send a request to the server. """ diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 0000000..e8b6881 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,20 @@ +import os + +from cozepy import ApplicationOAuth, COZE_CN_BASE_URL, JWTAuth, TokenAuth + +COZE_JWT_AUTH_CLIENT_ID = os.getenv("COZE_JWT_AUTH_CLIENT_ID").strip() +COZE_JWT_AUTH_PRIVATE_KEY = os.getenv("COZE_JWT_AUTH_PRIVATE_KEY").strip() +COZE_JWT_AUTH_KEY_ID = os.getenv("COZE_JWT_AUTH_KEY_ID").strip() + +COZE_TOKEN = os.getenv("COZE_TOKEN").strip() + +app_oauth = ApplicationOAuth( + COZE_JWT_AUTH_CLIENT_ID, + base_url=COZE_CN_BASE_URL, +) + +fixed_token_auth = TokenAuth(COZE_TOKEN) + +jwt_auth = JWTAuth( + COZE_JWT_AUTH_CLIENT_ID, COZE_JWT_AUTH_PRIVATE_KEY, COZE_JWT_AUTH_KEY_ID, 30, base_url=COZE_CN_BASE_URL +) diff --git a/tests/test_auth.py b/tests/test_auth.py index ec56c35..c6bf2f6 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,20 +1,15 @@ -import os import time -from cozepy import ApplicationOAuth, COZE_CN_BASE_URL +from tests.config import app_oauth, COZE_JWT_AUTH_KEY_ID, COZE_JWT_AUTH_PRIVATE_KEY, jwt_auth -def test_jwt_auth(): - client_id = os.getenv("COZE_JWT_AUTH_CLIENT_ID") - private_key = os.getenv("COZE_JWT_AUTH_PRIVATE_KEY") - key_id = os.getenv("COZE_JWT_AUTH_KEY_ID") - - app = ApplicationOAuth( - client_id, - base_url=COZE_CN_BASE_URL, - ) - token = app.jwt_auth(private_key, key_id, 30) +def test_jwt_app_oauth(): + token = app_oauth.jwt_auth(COZE_JWT_AUTH_PRIVATE_KEY, COZE_JWT_AUTH_KEY_ID, 30) assert token.access_token != "" assert token.token_type == "Bearer" assert token.expires_in - int(time.time()) <= 31 assert token.refresh_token == "" + + +def test_jwt_auth(): + assert jwt_auth.token != "" diff --git a/tests/test_bot.py b/tests/test_bot.py index f458816..54d6005 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1,17 +1,45 @@ import os from unittest import TestCase -from cozepy import TokenAuth, Coze, COZE_CN_BASE_URL +from cozepy import Coze, COZE_CN_BASE_URL +from tests.config import fixed_token_auth, jwt_auth class TestBotClient(TestCase): - def test_list_published_bots_v1(self): + def test_bot_v1_list(self): space_id = os.getenv("SPACE_ID_1").strip() - token = os.getenv("COZE_TOKEN").strip() - auth = TokenAuth(token) - cli = Coze(auth=auth, base_url=COZE_CN_BASE_URL) + cli_list = [ + # fixed token + Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL), + # jwt auth + Coze(auth=jwt_auth, base_url=COZE_CN_BASE_URL), + ] + for cli in cli_list: + res = cli.bot.v1.list(space_id=space_id, page_size=2) + assert res.total > 1 + assert res.has_more + assert len(res.items) > 1 + + def test_bot_v1_get_online_info(self): + bot_id = self.bot_id + + cli_list = [ + # fixed token + Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL), + # jwt auth + Coze(auth=jwt_auth, base_url=COZE_CN_BASE_URL), + ] + for cli in cli_list: + bot = cli.bot.v1.get_online_info(bot_id=bot_id) + assert bot is not None + assert bot.bot_id == bot_id + + @property + def bot_id(self) -> str: + space_id = os.getenv("SPACE_ID_1").strip() + + # fixed token + cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL) res = cli.bot.v1.list(space_id=space_id, page_size=2) - assert res.total > 1 - assert res.has_more - assert len(res.items) > 1 + return res.items[0].bot_id diff --git a/tests/test_chat.py b/tests/test_chat.py index febd5cd..edbced3 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,16 +1,15 @@ import os -from cozepy import TokenAuth, Coze, COZE_CN_BASE_URL, Message +from cozepy import Coze, COZE_CN_BASE_URL, Message from cozepy.auth import _random_hex from cozepy.chat.v3 import ChatIterator, Event +from tests.config import fixed_token_auth def test_chat_v3_not_stream(): - token = os.getenv("COZE_TOKEN").strip() bot_id = os.getenv("COZE_BOT_ID_TRANSLATE").strip() - auth = TokenAuth(token) - cli = Coze(auth=auth, base_url=COZE_CN_BASE_URL) + cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL) chat = cli.chat.v3.create( bot_id=bot_id, @@ -31,11 +30,9 @@ def test_chat_v3_not_stream(): def test_chat_v3_stream(): - token = os.getenv("COZE_TOKEN").strip() bot_id = os.getenv("COZE_BOT_ID_TRANSLATE").strip() - auth = TokenAuth(token) - cli = Coze(auth=auth, base_url=COZE_CN_BASE_URL) + cli = Coze(auth=fixed_token_auth, base_url=COZE_CN_BASE_URL) chat_iter: ChatIterator = cli.chat.v3.create( bot_id=bot_id,