From 8986e686cffdeb8457191f61f3432702d0a663e7 Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 3 Mar 2025 18:14:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9B=E5=BB=BAAsyncJWTAuth=E7=B1=BB?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5=E6=9B=B4=E6=96=B0=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E7=A0=81=20=E4=BF=AE=E5=A4=8D=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E7=9A=84=E9=89=B4=E6=9D=83=E8=B0=83=E7=94=A8=EF=BC=8C=E5=8C=85?= =?UTF-8?q?=E6=8B=ACRequester=E5=92=8C=E4=B8=8A=E5=B1=82=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cozepy/audio/voices/__init__.py | 6 +- cozepy/auth/__init__.py | 190 +++++-- cozepy/bots/__init__.py | 6 +- cozepy/conversations/__init__.py | 6 +- cozepy/conversations/message/__init__.py | 6 +- cozepy/coze.py | 4 +- cozepy/datasets/__init__.py | 118 +++-- cozepy/datasets/documents/__init__.py | 6 +- cozepy/datasets/images/__init__.py | 6 +- cozepy/knowledge/documents/__init__.py | 6 +- cozepy/model.py | 14 +- cozepy/request.py | 464 ++++++++++-------- cozepy/websockets/__init__.py | 3 +- cozepy/websockets/audio/__init__.py | 4 +- cozepy/websockets/audio/speech/__init__.py | 42 +- .../audio/transcriptions/__init__.py | 58 +-- cozepy/websockets/chat/__init__.py | 68 +-- cozepy/websockets/ws.py | 52 +- cozepy/workspaces/__init__.py | 6 +- tests/test_auth.py | 160 +++--- 20 files changed, 688 insertions(+), 537 deletions(-) diff --git a/cozepy/audio/voices/__init__.py b/cozepy/audio/voices/__init__.py index dcfb65b..9f25d24 100644 --- a/cozepy/audio/voices/__init__.py +++ b/cozepy/audio/voices/__init__.py @@ -140,7 +140,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_size": i_page_size, }, cast=_PrivateListVoiceData, - is_async=False, stream=False, ) @@ -225,8 +224,8 @@ async def list( url = f"{self._base_url}/v1/audio/voices" headers: Optional[dict] = kwargs.get("headers") - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "GET", url, params={ @@ -236,7 +235,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: }, headers=headers, cast=_PrivateListVoiceData, - is_async=False, stream=False, ) diff --git a/cozepy/auth/__init__.py b/cozepy/auth/__init__.py index f93617d..48060fb 100644 --- a/cozepy/auth/__init__.py +++ b/cozepy/auth/__init__.py @@ -77,12 +77,12 @@ def __init__(self, client_id: str, base_url: str, www_base_url: str): self._requester = Requester() def _get_oauth_url( - self, - redirect_uri: str, - code_challenge: Optional[str] = None, - code_challenge_method: Optional[str] = None, - state: str = "", - workspace_id: Optional[str] = None, + self, + redirect_uri: str, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[str] = None, + state: str = "", + workspace_id: Optional[str] = None, ): params = { "response_type": "code", @@ -163,10 +163,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_ super().__init__(client_id, base_url, www_base_url=www_base_url) def get_oauth_url( - self, - redirect_uri: str, - state: str = "", - workspace_id: Optional[str] = None, + self, + redirect_uri: str, + state: str = "", + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -185,9 +185,9 @@ def get_oauth_url( ) def get_access_token( - self, - redirect_uri: str, - code: str, + self, + redirect_uri: str, + code: str, ) -> OAuthToken: """ Get the token by jwt with jwt auth flow. @@ -225,10 +225,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_ super().__init__(client_id, base_url, www_base_url=www_base_url) def get_oauth_url( - self, - redirect_uri: str, - state: str = "", - workspace_id: Optional[str] = None, + self, + redirect_uri: str, + state: str = "", + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -247,9 +247,9 @@ def get_oauth_url( ) async def get_access_token( - self, - redirect_uri: str, - code: str, + self, + redirect_uri: str, + code: str, ) -> OAuthToken: """ Get the token by jwt with jwt auth flow. @@ -289,7 +289,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur super().__init__(client_id, base_url, www_base_url="") def get_access_token( - self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None + self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None ) -> OAuthToken: """ Get the token by jwt with jwt auth flow. @@ -332,7 +332,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur super().__init__(client_id, base_url, www_base_url="") async def get_access_token( - self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None + self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None ) -> OAuthToken: """ Get the token by jwt with jwt auth flow. @@ -366,12 +366,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u ) def get_oauth_url( - self, - redirect_uri: str, - code_verifier: str, - code_challenge_method: Literal["plain", "S256"] = "plain", - state: str = "", - workspace_id: Optional[str] = None, + self, + redirect_uri: str, + code_verifier: str, + code_challenge_method: Literal["plain", "S256"] = "plain", + state: str = "", + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -433,12 +433,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u ) def get_oauth_url( - self, - redirect_uri: str, - code_verifier: str, - code_challenge_method: Literal["plain", "S256"] = "plain", - state: str = "", - workspace_id: Optional[str] = None, + self, + redirect_uri: str, + code_verifier: str, + code_challenge_method: Literal["plain", "S256"] = "plain", + state: str = "", + workspace_id: Optional[str] = None, ): """ Get the pkce flow authorized url. @@ -500,8 +500,8 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u ) def get_device_code( - self, - workspace_id: Optional[str] = None, + self, + workspace_id: Optional[str] = None, ) -> DeviceAuthCode: """ Get the pkce flow authorized url. @@ -680,7 +680,6 @@ class Auth(abc.ABC): It provides the abstract methods for getting the token type and token. """ - @property @abc.abstractmethod def token_type(self) -> str: """ @@ -691,7 +690,6 @@ def token_type(self) -> str: :return: token type """ - @property @abc.abstractmethod def token(self) -> str: """ @@ -707,7 +705,7 @@ def authentication(self, headers: dict) -> None: :param headers: http headers :return: None """ - headers["Authorization"] = f"{self.token_type} {self.token}" + headers["Authorization"] = f"{self.token_type()} {self.token()}" class TokenAuth(Auth): @@ -720,11 +718,9 @@ def __init__(self, token: str): assert len(token) > 0 self._token = token - @property def token_type(self) -> str: return "Bearer" - @property def token(self) -> str: return self._token @@ -735,13 +731,13 @@ class JWTAuth(Auth): """ def __init__( - self, - client_id: Optional[str] = None, - private_key: Optional[str] = None, - public_key_id: Optional[str] = None, - ttl: int = 7200, - base_url: str = COZE_COM_BASE_URL, - oauth_app: Optional[JWTOAuthApp] = None, + self, + client_id: Optional[str] = None, + private_key: Optional[str] = None, + public_key_id: Optional[str] = None, + ttl: int = 7200, + base_url: str = COZE_COM_BASE_URL, + oauth_app: Optional[JWTOAuthApp] = None, ): assert ttl > 0 self._ttl = ttl @@ -759,11 +755,9 @@ def __init__( client_id, private_key, public_key_id, base_url=remove_url_trailing_slash(base_url) ) - @property def token_type(self) -> str: return "Bearer" - @property def token(self) -> str: token = self._generate_token() return token.access_token @@ -773,3 +767,99 @@ def _generate_token(self): return self._token self._token = self._oauth_cli.get_access_token(self._ttl) return self._token + + +class AsyncAuth(abc.ABC): + """ + This class is the base class for all authorization types. + + It provides the abstract methods for getting the token type and token. + """ + + @abc.abstractmethod + async def token_type(self) -> str: + """ + The authorization type used in the http request header. + + eg: Bearer, Basic, etc. + + :return: token type + """ + + @abc.abstractmethod + async def token(self) -> str: + """ + The token used in the http request header. + + :return: token + """ + + async def authentication(self, headers: dict) -> None: + """ + Construct the authorization header in the http headers. + + :param headers: http headers + :return: None + """ + headers["Authorization"] = f"{await self.token_type()} {await self.token()}" + + +class AsyncTokenAuth(AsyncAuth): + """ + The fixed access token auth flow. + """ + + def __init__(self, token: str): + assert isinstance(token, str) + assert len(token) > 0 + self._token = token + + async def token_type(self) -> str: + return "Bearer" + + async def token(self) -> str: + return self._token + + +class AsyncJWTAuth(AsyncAuth): + """ + The JWT auth flow. + """ + + def __init__( + self, + client_id: Optional[str] = None, + private_key: Optional[str] = None, + public_key_id: Optional[str] = None, + ttl: int = 7200, + base_url: str = COZE_COM_BASE_URL, + oauth_app: Optional[AsyncJWTOAuthApp] = None, + ): + assert ttl > 0 + self._ttl = ttl + self._token = None + + if oauth_app: + self._oauth_cli = oauth_app + else: + assert isinstance(client_id, str) + assert isinstance(private_key, str) + assert isinstance(public_key_id, str) + assert isinstance(ttl, int) + assert isinstance(base_url, str) + self._oauth_cli = AsyncJWTOAuthApp( + client_id, private_key, public_key_id, base_url=remove_url_trailing_slash(base_url) + ) + + async def token_type(self) -> str: + return "Bearer" + + async def token(self) -> str: + token = await self._generate_token() + return token.access_token + + async def _generate_token(self): + if self._token is not None and int(time.time()) < self._token.expires_in: + return self._token + self._token = await self._oauth_cli.get_access_token(self._ttl) + return self._token diff --git a/cozepy/bots/__init__.py b/cozepy/bots/__init__.py index aa1ef34..180b699 100644 --- a/cozepy/bots/__init__.py +++ b/cozepy/bots/__init__.py @@ -268,7 +268,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_index": i_page_num, }, cast=_PrivateListBotsData, - is_async=False, stream=False, ) @@ -413,8 +412,8 @@ async def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) - """ url = f"{self._base_url}/v1/space/published_bots_list" - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "GET", url, params={ @@ -423,7 +422,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_index": i_page_num, }, cast=_PrivateListBotsData, - is_async=False, stream=False, ) diff --git a/cozepy/conversations/__init__.py b/cozepy/conversations/__init__.py index 7f2376b..710daad 100644 --- a/cozepy/conversations/__init__.py +++ b/cozepy/conversations/__init__.py @@ -89,7 +89,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_size": i_page_size, }, cast=_PrivateListConversationResp, - is_async=False, stream=False, ) @@ -174,8 +173,8 @@ async def list( ): url = f"{self._base_url}/v1/conversations" - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "GET", url, params={ @@ -184,7 +183,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_size": i_page_size, }, cast=_PrivateListConversationResp, - is_async=False, stream=False, ) diff --git a/cozepy/conversations/message/__init__.py b/cozepy/conversations/message/__init__.py index e45cb0b..36b42d4 100644 --- a/cozepy/conversations/message/__init__.py +++ b/cozepy/conversations/message/__init__.py @@ -115,7 +115,6 @@ def request_maker(i_before_id: str, i_after_id: str) -> HTTPRequest: }, params=params, cast=_PrivateListMessageResp, - is_async=False, stream=False, ) @@ -287,8 +286,8 @@ async def list( "conversation_id": conversation_id, } - def request_maker(i_before_id: str, i_after_id: str) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_before_id: str, i_after_id: str) -> HTTPRequest: + return await self._requester.amake_request( "POST", url, json={ @@ -300,7 +299,6 @@ def request_maker(i_before_id: str, i_after_id: str) -> HTTPRequest: }, params=params, cast=_PrivateListMessageResp, - is_async=False, stream=False, ) diff --git a/cozepy/coze.py b/cozepy/coze.py index 28f1a24..333b5b2 100644 --- a/cozepy/coze.py +++ b/cozepy/coze.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Optional -from cozepy.auth import Auth +from cozepy.auth import Auth, AsyncAuth from cozepy.config import COZE_COM_BASE_URL from cozepy.request import AsyncHTTPClient, Requester, SyncHTTPClient from cozepy.util import remove_url_trailing_slash @@ -143,7 +143,7 @@ def users(self) -> "UsersClient": class AsyncCoze(object): def __init__( self, - auth: Auth, + auth: AsyncAuth, base_url: str = COZE_COM_BASE_URL, http_client: Optional[AsyncHTTPClient] = None, ): diff --git a/cozepy/datasets/__init__.py b/cozepy/datasets/__init__.py index 59bd09b..20a8cd1 100644 --- a/cozepy/datasets/__init__.py +++ b/cozepy/datasets/__init__.py @@ -112,13 +112,13 @@ def images(self) -> "DatasetsImagesClient": return self._images def create( - self, - *, - name: str, - space_id: str, - format_type: DocumentFormatType, - description: Optional[str] = None, - icon_file_id: Optional[str] = None, + self, + *, + name: str, + space_id: str, + format_type: DocumentFormatType, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, ) -> CreateDatasetResp: """ Create Dataset @@ -149,14 +149,14 @@ def create( ) def list( - self, - *, - space_id: str, - name: Optional[str] = None, - format_type: Optional[DocumentFormatType] = None, - page_num: int = 1, - page_size: int = 10, - **kwargs, + self, + *, + space_id: str, + name: Optional[str] = None, + format_type: Optional[DocumentFormatType] = None, + page_num: int = 1, + page_size: int = 10, + **kwargs, ) -> NumberPaged[Dataset]: """ List Datasets @@ -187,7 +187,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_num": i_page_num, }, cast=_PrivateListDatasetsData, - is_async=False, stream=False, ) @@ -199,12 +198,12 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: ) def update( - self, - *, - dataset_id: str, - name: str, - description: Optional[str] = None, - icon_file_id: Optional[str] = None, + self, + *, + dataset_id: str, + name: str, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, ) -> UpdateDatasetRes: """ Update Dataset @@ -233,9 +232,9 @@ def update( ) def delete( - self, - *, - dataset_id: str, + self, + *, + dataset_id: str, ) -> DeleteDatasetRes: """ Delete Dataset @@ -256,10 +255,10 @@ def delete( ) def process( - self, - *, - dataset_id: str, - document_ids: List[str], + self, + *, + dataset_id: str, + document_ids: List[str], ) -> ListResponse[DocumentProgress]: """ Check the upload progress @@ -316,13 +315,13 @@ def images(self) -> "AsyncDatasetsImagesClient": return self._images async def create( - self, - *, - name: str, - space_id: str, - format_type: DocumentFormatType, - description: Optional[str] = None, - icon_file_id: Optional[str] = None, + self, + *, + name: str, + space_id: str, + format_type: DocumentFormatType, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, ) -> CreateDatasetResp: """ Create Dataset @@ -350,14 +349,14 @@ async def create( ) async def list( - self, - *, - space_id: str, - name: Optional[str] = None, - format_type: Optional[DocumentFormatType] = None, - page_num: int = 1, - page_size: int = 10, - **kwargs, + self, + *, + space_id: str, + name: Optional[str] = None, + format_type: Optional[DocumentFormatType] = None, + page_num: int = 1, + page_size: int = 10, + **kwargs, ) -> AsyncNumberPaged[Dataset]: """ List Datasets @@ -375,8 +374,8 @@ async def list( url = f"{self._base_url}/v1/datasets" headers: Optional[dict] = kwargs.get("headers") - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "GET", url, headers=headers, @@ -388,7 +387,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_num": i_page_num, }, cast=_PrivateListDatasetsData, - is_async=False, stream=False, ) @@ -400,12 +398,12 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: ) async def update( - self, - *, - dataset_id: str, - name: str, - description: Optional[str] = None, - icon_file_id: Optional[str] = None, + self, + *, + dataset_id: str, + name: str, + description: Optional[str] = None, + icon_file_id: Optional[str] = None, ) -> UpdateDatasetRes: """ Update Dataset @@ -434,9 +432,9 @@ async def update( ) async def delete( - self, - *, - dataset_id: str, + self, + *, + dataset_id: str, ) -> DeleteDatasetRes: """ Delete Dataset @@ -457,10 +455,10 @@ async def delete( ) async def process( - self, - *, - dataset_id: str, - document_ids: List[str], + self, + *, + dataset_id: str, + document_ids: List[str], ) -> ListResponse[DocumentProgress]: """ Check the upload progress diff --git a/cozepy/datasets/documents/__init__.py b/cozepy/datasets/documents/__init__.py index b489f26..703c369 100644 --- a/cozepy/datasets/documents/__init__.py +++ b/cozepy/datasets/documents/__init__.py @@ -422,7 +422,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "size": i_page_size, }, cast=_PrivateListDocumentsData, - is_async=False, stream=False, ) @@ -560,8 +559,8 @@ async def list( url = f"{self._base_url}/open_api/knowledge/document/list" headers = {"Agw-Js-Conv": "str"} - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "POST", url, headers=headers, @@ -571,7 +570,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "size": i_page_size, }, cast=_PrivateListDocumentsData, - is_async=False, stream=False, ) diff --git a/cozepy/datasets/images/__init__.py b/cozepy/datasets/images/__init__.py index 2b256c8..94f9cbe 100644 --- a/cozepy/datasets/images/__init__.py +++ b/cozepy/datasets/images/__init__.py @@ -117,7 +117,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "has_caption": has_caption, }, cast=_PrivateListPhotosData, - is_async=False, stream=False, ) @@ -188,8 +187,8 @@ async def list( url = f"{self._base_url}/v1/datasets/{dataset_id}/images" headers: Optional[dict] = kwargs.get("headers") - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "get", url, headers=headers, @@ -200,7 +199,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "has_caption": has_caption, }, cast=_PrivateListPhotosData, - is_async=False, stream=False, ) diff --git a/cozepy/knowledge/documents/__init__.py b/cozepy/knowledge/documents/__init__.py index 816fab3..7a1fe14 100644 --- a/cozepy/knowledge/documents/__init__.py +++ b/cozepy/knowledge/documents/__init__.py @@ -200,7 +200,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "size": i_page_size, }, cast=_PrivateListDocumentsData, - is_async=False, stream=False, ) @@ -369,8 +368,8 @@ async def list( url = f"{self._base_url}/open_api/knowledge/document/list" headers = {"Agw-Js-Conv": "str"} - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "POST", url, headers=headers, @@ -380,7 +379,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "size": i_page_size, }, cast=_PrivateListDocumentsData, - is_async=False, stream=False, ) diff --git a/cozepy/model.py b/cozepy/model.py index 82b03f5..1284c91 100644 --- a/cozepy/model.py +++ b/cozepy/model.py @@ -14,7 +14,7 @@ TypeVar, Union, cast, - overload, + overload, Coroutine, ) import httpx @@ -310,7 +310,7 @@ def __init__( page_num: int, page_size: int, requestor: "Requester", - request_maker: Callable[[int, int], HTTPRequest], + request_maker: Callable[[int, int], Coroutine[None, None, HTTPRequest]], ): self.page_num = page_num self.page_size = page_size @@ -367,7 +367,7 @@ async def _fetch_page(self): """ if self._total is not None: return - request = self._request_maker(self.page_num, self.page_size) + request = await self._request_maker(self.page_num, self.page_size) res: NumberPagedResponse[T] = await self._requestor.asend(request) self._total = res.get_total() self._has_more = res.get_has_more() @@ -389,7 +389,7 @@ async def build( page_num: int, page_size: int, requestor: "Requester", - request_maker: Callable[[int, int], HTTPRequest], + request_maker: Callable[[int, int], Coroutine[None, None, HTTPRequest]], ) -> "AsyncNumberPaged[T]": page: AsyncNumberPaged[T] = AsyncNumberPaged( page_num=page_num, @@ -498,7 +498,7 @@ def __init__( before_id: str, after_id: str, requestor: "Requester", - request_maker: Callable[[str, str], HTTPRequest], + request_maker: Callable[[str, str], Coroutine[None, None, HTTPRequest]], ): self.before_id = before_id self.after_id = after_id @@ -551,7 +551,7 @@ async def build( before_id: str, after_id: str, requestor: "Requester", - request_maker: Callable[[str, str], HTTPRequest], + request_maker: Callable[[str, str], Coroutine[None, None, HTTPRequest]], ) -> "AsyncLastIDPaged[T]": page: AsyncLastIDPaged = AsyncLastIDPaged( before_id=before_id, @@ -566,7 +566,7 @@ async def _fetch_page(self): if self.last_id is not None or self._has_more is not None: return - request = self._request_maker(self.before_id, self.after_id) + request = await self._request_maker(self.before_id, self.after_id) res: LastIDPagedResponse[T] = await self._requestor.asend(request) self.first_id = res.get_first_id() diff --git a/cozepy/request.py b/cozepy/request.py index 0165582..114f605 100644 --- a/cozepy/request.py +++ b/cozepy/request.py @@ -28,7 +28,7 @@ from cozepy.version import coze_client_user_agent, user_agent if TYPE_CHECKING: - from cozepy.auth import Auth + from cozepy.auth import Auth, AsyncAuth T = TypeVar("T", bound=BaseModel) @@ -55,27 +55,28 @@ class Requester(object): """ def __init__( - self, - auth: Optional["Auth"] = None, - sync_client: Optional[SyncHTTPClient] = None, - async_client: Optional[AsyncHTTPClient] = None, + self, + auth: Optional["Auth"] = None, + async_auth: Optional["AsyncAuth"] = None, + sync_client: Optional[SyncHTTPClient] = None, + async_client: Optional[AsyncHTTPClient] = None, ): self._auth = auth + self._async_auth = async_auth self._sync_client = sync_client self._async_client = async_client def make_request( - self, - method: str, - url: str, - params: Optional[dict] = None, - headers: Optional[dict] = None, - json: Optional[dict] = None, - files: Optional[dict] = None, - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None] = None, - data_field: str = "data", - is_async: Optional[bool] = None, - stream: bool = False, + self, + method: str, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + json: Optional[dict] = None, + files: Optional[dict] = None, + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None] = None, + data_field: str = "data", + stream: bool = False, ) -> HTTPRequest: if headers is None: headers = {} @@ -91,7 +92,7 @@ def make_request( params, json, stream, - is_async, + False, ) return HTTPRequest( @@ -101,7 +102,49 @@ def make_request( headers=headers, json_body=json, files=files, - is_async=is_async, + is_async=False, + stream=stream, + data_field=data_field, + cast=cast, + ) + + async def amake_request( + self, + method: str, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + json: Optional[dict] = None, + files: Optional[dict] = None, + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None] = None, + data_field: str = "data", + stream: bool = False, + ) -> HTTPRequest: + if headers is None: + headers = {} + headers["User-Agent"] = user_agent() + headers["X-Coze-Client-User-Agent"] = coze_client_user_agent() + if self._async_auth: + await self._async_auth.authentication(headers) + + log_debug( + "request %s#%s sending, params=%s, json=%s, stream=%s, async=%s", + method, + url, + params, + json, + stream, + True, + ) + + return HTTPRequest( + method=method, + url=url, + params=params, + headers=headers, + json_body=json, + files=files, + is_async=True, stream=stream, data_field=data_field, cast=cast, @@ -109,99 +152,105 @@ def make_request( @overload def request( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[T], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> T: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[T], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> T: + ... @overload def request( - self, - method: str, - url: str, - stream: Literal[False], - cast: List[Type[T]], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> List[T]: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: List[Type[T]], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> List[T]: + ... @overload def request( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[ListResponse[T]], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> ListResponse[T]: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[ListResponse[T]], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> ListResponse[T]: + ... @overload def request( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[FileHTTPResponse], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> FileHTTPResponse: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[FileHTTPResponse], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> FileHTTPResponse: + ... @overload def request( - self, - method: str, - url: str, - stream: Literal[True], - cast: None, - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> IteratorHTTPResponse[str]: ... + self, + method: str, + url: str, + stream: Literal[True], + cast: None, + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> IteratorHTTPResponse[str]: + ... @overload def request( - self, - method: str, - url: str, - stream: Literal[False], - cast: None, - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> None: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: None, + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> None: + ... def request( - self, - method: str, - url: str, - stream: Literal[True, False], - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], - params: Optional[dict] = None, - headers: Optional[dict] = None, - body: Optional[dict] = None, - files: Optional[dict] = None, - data_field: str = "data", + self, + method: str, + url: str, + stream: Literal[True, False], + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], + params: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, + files: Optional[dict] = None, + data_field: str = "data", ) -> Union[T, List[T], ListResponse[T], IteratorHTTPResponse[str], FileHTTPResponse, None]: """ Send a request to the server. @@ -218,14 +267,13 @@ def request( cast=cast, data_field=data_field, stream=stream, - is_async=False, ) return self.send(request) def send( - self, - request: HTTPRequest, + self, + request: HTTPRequest, ) -> Union[T, List[T], ListResponse[T], IteratorHTTPResponse[str], FileHTTPResponse, None]: """ Send a request to the server. @@ -242,106 +290,112 @@ def send( @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[T], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> T: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[T], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> T: + ... @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[False], - cast: List[Type[T]], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> List[T]: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: List[Type[T]], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> List[T]: + ... @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[ListResponse[T]], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> ListResponse[T]: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[ListResponse[T]], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> ListResponse[T]: + ... @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[False], - cast: Type[FileHTTPResponse], - params: dict = ..., - headers: Optional[dict] = ..., - body: dict = ..., - files: dict = ..., - data_field: str = ..., - ) -> FileHTTPResponse: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: Type[FileHTTPResponse], + params: dict = ..., + headers: Optional[dict] = ..., + body: dict = ..., + files: dict = ..., + data_field: str = ..., + ) -> FileHTTPResponse: + ... @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[False], - cast: None, - params: Optional[dict] = ..., - headers: Optional[dict] = ..., - body: Optional[dict] = ..., - files: Optional[dict] = ..., - data_field: str = ..., - ) -> None: ... + self, + method: str, + url: str, + stream: Literal[False], + cast: None, + params: Optional[dict] = ..., + headers: Optional[dict] = ..., + body: Optional[dict] = ..., + files: Optional[dict] = ..., + data_field: str = ..., + ) -> None: + ... @overload async def arequest( - self, - method: str, - url: str, - stream: Literal[True], - cast: None, - params: Optional[dict] = ..., - headers: Optional[dict] = ..., - body: Optional[dict] = ..., - files: Optional[dict] = ..., - data_field: str = ..., - ) -> AsyncIteratorHTTPResponse[str]: ... + self, + method: str, + url: str, + stream: Literal[True], + cast: None, + params: Optional[dict] = ..., + headers: Optional[dict] = ..., + body: Optional[dict] = ..., + files: Optional[dict] = ..., + data_field: str = ..., + ) -> AsyncIteratorHTTPResponse[str]: + ... async def arequest( - self, - method: str, - url: str, - stream: Literal[True, False], - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], - params: Optional[dict] = None, - headers: Optional[dict] = None, - body: Optional[dict] = None, - files: Optional[dict] = None, - data_field: str = "data", + self, + method: str, + url: str, + stream: Literal[True, False], + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], + params: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, + files: Optional[dict] = None, + data_field: str = "data", ) -> Union[T, List[T], ListResponse[T], AsyncIteratorHTTPResponse[str], FileHTTPResponse, None]: """ Send a request to the server. """ method = method.upper() - request = self.make_request( - method, url, params=params, headers=headers, json=body, files=files, stream=stream, is_async=True + request = await self.amake_request( + method, url, params=params, headers=headers, json=body, files=files, stream=stream ) response = await self.async_client.send(request.as_httpx, stream=stream) @@ -350,8 +404,8 @@ async def arequest( ) async def asend( - self, - request: HTTPRequest, + self, + request: HTTPRequest, ) -> Union[T, List[T], ListResponse[T], AsyncIteratorHTTPResponse[str], FileHTTPResponse, None]: return self._parse_response( method=request.method, @@ -377,37 +431,39 @@ def async_client(self) -> "AsyncHTTPClient": @overload def _parse_response( - self, - method: str, - url: str, - is_async: Literal[False], - response: httpx.Response, - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], - stream: bool = ..., - data_field: str = ..., - ) -> Union[T, List[T], ListResponse[T], IteratorHTTPResponse[str], FileHTTPResponse, None]: ... + self, + method: str, + url: str, + is_async: Literal[False], + response: httpx.Response, + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], + stream: bool = ..., + data_field: str = ..., + ) -> Union[T, List[T], ListResponse[T], IteratorHTTPResponse[str], FileHTTPResponse, None]: + ... @overload def _parse_response( - self, - method: str, - url: str, - is_async: Literal[True], - response: httpx.Response, - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], - stream: bool = ..., - data_field: str = ..., - ) -> Union[T, List[T], ListResponse[T], AsyncIteratorHTTPResponse[str], FileHTTPResponse, None]: ... + self, + method: str, + url: str, + is_async: Literal[True], + response: httpx.Response, + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], + stream: bool = ..., + data_field: str = ..., + ) -> Union[T, List[T], ListResponse[T], AsyncIteratorHTTPResponse[str], FileHTTPResponse, None]: + ... def _parse_response( - self, - method: str, - url: str, - is_async: Literal[True, False], - response: httpx.Response, - cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], - stream: bool = False, - data_field: str = "data", + self, + method: str, + url: str, + is_async: Literal[True, False], + response: httpx.Response, + cast: Union[Type[T], List[Type[T]], Type[ListResponse[T]], Type[FileHTTPResponse], None], + stream: bool = False, + data_field: str = "data", ) -> Union[ T, List[T], ListResponse[T], IteratorHTTPResponse[str], AsyncIteratorHTTPResponse[str], FileHTTPResponse, None ]: @@ -452,7 +508,7 @@ def _parse_response( return res # type: ignore def _parse_requests_code_msg( - self, method: str, url: str, response: Response, data_field: str = "data" + self, method: str, url: str, response: Response, data_field: str = "data" ) -> Tuple[Optional[int], str, Any]: try: response.read() diff --git a/cozepy/websockets/__init__.py b/cozepy/websockets/__init__.py index 6353636..918fe9c 100644 --- a/cozepy/websockets/__init__.py +++ b/cozepy/websockets/__init__.py @@ -4,6 +4,7 @@ from .audio import AsyncWebsocketsAudioClient, WebsocketsAudioClient from .chat import AsyncWebsocketsChatBuildClient, WebsocketsChatBuildClient +from ..auth import AsyncAuth class WebsocketsClient(object): @@ -30,7 +31,7 @@ def chat(self) -> WebsocketsChatBuildClient: class AsyncWebsocketsClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): + def __init__(self, base_url: str, auth: AsyncAuth, requester: Requester): self._base_url = http_base_url_to_ws(remove_url_trailing_slash(base_url)) self._auth = auth self._requester = requester diff --git a/cozepy/websockets/audio/__init__.py b/cozepy/websockets/audio/__init__.py index a02f3ea..4aeb48f 100644 --- a/cozepy/websockets/audio/__init__.py +++ b/cozepy/websockets/audio/__init__.py @@ -1,4 +1,4 @@ -from cozepy.auth import Auth +from cozepy.auth import Auth, AsyncAuth from cozepy.request import Requester from .speech import AsyncWebsocketsAudioSpeechBuildClient, WebsocketsAudioSpeechBuildClient @@ -29,7 +29,7 @@ def speech(self) -> "WebsocketsAudioSpeechBuildClient": class AsyncWebsocketsAudioClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): + def __init__(self, base_url: str, auth: AsyncAuth, requester: Requester): self._base_url = base_url self._auth = auth self._requester = requester diff --git a/cozepy/websockets/audio/speech/__init__.py b/cozepy/websockets/audio/speech/__init__.py index 7a533f7..7d976ec 100644 --- a/cozepy/websockets/audio/speech/__init__.py +++ b/cozepy/websockets/audio/speech/__init__.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_serializer -from cozepy.auth import Auth +from cozepy.auth import Auth, AsyncAuth from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -85,12 +85,12 @@ def on_speech_audio_completed(self, cli: "WebsocketsAudioSpeechClient", event: S class WebsocketsAudioSpeechClient(WebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, WebsocketsAudioSpeechEventHandler): on_event = on_event.to_dict( @@ -168,7 +168,7 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester def create( - self, *, on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs + self, *, on_event: Union[WebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], **kwargs ) -> WebsocketsAudioSpeechClient: return WebsocketsAudioSpeechClient( base_url=self._base_url, @@ -184,7 +184,7 @@ async def on_speech_created(self, cli: "AsyncWebsocketsAudioSpeechClient", event pass async def on_input_text_buffer_completed( - self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent + self, cli: "AsyncWebsocketsAudioSpeechClient", event: InputTextBufferCompletedEvent ): pass @@ -192,19 +192,19 @@ async def on_speech_audio_update(self, cli: "AsyncWebsocketsAudioSpeechClient", pass async def on_speech_audio_completed( - self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent + self, cli: "AsyncWebsocketsAudioSpeechClient", event: SpeechAudioCompletedEvent ): pass class AsyncWebsocketsAudioSpeechClient(AsyncWebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: AsyncAuth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, AsyncWebsocketsAudioSpeechEventHandler): on_event = on_event.to_dict( @@ -271,16 +271,16 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: class AsyncWebsocketsAudioSpeechBuildClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): + def __init__(self, base_url: str, auth: AsyncAuth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, - *, - on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + *, + on_event: Union[AsyncWebsocketsAudioSpeechEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ) -> AsyncWebsocketsAudioSpeechClient: return AsyncWebsocketsAudioSpeechClient( base_url=self._base_url, diff --git a/cozepy/websockets/audio/transcriptions/__init__.py b/cozepy/websockets/audio/transcriptions/__init__.py index 4f03b14..6b3ab5b 100644 --- a/cozepy/websockets/audio/transcriptions/__init__.py +++ b/cozepy/websockets/audio/transcriptions/__init__.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_serializer -from cozepy.auth import Auth +from cozepy.auth import Auth, AsyncAuth from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -84,29 +84,29 @@ def on_transcriptions_created(self, cli: "WebsocketsAudioTranscriptionsClient", pass def on_input_audio_buffer_completed( - self, cli: "WebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + self, cli: "WebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): pass def on_transcriptions_message_update( - self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): pass def on_transcriptions_message_completed( - self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent + self, cli: "WebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent ): pass class WebsocketsAudioTranscriptionsClient(WebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: Auth, + requester: Requester, + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, WebsocketsAudioTranscriptionsEventHandler): on_event = on_event.to_dict( @@ -186,10 +186,10 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester def create( - self, - *, - on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + *, + on_event: Union[WebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ) -> WebsocketsAudioTranscriptionsClient: return WebsocketsAudioTranscriptionsClient( base_url=self._base_url, @@ -202,34 +202,34 @@ def create( class AsyncWebsocketsAudioTranscriptionsEventHandler(AsyncWebsocketsBaseEventHandler): async def on_transcriptions_created( - self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsCreatedEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsCreatedEvent ): pass async def on_input_audio_buffer_completed( - self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: InputAudioBufferCompletedEvent ): pass async def on_transcriptions_message_update( - self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageUpdateEvent ): pass async def on_transcriptions_message_completed( - self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent + self, cli: "AsyncWebsocketsAudioTranscriptionsClient", event: TranscriptionsMessageCompletedEvent ): pass class AsyncWebsocketsAudioTranscriptionsClient(AsyncWebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: AsyncAuth, + requester: Requester, + on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, AsyncWebsocketsAudioTranscriptionsEventHandler): on_event = on_event.to_dict( @@ -303,16 +303,16 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: class AsyncWebsocketsAudioTranscriptionsBuildClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): + def __init__(self, base_url: str, auth: AsyncAuth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, - *, - on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + *, + on_event: Union[AsyncWebsocketsAudioTranscriptionsEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ) -> AsyncWebsocketsAudioTranscriptionsClient: return AsyncWebsocketsAudioTranscriptionsClient( base_url=self._base_url, diff --git a/cozepy/websockets/chat/__init__.py b/cozepy/websockets/chat/__init__.py index 2e0e73b..07bdbfc 100644 --- a/cozepy/websockets/chat/__init__.py +++ b/cozepy/websockets/chat/__init__.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from cozepy import Chat, Message, ToolOutput -from cozepy.auth import Auth +from cozepy.auth import Auth, AsyncAuth from cozepy.log import log_warning from cozepy.request import Requester from cozepy.util import remove_url_trailing_slash @@ -132,7 +132,7 @@ def on_conversation_message_completed(self, cli: "WebsocketsChatClient", event: pass def on_conversation_chat_requires_action( - self, cli: "WebsocketsChatClient", event: ConversationChatRequiresActionEvent + self, cli: "WebsocketsChatClient", event: ConversationChatRequiresActionEvent ): pass @@ -148,13 +148,13 @@ def on_conversation_chat_completed(self, cli: "WebsocketsChatClient", event: Con class WebsocketsChatClient(WebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - bot_id: str, - on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: Auth, + requester: Requester, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, WebsocketsChatEventHandler): on_event = on_event.to_dict( @@ -297,11 +297,11 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): self._requester = requester def create( - self, - *, - bot_id: str, - on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + *, + bot_id: str, + on_event: Union[WebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ) -> WebsocketsChatClient: return WebsocketsChatClient( base_url=self._base_url, @@ -321,7 +321,7 @@ async def on_chat_updated(self, cli: "AsyncWebsocketsChatClient", event: ChatUpd pass async def on_input_audio_buffer_completed( - self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent + self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent ): pass @@ -329,22 +329,22 @@ async def on_conversation_chat_created(self, cli: "AsyncWebsocketsChatClient", e pass async def on_conversation_chat_in_progress( - self, cli: "AsyncWebsocketsChatClient", event: ConversationChatInProgressEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatInProgressEvent ): pass async def on_conversation_message_delta( - self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageDeltaEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageDeltaEvent ): pass async def on_conversation_chat_requires_action( - self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatRequiresActionEvent ): pass async def on_conversation_message_completed( - self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageCompletedEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationMessageCompletedEvent ): pass @@ -352,25 +352,25 @@ async def on_conversation_audio_delta(self, cli: "AsyncWebsocketsChatClient", ev pass async def on_conversation_audio_completed( - self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioCompletedEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationAudioCompletedEvent ): pass async def on_conversation_chat_completed( - self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent + self, cli: "AsyncWebsocketsChatClient", event: ConversationChatCompletedEvent ): pass class AsyncWebsocketsChatClient(AsyncWebsocketsBaseClient): def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - bot_id: str, - on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + base_url: str, + auth: AsyncAuth, + requester: Requester, + bot_id: str, + on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ): if isinstance(on_event, AsyncWebsocketsChatEventHandler): on_event = on_event.to_dict( @@ -507,17 +507,17 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: class AsyncWebsocketsChatBuildClient(object): - def __init__(self, base_url: str, auth: Auth, requester: Requester): + def __init__(self, base_url: str, auth: AsyncAuth, requester: Requester): self._base_url = remove_url_trailing_slash(base_url) self._auth = auth self._requester = requester def create( - self, - *, - bot_id: str, - on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], - **kwargs, + self, + *, + bot_id: str, + on_event: Union[AsyncWebsocketsChatEventHandler, Dict[WebsocketsEventType, Callable]], + **kwargs, ) -> AsyncWebsocketsChatClient: return AsyncWebsocketsChatClient( base_url=self._base_url, diff --git a/cozepy/websockets/ws.py b/cozepy/websockets/ws.py index 318383f..13742d8 100644 --- a/cozepy/websockets/ws.py +++ b/cozepy/websockets/ws.py @@ -10,6 +10,8 @@ from enum import Enum from typing import Callable, Dict, List, Optional, Set +from cozepy.auth import AsyncAuth + if sys.version_info >= (3, 8): # note: >=3.7,<3.8 not support asyncio from websockets import InvalidStatus @@ -22,6 +24,7 @@ warnings.warn("asyncio websockets requires Python >= 3.8") + class AsyncWebsocketClientConnection(object): def recv(self, *args, **kwargs): pass @@ -32,13 +35,14 @@ def send(self, *args, **kwargs): def close(self, *args, **kwargs): pass + def asyncio_connect(*args, **kwargs): pass + class InvalidStatus(object): pass - import websockets.sync.client from pydantic import BaseModel @@ -158,15 +162,15 @@ class State(str, Enum): CLOSED = "closed" def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - path: str, - query: Optional[Dict[str, str]] = None, - on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, - wait_events: Optional[List[WebsocketsEventType]] = None, - **kwargs, + self, + base_url: str, + auth: Auth, + requester: Requester, + path: str, + query: Optional[Dict[str, str]] = None, + on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, + wait_events: Optional[List[WebsocketsEventType]] = None, + **kwargs, ): self._state = self.State.INITIALIZED self._base_url = remove_url_trailing_slash(base_url) @@ -200,7 +204,7 @@ def connect(self): raise ValueError(f"Cannot connect in {self._state.value} state") self._state = self.State.CONNECTING headers = { - "Authorization": f"Bearer {self._auth.token}", + "Authorization": f"Bearer {self._auth.token()}", "X-Coze-Client-User-Agent": coze_client_user_agent(), **(self._headers or {}), } @@ -283,7 +287,8 @@ def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: return self._load_event(message) @abc.abstractmethod - def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + ... def _wait_completed(self, events: List[WebsocketsEventType], wait_all: bool) -> None: while True: @@ -365,15 +370,15 @@ class State(str, Enum): CLOSED = "closed" def __init__( - self, - base_url: str, - auth: Auth, - requester: Requester, - path: str, - query: Optional[Dict[str, str]] = None, - on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, - wait_events: Optional[List[WebsocketsEventType]] = None, - **kwargs, + self, + base_url: str, + auth: AsyncAuth, + requester: Requester, + path: str, + query: Optional[Dict[str, str]] = None, + on_event: Optional[Dict[WebsocketsEventType, Callable]] = None, + wait_events: Optional[List[WebsocketsEventType]] = None, + **kwargs, ): self._state = self.State.INITIALIZED self._base_url = remove_url_trailing_slash(base_url) @@ -405,7 +410,7 @@ async def connect(self): raise ValueError(f"Cannot connect in {self._state.value} state") self._state = self.State.CONNECTING headers = { - "Authorization": f"Bearer {self._auth.token}", + "Authorization": f"Bearer {await self._auth.token()}", "X-Coze-Client-User-Agent": coze_client_user_agent(), **(self._headers or {}), } @@ -483,7 +488,8 @@ def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]: return self._load_event(message) @abc.abstractmethod - def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: ... + def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]: + ... async def _wait_completed(self, wait_events: List[WebsocketsEventType], wait_all: bool) -> None: future: asyncio.Future[None] = asyncio.Future() diff --git a/cozepy/workspaces/__init__.py b/cozepy/workspaces/__init__.py index 4121052..a45168b 100644 --- a/cozepy/workspaces/__init__.py +++ b/cozepy/workspaces/__init__.py @@ -68,7 +68,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_num": i_page_num, }, cast=_PrivateListWorkspacesData, - is_async=False, stream=False, ) @@ -93,8 +92,8 @@ def __init__(self, base_url: str, auth: Auth, requester: Requester): async def list(self, *, page_num: int = 1, page_size: int = 20, headers=None) -> AsyncNumberPaged[Workspace]: url = f"{self._base_url}/v1/workspaces" - def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: - return self._requester.make_request( + async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: + return await self._requester.amake_request( "GET", url, headers=headers, @@ -103,7 +102,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest: "page_num": i_page_num, }, cast=_PrivateListWorkspacesData, - is_async=False, stream=False, ) diff --git a/tests/test_auth.py b/tests/test_auth.py index 6961c2d..f87fb53 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -19,8 +19,8 @@ Scope, WebOAuthApp, ) +from cozepy.auth import AsyncJWTAuth from cozepy.util import random_hex - from .test_util import read_file @@ -31,45 +31,45 @@ def test_get_oauth_url(self, respx_mock): url = app.get_oauth_url("https://example.com", "state") assert ( - "https://www.coze.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://www.coze.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url url = app.get_oauth_url("https://example.com", "state", workspace_id="this_is_id") assert ( - "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url def test_get_oauth_url_config_www_url(self, respx_mock): app = WebOAuthApp("client id", "client secret", www_base_url="https://example.com") url = app.get_oauth_url("https://example.com", "state") assert ( - "https://example.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://example.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url def test_get_oauth_url_config_custom_api_base_url(self, respx_mock): app = WebOAuthApp("client id", "client secret", base_url="https://api.example.com") url = app.get_oauth_url("https://example.com", "state") assert ( - "https://api.example.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://api.example.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url def test_get_access_token(self, respx_mock): app = WebOAuthApp("client id", "client secret") @@ -106,45 +106,45 @@ async def test_get_oauth_url(self, respx_mock): url = app.get_oauth_url("https://example.com", "state") assert ( - "https://www.coze.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://www.coze.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url url = app.get_oauth_url("https://example.com", "state", workspace_id="this_is_id") assert ( - "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url async def test_get_oauth_url_config_www_url(self, respx_mock): app = AsyncWebOAuthApp("client id", "client secret", www_base_url="https://example.com") url = app.get_oauth_url("https://example.com", "state") assert ( - "https://example.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://example.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url async def test_get_oauth_url_config_custom_api_base_url(self, respx_mock): app = AsyncWebOAuthApp("client id", "client secret", base_url="https://api.example.com") url = app.get_oauth_url("https://example.com", "state") assert ( - "https://api.example.com/api/permission/oauth2/authorize" - "?response_type=code&" - "client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&" - "state=state" - ) == url + "https://api.example.com/api/permission/oauth2/authorize" + "?response_type=code&" + "client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&" + "state=state" + ) == url async def test_get_access_token(self, respx_mock): app = AsyncWebOAuthApp("client id", "client secret") @@ -186,9 +186,9 @@ def test_jwt_auth(self, respx_mock): auth = JWTAuth("client id", private_key, "public key id") - assert "Bearer" == auth.token_type - assert mock_token == auth.token - assert mock_token == auth.token # get from cache + assert "Bearer" == auth.token_type() + assert mock_token == auth.token() + assert mock_token == auth.token() # get from cache def test_get_access_token(self, respx_mock): private_key = read_file("testdata/private_key.pem") @@ -207,6 +207,22 @@ def test_get_access_token(self, respx_mock): @pytest.mark.respx(base_url="https://api.coze.com") @pytest.mark.asyncio class TestAsyncJWTOAuthApp: + + async def test_jwt_auth(self, respx_mock): + private_key = read_file("testdata/private_key.pem") + mock_token = random_hex(20) + respx_mock.post("/api/permission/oauth2/token").mock( + httpx.Response( + 200, content=OAuthToken(access_token=mock_token, expires_in=int(time.time()) + 100).model_dump_json() + ) + ) + + auth = AsyncJWTAuth("client id", private_key, "public key id") + + assert "Bearer" == await auth.token_type() + assert mock_token == await auth.token() + assert mock_token == await auth.token() # get from cache + async def test_get_access_token(self, respx_mock): private_key = read_file("testdata/private_key.pem") app = AsyncJWTOAuthApp("client id", private_key, "public key id") @@ -228,21 +244,21 @@ def test_get_oauth_url(self, respx_mock): url = app.get_oauth_url("https://example.com", "code_verifier", "S256", state="state") assert ( - "https://www.coze.com/api/permission/oauth2/authorize?" - "response_type=code&client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&state=state&" - "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" - ) == url + "https://www.coze.com/api/permission/oauth2/authorize?" + "response_type=code&client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&state=state&" + "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" + ) == url url = app.get_oauth_url( "https://example.com", "code_verifier", "S256", state="state", workspace_id="this_is_id" ) assert ( - "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize?" - "response_type=code&client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&state=state&" - "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" - ) == url + "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize?" + "response_type=code&client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&state=state&" + "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" + ) == url def test_get_access_token(self, respx_mock): app = PKCEOAuthApp("client id") @@ -279,21 +295,21 @@ async def test_get_oauth_url(self, respx_mock): url = app.get_oauth_url("https://example.com", "code_verifier", "S256", state="state") assert ( - "https://www.coze.com/api/permission/oauth2/authorize?" - "response_type=code&client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&state=state&" - "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" - ) == url + "https://www.coze.com/api/permission/oauth2/authorize?" + "response_type=code&client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&state=state&" + "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" + ) == url url = app.get_oauth_url( "https://example.com", "code_verifier", "S256", workspace_id="this_is_id", state="state" ) assert ( - "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize?" - "response_type=code&client_id=client+id&" - "redirect_uri=https%3A%2F%2Fexample.com&state=state&" - "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" - ) == url + "https://www.coze.com/api/permission/oauth2/workspace_id/this_is_id/authorize?" + "response_type=code&client_id=client+id&" + "redirect_uri=https%3A%2F%2Fexample.com&state=state&" + "code_challenge=73oehA2tBul5grZPhXUGQwNAjxh69zNES8bu2bVD0EM&code_challenge_method=S256" + ) == url async def test_get_access_token(self, respx_mock): app = AsyncPKCEOAuthApp("client id")