diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index d2ca1c1ca6..1b8e1833d3 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,6 +1,5 @@ import math import dataclasses -from abc import abstractmethod from base64 import b64encode from typing import ( TYPE_CHECKING, @@ -17,6 +16,7 @@ from typing_extensions import Annotated from requests.auth import AuthBase from requests import PreparedRequest, Session as BaseSession # noqa: I251 +from abc import abstractmethod from dlt.common import logger from dlt.common.exceptions import MissingDependencyException @@ -46,6 +46,15 @@ def __bool__(self) -> bool: # to be evaluated as False in requests.sessions.Session.prepare_request() return True + def mask_secret(self, secret: Optional[str]) -> str: + if secret is None: + return "None" + return f"{secret[0]}*****{secret[-1]}" + + @abstractmethod + def mask_secrets(self) -> str: + pass + @configspec class BearerTokenAuth(AuthConfigBase): @@ -67,6 +76,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: request.headers["Authorization"] = f"Bearer {self.token}" return request + def mask_secrets(self) -> str: + return self.mask_secret(self.token) + @configspec class APIKeyAuth(AuthConfigBase): @@ -95,6 +107,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: raise NotImplementedError() return request + def mask_secrets(self) -> str: + return self.mask_secret(self.api_key) + @configspec class HttpBasicAuth(AuthConfigBase): @@ -117,10 +132,16 @@ def parse_native_representation(self, value: Any) -> None: ) def __call__(self, request: PreparedRequest) -> PreparedRequest: - encoded = b64encode(f"{self.username}:{self.password}".encode()).decode() + encoded = b64encode(self._format().encode()).decode() request.headers["Authorization"] = f"Basic {encoded}" return request + def mask_secrets(self) -> str: + return self.mask_secret(self._format()) + + def _format(self) -> str: + return f"{self.username}:{self.password}" + @configspec class OAuth2AuthBase(AuthConfigBase): @@ -143,6 +164,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: request.headers["Authorization"] = f"Bearer {self.access_token}" return request + def mask_secrets(self) -> str: + return self.mask_secret(self.access_token) + @configspec class OAuth2ClientCredentials(OAuth2AuthBase): @@ -213,6 +237,15 @@ def parse_expiration_in_seconds(self, response_json: Any) -> int: def parse_access_token(self, response_json: Any) -> str: return str(response_json.get("access_token")) + def mask_secrets(self) -> str: + return str( + { + "access_token": self.mask_secret(self.access_token), + "client_id": self.mask_secret(self.client_id), + "client_secret": self.mask_secret(self.client_secret), + } + ) + @configspec class OAuthJWTAuth(BearerTokenAuth): diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 8fafde60d9..791e6fa92b 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -155,6 +155,35 @@ def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: pages = list(pages_iter) assert pages == [] + def test_auth_masks_bearer_token(self): + auth = BearerTokenAuth(cast(TSecretStrValue, "test-token")) + assert auth.mask_secrets() == "t*****n" + + def test_auth_masks_api_key(self): + auth = APIKeyAuth(api_key=cast(TSecretStrValue, "test-token")) + assert auth.mask_secrets() == "t*****n" + + def test_auth_masks_http_basic(self): + auth = HttpBasicAuth("a_user", cast(TSecretStrValue, "test-token")) + assert auth.mask_secrets() == "a*****n" + + def test_auth_masks_oauth2(self): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + + assert auth.mask_secrets() == str( + {"access_token": "None", "client_id": "t*****d", "client_secret": "t*****t"} + ) + + auth.access_token = cast(TSecretStrValue, "test-token") + assert auth.mask_secrets() == str( + {"access_token": "t*****n", "client_id": "t*****d", "client_secret": "t*****t"} + ) + def test_basic_auth_success(self, rest_client: RESTClient): response = rest_client.get( "/protected/posts/basic-auth", @@ -349,6 +378,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: request.headers["Authorization"] = f"Bearer {self.token}" return request + def mask_secrets(self) -> str: + return self.mask_secret(self.token) + class CustomAuthAuthBase(AuthBase): def __init__(self, token: str): self.token = token @@ -357,6 +389,8 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: request.headers["Authorization"] = f"Bearer {self.token}" return request + + auth_list = [ CustomAuthConfigBase("test-token"), CustomAuthAuthBase("test-token"),