Skip to content

Commit

Permalink
创建AsyncJWTAuth类支持异步更新鉴权码
Browse files Browse the repository at this point in the history
修复相关的鉴权调用,包括Requester和上层异步调用
  • Loading branch information
admin committed Mar 3, 2025
1 parent 7178c2c commit 8986e68
Show file tree
Hide file tree
Showing 20 changed files with 688 additions and 537 deletions.
6 changes: 2 additions & 4 deletions cozepy/audio/voices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
"page_size": i_page_size,
},
cast=_PrivateListVoiceData,
is_async=False,
stream=False,
)

Expand Down Expand Up @@ -225,8 +224,8 @@ async def list(
url = f"{self._base_url}/v1/audio/voices"
headers: Optional[dict] = kwargs.get("headers")

def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return self._requester.make_request(
async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return await self._requester.amake_request(
"GET",
url,
params={
Expand All @@ -236,7 +235,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
},
headers=headers,
cast=_PrivateListVoiceData,
is_async=False,
stream=False,
)

Expand Down
190 changes: 140 additions & 50 deletions cozepy/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def __init__(self, client_id: str, base_url: str, www_base_url: str):
self._requester = Requester()

def _get_oauth_url(
self,
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: str = "",
workspace_id: Optional[str] = None,
):
params = {
"response_type": "code",
Expand Down Expand Up @@ -163,10 +163,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_
super().__init__(client_id, base_url, www_base_url=www_base_url)

def get_oauth_url(
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand All @@ -185,9 +185,9 @@ def get_oauth_url(
)

def get_access_token(
self,
redirect_uri: str,
code: str,
self,
redirect_uri: str,
code: str,
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -225,10 +225,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_
super().__init__(client_id, base_url, www_base_url=www_base_url)

def get_oauth_url(
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand All @@ -247,9 +247,9 @@ def get_oauth_url(
)

async def get_access_token(
self,
redirect_uri: str,
code: str,
self,
redirect_uri: str,
code: str,
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
super().__init__(client_id, base_url, www_base_url="")

def get_access_token(
self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None
self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -332,7 +332,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
super().__init__(client_id, base_url, www_base_url="")

async def get_access_token(
self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None
self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -366,12 +366,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_oauth_url(
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -433,12 +433,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_oauth_url(
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -500,8 +500,8 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_device_code(
self,
workspace_id: Optional[str] = None,
self,
workspace_id: Optional[str] = None,
) -> DeviceAuthCode:
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -680,7 +680,6 @@ class Auth(abc.ABC):
It provides the abstract methods for getting the token type and token.
"""

@property
@abc.abstractmethod
def token_type(self) -> str:
"""
Expand All @@ -691,7 +690,6 @@ def token_type(self) -> str:
:return: token type
"""

@property
@abc.abstractmethod
def token(self) -> str:
"""
Expand All @@ -707,7 +705,7 @@ def authentication(self, headers: dict) -> None:
:param headers: http headers
:return: None
"""
headers["Authorization"] = f"{self.token_type} {self.token}"
headers["Authorization"] = f"{self.token_type()} {self.token()}"


class TokenAuth(Auth):
Expand All @@ -720,11 +718,9 @@ def __init__(self, token: str):
assert len(token) > 0
self._token = token

@property
def token_type(self) -> str:
return "Bearer"

@property
def token(self) -> str:
return self._token

Expand All @@ -735,13 +731,13 @@ class JWTAuth(Auth):
"""

def __init__(
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[JWTOAuthApp] = None,
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[JWTOAuthApp] = None,
):
assert ttl > 0
self._ttl = ttl
Expand All @@ -759,11 +755,9 @@ def __init__(
client_id, private_key, public_key_id, base_url=remove_url_trailing_slash(base_url)
)

@property
def token_type(self) -> str:
return "Bearer"

@property
def token(self) -> str:
token = self._generate_token()
return token.access_token
Expand All @@ -773,3 +767,99 @@ def _generate_token(self):
return self._token
self._token = self._oauth_cli.get_access_token(self._ttl)
return self._token


class AsyncAuth(abc.ABC):
"""
This class is the base class for all authorization types.
It provides the abstract methods for getting the token type and token.
"""

@abc.abstractmethod
async def token_type(self) -> str:
"""
The authorization type used in the http request header.
eg: Bearer, Basic, etc.
:return: token type
"""

@abc.abstractmethod
async def token(self) -> str:
"""
The token used in the http request header.
:return: token
"""

async def authentication(self, headers: dict) -> None:
"""
Construct the authorization header in the http headers.
:param headers: http headers
:return: None
"""
headers["Authorization"] = f"{await self.token_type()} {await self.token()}"


class AsyncTokenAuth(AsyncAuth):
"""
The fixed access token auth flow.
"""

def __init__(self, token: str):
assert isinstance(token, str)
assert len(token) > 0
self._token = token

async def token_type(self) -> str:
return "Bearer"

async def token(self) -> str:
return self._token


class AsyncJWTAuth(AsyncAuth):
"""
The JWT auth flow.
"""

def __init__(
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[AsyncJWTOAuthApp] = None,
):
assert ttl > 0
self._ttl = ttl
self._token = None

if oauth_app:
self._oauth_cli = oauth_app
else:
assert isinstance(client_id, str)
assert isinstance(private_key, str)
assert isinstance(public_key_id, str)
assert isinstance(ttl, int)
assert isinstance(base_url, str)
self._oauth_cli = AsyncJWTOAuthApp(
client_id, private_key, public_key_id, base_url=remove_url_trailing_slash(base_url)
)

async def token_type(self) -> str:
return "Bearer"

async def token(self) -> str:
token = await self._generate_token()
return token.access_token

async def _generate_token(self):
if self._token is not None and int(time.time()) < self._token.expires_in:
return self._token
self._token = await self._oauth_cli.get_access_token(self._ttl)
return self._token
6 changes: 2 additions & 4 deletions cozepy/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
"page_index": i_page_num,
},
cast=_PrivateListBotsData,
is_async=False,
stream=False,
)

Expand Down Expand Up @@ -413,8 +412,8 @@ async def list(self, *, space_id: str, page_num: int = 1, page_size: int = 20) -
"""
url = f"{self._base_url}/v1/space/published_bots_list"

def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return self._requester.make_request(
async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return await self._requester.amake_request(
"GET",
url,
params={
Expand All @@ -423,7 +422,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
"page_index": i_page_num,
},
cast=_PrivateListBotsData,
is_async=False,
stream=False,
)

Expand Down
6 changes: 2 additions & 4 deletions cozepy/conversations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
"page_size": i_page_size,
},
cast=_PrivateListConversationResp,
is_async=False,
stream=False,
)

Expand Down Expand Up @@ -174,8 +173,8 @@ async def list(
):
url = f"{self._base_url}/v1/conversations"

def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return self._requester.make_request(
async def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
return await self._requester.amake_request(
"GET",
url,
params={
Expand All @@ -184,7 +183,6 @@ def request_maker(i_page_num: int, i_page_size: int) -> HTTPRequest:
"page_size": i_page_size,
},
cast=_PrivateListConversationResp,
is_async=False,
stream=False,
)

Expand Down
Loading

0 comments on commit 8986e68

Please sign in to comment.