Skip to content

Commit

Permalink
support async auth of jwt and fix relative code
Browse files Browse the repository at this point in the history
  • Loading branch information
admin committed Mar 5, 2025
1 parent 642625a commit 0e9acd1
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 411 deletions.
6 changes: 3 additions & 3 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
from .audio.transcriptions import CreateTranscriptionsResp
from .audio.voices import Voice
from .auth import (
AsyncAuth,
AsyncDeviceOAuthApp,
AsyncJWTAuth,
AsyncJWTOAuthApp,
AsyncPKCEOAuthApp,
AsyncTokenAuth,
AsyncWebOAuthApp,
Auth,
AsyncAuth,
DeviceAuthCode,
DeviceOAuthApp,
JWTAuth,
AsyncJWTAuth,
JWTOAuthApp,
OAuthApp,
OAuthToken,
PKCEOAuthApp,
Scope,
TokenAuth,
AsyncTokenAuth,
WebOAuthApp,
load_oauth_app_from_config,
)
Expand Down
108 changes: 54 additions & 54 deletions cozepy/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def __init__(self, client_id: str, base_url: str, www_base_url: str):
self._requester = Requester()

def _get_oauth_url(
self,
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: str = "",
workspace_id: Optional[str] = None,
):
params = {
"response_type": "code",
Expand Down Expand Up @@ -163,10 +163,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_
super().__init__(client_id, base_url, www_base_url=www_base_url)

def get_oauth_url(
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand All @@ -185,9 +185,9 @@ def get_oauth_url(
)

def get_access_token(
self,
redirect_uri: str,
code: str,
self,
redirect_uri: str,
code: str,
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -225,10 +225,10 @@ def __init__(self, client_id: str, client_secret: str, base_url: str = COZE_COM_
super().__init__(client_id, base_url, www_base_url=www_base_url)

def get_oauth_url(
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand All @@ -247,9 +247,9 @@ def get_oauth_url(
)

async def get_access_token(
self,
redirect_uri: str,
code: str,
self,
redirect_uri: str,
code: str,
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
super().__init__(client_id, base_url, www_base_url="")

def get_access_token(
self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None
self, ttl: int = 900, scope: Optional[Scope] = None, session_name: Optional[str] = None
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -332,7 +332,7 @@ def __init__(self, client_id: str, private_key: str, public_key_id: str, base_ur
super().__init__(client_id, base_url, www_base_url="")

async def get_access_token(
self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None
self, ttl: int, scope: Optional[Scope] = None, session_name: Optional[str] = None
) -> OAuthToken:
"""
Get the token by jwt with jwt auth flow.
Expand Down Expand Up @@ -366,12 +366,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_oauth_url(
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -433,12 +433,12 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_oauth_url(
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
self,
redirect_uri: str,
code_verifier: str,
code_challenge_method: Literal["plain", "S256"] = "plain",
state: str = "",
workspace_id: Optional[str] = None,
):
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -500,8 +500,8 @@ def __init__(self, client_id: str, base_url: str = COZE_COM_BASE_URL, www_base_u
)

def get_device_code(
self,
workspace_id: Optional[str] = None,
self,
workspace_id: Optional[str] = None,
) -> DeviceAuthCode:
"""
Get the pkce flow authorized url.
Expand Down Expand Up @@ -735,13 +735,13 @@ class JWTAuth(Auth):
"""

def __init__(
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[JWTOAuthApp] = None,
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[JWTOAuthApp] = None,
):
assert ttl > 0
self._ttl = ttl
Expand Down Expand Up @@ -784,7 +784,7 @@ class AsyncAuth(abc.ABC):

@property
@abc.abstractmethod
async def token_type(self) -> str:
def token_type(self) -> str:
"""
The authorization type used in the http request header.
Expand All @@ -809,7 +809,7 @@ async def authentication(self, headers: dict) -> None:
:param headers: http headers
:return: None
"""
headers["Authorization"] = f"{await self.token_type} {await self.token}"
headers["Authorization"] = f"{self.token_type} {await self.token}"


class AsyncTokenAuth(AsyncAuth):
Expand All @@ -823,7 +823,7 @@ def __init__(self, token: str):
self._token = token

@property
async def token_type(self) -> str:
def token_type(self) -> str:
return "Bearer"

@property
Expand All @@ -837,13 +837,13 @@ class AsyncJWTAuth(AsyncAuth):
"""

def __init__(
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[AsyncJWTOAuthApp] = None,
self,
client_id: Optional[str] = None,
private_key: Optional[str] = None,
public_key_id: Optional[str] = None,
ttl: int = 7200,
base_url: str = COZE_COM_BASE_URL,
oauth_app: Optional[AsyncJWTOAuthApp] = None,
):
assert ttl > 0
self._ttl = ttl
Expand All @@ -862,7 +862,7 @@ def __init__(
)

@property
async def token_type(self) -> str:
def token_type(self) -> str:
return "Bearer"

@property
Expand Down
Loading

0 comments on commit 0e9acd1

Please sign in to comment.