Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements secret masking for auth helpers #1643

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import dataclasses
from abc import abstractmethod
from base64 import b64encode
from typing import (
TYPE_CHECKING,
Expand All @@ -17,6 +16,7 @@
from typing_extensions import Annotated
from requests.auth import AuthBase
from requests import PreparedRequest, Session as BaseSession # noqa: I251
from abc import abstractmethod

from dlt.common import logger
from dlt.common.exceptions import MissingDependencyException
Expand Down Expand Up @@ -46,6 +46,15 @@ def __bool__(self) -> bool:
# to be evaluated as False in requests.sessions.Session.prepare_request()
return True

def mask_secret(self, secret: Optional[str]) -> str:
if secret is None:
return "None"
return f"{secret[0]}*****{secret[-1]}"

@abstractmethod
def mask_secrets(self) -> str:
pass


@configspec
class BearerTokenAuth(AuthConfigBase):
Expand All @@ -67,6 +76,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

def mask_secrets(self) -> str:
return self.mask_secret(self.token)


@configspec
class APIKeyAuth(AuthConfigBase):
Expand Down Expand Up @@ -95,6 +107,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
raise NotImplementedError()
return request

def mask_secrets(self) -> str:
return self.mask_secret(self.api_key)


@configspec
class HttpBasicAuth(AuthConfigBase):
Expand All @@ -117,10 +132,16 @@ def parse_native_representation(self, value: Any) -> None:
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
encoded = b64encode(f"{self.username}:{self.password}".encode()).decode()
encoded = b64encode(self._format().encode()).decode()
request.headers["Authorization"] = f"Basic {encoded}"
return request

def mask_secrets(self) -> str:
return self.mask_secret(self._format())

def _format(self) -> str:
return f"{self.username}:{self.password}"


@configspec
class OAuth2AuthBase(AuthConfigBase):
Expand All @@ -143,6 +164,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request

def mask_secrets(self) -> str:
return self.mask_secret(self.access_token)


@configspec
class OAuth2ClientCredentials(OAuth2AuthBase):
Expand Down Expand Up @@ -213,6 +237,15 @@ def parse_expiration_in_seconds(self, response_json: Any) -> int:
def parse_access_token(self, response_json: Any) -> str:
return str(response_json.get("access_token"))

def mask_secrets(self) -> str:
return str(
{
"access_token": self.mask_secret(self.access_token),
"client_id": self.mask_secret(self.client_id),
"client_secret": self.mask_secret(self.client_secret),
}
)


@configspec
class OAuthJWTAuth(BearerTokenAuth):
Expand Down
34 changes: 34 additions & 0 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,35 @@ def response_hook(response: Response, *args: Any, **kwargs: Any) -> None:
pages = list(pages_iter)
assert pages == []

def test_auth_masks_bearer_token(self):
auth = BearerTokenAuth(cast(TSecretStrValue, "test-token"))
assert auth.mask_secrets() == "t*****n"

def test_auth_masks_api_key(self):
auth = APIKeyAuth(api_key=cast(TSecretStrValue, "test-token"))
assert auth.mask_secrets() == "t*****n"

def test_auth_masks_http_basic(self):
auth = HttpBasicAuth("a_user", cast(TSecretStrValue, "test-token"))
assert auth.mask_secrets() == "a*****n"

def test_auth_masks_oauth2(self):
auth = OAuth2ClientCredentials(
access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"),
client_id=cast(TSecretStrValue, "test-client-id"),
client_secret=cast(TSecretStrValue, "test-client-secret"),
session=Client().session,
)

assert auth.mask_secrets() == str(
{"access_token": "None", "client_id": "t*****d", "client_secret": "t*****t"}
)

auth.access_token = cast(TSecretStrValue, "test-token")
assert auth.mask_secrets() == str(
{"access_token": "t*****n", "client_id": "t*****d", "client_secret": "t*****t"}
)

def test_basic_auth_success(self, rest_client: RESTClient):
response = rest_client.get(
"/protected/posts/basic-auth",
Expand Down Expand Up @@ -349,6 +378,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

def mask_secrets(self) -> str:
return self.mask_secret(self.token)

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token
Expand All @@ -357,6 +389,8 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request



auth_list = [
CustomAuthConfigBase("test-token"),
CustomAuthAuthBase("test-token"),
Expand Down
Loading