Skip to content

Commit

Permalink
refact auth extends relation
Browse files Browse the repository at this point in the history
  • Loading branch information
admin committed Mar 5, 2025
1 parent d67a0d0 commit ab826b4
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 71 deletions.
2 changes: 2 additions & 0 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
OAuthToken,
PKCEOAuthApp,
Scope,
SyncAuth,
TokenAuth,
WebOAuthApp,
load_oauth_app_from_config,
Expand Down Expand Up @@ -175,6 +176,7 @@
"AsyncWebOAuthApp",
"Auth",
"AsyncAuth",
"SyncAuth",
"AsyncJWTAuth",
"AsyncTokenAuth",
"DeviceAuthCode",
Expand Down
98 changes: 57 additions & 41 deletions cozepy/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,15 @@ def token(self) -> str:
:return: token
"""

@property
@abc.abstractmethod
async def atoken(self) -> str:
"""
The token used in the http request header.
:return: token
"""

def authentication(self, headers: dict) -> None:
"""
Construct the authorization header in the http headers.
Expand All @@ -709,8 +718,52 @@ def authentication(self, headers: dict) -> None:
"""
headers["Authorization"] = f"{self.token_type} {self.token}"

async def aauthentication(self, headers: dict) -> None:
"""
Construct the authorization header in the http headers.
:param headers: http headers
:return: None
"""
headers["Authorization"] = f"{self.token_type} {await self.atoken}"


class SyncAuth(Auth, abc.ABC):
"""
This class is the base class for all SyncAuth authorization types.
It provides the abstract methods for getting the token type and sync token.
"""

@property
async def atoken(self) -> str:
"""
SyncAuth not need implementation.
:return: sync token for compatible
"""
return self.token


class AsyncAuth(Auth, abc.ABC):
"""
This class is the base class for all authorization types.
It provides the abstract methods for getting the token type and async token.
"""

@property
def token(self) -> str:
"""
AsyncAuth not need implementation.
Any compatible needed.
:return: empty
"""
return ""


class TokenAuth(Auth):
class TokenAuth(SyncAuth):
"""
The fixed access token auth flow.
"""
Expand All @@ -729,7 +782,7 @@ def token(self) -> str:
return self._token


class JWTAuth(Auth):
class JWTAuth(SyncAuth):
"""
The JWT auth flow.
"""
Expand Down Expand Up @@ -775,43 +828,6 @@ def _generate_token(self):
return self._token


class AsyncAuth(abc.ABC):
"""
This class is the base class for all authorization types.
It provides the abstract methods for getting the token type and token.
"""

@property
@abc.abstractmethod
def token_type(self) -> str:
"""
The authorization type used in the http request header.
eg: Bearer, Basic, etc.
:return: token type
"""

@property
@abc.abstractmethod
async def token(self) -> str:
"""
The token used in the http request header.
:return: token
"""

async def authentication(self, headers: dict) -> None:
"""
Construct the authorization header in the http headers.
:param headers: http headers
:return: None
"""
headers["Authorization"] = f"{self.token_type} {await self.token}"


class AsyncTokenAuth(AsyncAuth):
"""
The fixed access token auth flow.
Expand All @@ -827,7 +843,7 @@ def token_type(self) -> str:
return "Bearer"

@property
async def token(self) -> str:
async def atoken(self) -> str:
return self._token


Expand Down Expand Up @@ -866,7 +882,7 @@ def token_type(self) -> str:
return "Bearer"

@property
async def token(self) -> str:
async def atoken(self) -> str:
token = await self._generate_token()
return token.access_token

Expand Down
15 changes: 7 additions & 8 deletions cozepy/coze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional

from cozepy.auth import AsyncAuth, Auth
from cozepy.auth import Auth, SyncAuth
from cozepy.config import COZE_COM_BASE_URL
from cozepy.request import AsyncHTTPClient, Requester, SyncHTTPClient
from cozepy.util import remove_url_trailing_slash
Expand Down Expand Up @@ -143,22 +143,21 @@ def users(self) -> "UsersClient":
class AsyncCoze(object):
def __init__(
self,
auth: Union[Auth, AsyncAuth],
auth: Auth,
base_url: str = COZE_COM_BASE_URL,
http_client: Optional[AsyncHTTPClient] = None,
):
self._auth = auth
self._base_url = remove_url_trailing_slash(base_url)
if isinstance(auth, Auth):
if isinstance(auth, SyncAuth):
warnings.warn(
"The 'coze.Auth' use for AsyncCoze is deprecated and will be removed in a future version. "
"The 'coze.SyncAuth' use for AsyncCoze is deprecated and will be removed in a future version. "
"Please use 'coze.AsyncAuth' instead.",
DeprecationWarning,
stacklevel=2,
)
self._requester = Requester(auth=auth, async_client=http_client)
else:
self._requester = Requester(async_auth=auth, async_client=http_client)

self._requester = Requester(auth=auth, async_client=http_client)

# service client
self._bots: Optional[AsyncBotsClient] = None
Expand Down
13 changes: 4 additions & 9 deletions cozepy/knowledge/documents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
from typing import List, Optional

from cozepy.datasets.documents import (
Document, # noqa
DocumentBase, # noqa
DocumentChunkStrategy, # noqa
# noqa
# noqa
# noqa
# noqa
DocumentUpdateRule, # noqa
# noqa
Document,
DocumentBase,
DocumentChunkStrategy,
DocumentUpdateRule,
)
from cozepy.model import AsyncNumberPaged, CozeModel, HTTPRequest, NumberPaged, NumberPagedResponse
from cozepy.request import Requester
Expand Down
18 changes: 5 additions & 13 deletions cozepy/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from cozepy.version import coze_client_user_agent, user_agent

if TYPE_CHECKING:
from cozepy.auth import AsyncAuth, Auth
from cozepy.auth import Auth

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -57,12 +57,10 @@ class Requester(object):
def __init__(
self,
auth: Optional["Auth"] = None,
async_auth: Optional["AsyncAuth"] = None,
sync_client: Optional[SyncHTTPClient] = None,
async_client: Optional[AsyncHTTPClient] = None,
):
self._auth = auth
self._async_auth = async_auth
self._sync_client = sync_client
self._async_client = async_client

Expand All @@ -71,11 +69,8 @@ def auth_header(self, headers: dict):
self._auth.authentication(headers)

async def async_auth_header(self, headers: dict):
if self._async_auth:
await self._async_auth.authentication(headers)
elif self._auth:
# Compatible with old versions, the next version will be removed
self._auth.authentication(headers)
if self._auth:
await self._auth.aauthentication(headers)

def make_request(
self,
Expand Down Expand Up @@ -136,11 +131,8 @@ async def amake_request(
headers["User-Agent"] = user_agent()
headers["X-Coze-Client-User-Agent"] = coze_client_user_agent()

if self._async_auth:
await self._async_auth.authentication(headers)
elif self._auth:
# Compatible with old versions, the next version will be removed
self._auth.authentication(headers)
if self._auth:
await self._auth.aauthentication(headers)

log_debug(
"request %s#%s sending, params=%s, json=%s, stream=%s, async=%s",
Expand Down

0 comments on commit ab826b4

Please sign in to comment.