-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(ingestion): Fixed STS token refresh mechanism for sagemaker source #11252
base: master
Are you sure you want to change the base?
Changes from 19 commits
138e542
f7e7bba
287a39a
aea9b35
6a7b8bc
07f7dcf
934d73f
61869c8
feb5c6d
5d5dcae
37a6b34
88ce18c
b8cebfd
8cd4294
4282cd0
52be907
5918411
18d0d73
3dd3c7e
e315673
f17e4a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||
from datetime import datetime, timedelta, timezone | ||||||
import logging | ||||||
from datetime import datetime, timedelta | ||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||||||
|
||||||
import boto3 | ||||||
from boto3.session import Session | ||||||
from botocore.config import DEFAULT_TIMEOUT, Config | ||||||
from botocore.utils import fix_s3_host | ||||||
|
@@ -21,6 +21,8 @@ | |||||
from mypy_boto3_sagemaker import SageMakerClient | ||||||
from mypy_boto3_sts import STSClient | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
class AwsAssumeRoleConfig(PermissiveConfigModel): | ||||||
# Using the PermissiveConfigModel to allow the user to pass additional arguments. | ||||||
|
@@ -40,13 +42,13 @@ def assume_role( | |||||
credentials: Optional[dict] = None, | ||||||
) -> dict: | ||||||
credentials = credentials or {} | ||||||
sts_client: "STSClient" = boto3.client( | ||||||
"sts", | ||||||
region_name=aws_region, | ||||||
session = Session( | ||||||
aws_access_key_id=credentials.get("AccessKeyId"), | ||||||
aws_secret_access_key=credentials.get("SecretAccessKey"), | ||||||
aws_session_token=credentials.get("SessionToken"), | ||||||
region_name=aws_region, | ||||||
) | ||||||
sts_client: STSClient = session.client("sts") | ||||||
|
||||||
assume_role_args: dict = { | ||||||
**dict( | ||||||
|
@@ -64,6 +66,25 @@ def assume_role( | |||||
AUTODETECT_CREDENTIALS_DOC_LINK = "Can be auto-detected, see [the AWS boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) for details." | ||||||
|
||||||
|
||||||
class LazyLogEvaluator: | ||||||
""" | ||||||
Used by debug logging to avoid costly function calls if the information is not logged (due to the logging level set | ||||||
higher than DEBUG) | ||||||
""" | ||||||
|
||||||
def __init__(self, callback, *args): | ||||||
self.callback = callback | ||||||
self.args = args | ||||||
self.value_stored = False | ||||||
self.value = None | ||||||
|
||||||
def __repr__(self): | ||||||
if not self.value_stored: | ||||||
self.value = self.callback(*self.args) | ||||||
self.value_stored = True | ||||||
return self.value | ||||||
|
||||||
|
||||||
class AwsConnectionConfig(ConfigModel): | ||||||
""" | ||||||
Common AWS credentials config. | ||||||
|
@@ -118,10 +139,12 @@ class AwsConnectionConfig(ConfigModel): | |||||
description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).", | ||||||
) | ||||||
|
||||||
def allowed_cred_refresh(self) -> bool: | ||||||
if self._normalized_aws_roles(): | ||||||
return True | ||||||
return False | ||||||
@staticmethod | ||||||
def get_caller_identity(session: Session) -> Optional[str]: | ||||||
logger.info("Retrieving identity of session: %s", session.profile_name) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this change over time, or can we cache this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic of the ingestor itself does not need it, we extract the caller identity from the session to make sure we acquired correct role (in debug mode only) |
||||||
sts_client = session.client("sts") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We prefer to use format string instead of parameter if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
response = sts_client.get_caller_identity() | ||||||
return response.get("Arn") | ||||||
|
||||||
def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]: | ||||||
if not self.aws_role: | ||||||
|
@@ -136,22 +159,38 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]: | |||||
] | ||||||
|
||||||
def get_session(self) -> Session: | ||||||
logger.debug("Beginning session retrieval") | ||||||
if self.aws_access_key_id and self.aws_secret_access_key: | ||||||
session = Session( | ||||||
aws_access_key_id=self.aws_access_key_id, | ||||||
aws_secret_access_key=self.aws_secret_access_key, | ||||||
aws_session_token=self.aws_session_token, | ||||||
region_name=self.aws_region, | ||||||
) | ||||||
logger.debug( | ||||||
"Authenticated using access key, I am: %s", | ||||||
LazyLogEvaluator(self.get_caller_identity, session), | ||||||
) | ||||||
elif self.aws_profile: | ||||||
session = Session( | ||||||
region_name=self.aws_region, profile_name=self.aws_profile | ||||||
) | ||||||
logger.debug( | ||||||
"Authenticated using AWS profile, I am: %s", | ||||||
LazyLogEvaluator(self.get_caller_identity, session), | ||||||
) | ||||||
else: | ||||||
# Use boto3's credential autodetection. | ||||||
session = Session(region_name=self.aws_region) | ||||||
logger.debug( | ||||||
"Authenticated using auto-detection, I am: %s", | ||||||
LazyLogEvaluator(self.get_caller_identity, session), | ||||||
) | ||||||
|
||||||
if self._normalized_aws_roles(): | ||||||
logger.debug( | ||||||
"Detected normalized aws roles list: %s", self._normalized_aws_roles() | ||||||
) | ||||||
# Use existing session credentials to start the chain of role assumption. | ||||||
current_credentials = session.get_credentials() | ||||||
credentials = { | ||||||
|
@@ -161,29 +200,49 @@ def get_session(self) -> Session: | |||||
} | ||||||
|
||||||
for role in self._normalized_aws_roles(): | ||||||
if self._should_refresh_credentials(): | ||||||
credentials = assume_role( | ||||||
role, | ||||||
self.aws_region, | ||||||
credentials=credentials, | ||||||
) | ||||||
if isinstance(credentials["Expiration"], datetime): | ||||||
self._credentials_expiration = credentials["Expiration"] | ||||||
credentials = assume_role( | ||||||
role, | ||||||
self.aws_region, | ||||||
credentials=credentials, | ||||||
) | ||||||
|
||||||
session = Session( | ||||||
aws_access_key_id=credentials["AccessKeyId"], | ||||||
aws_secret_access_key=credentials["SecretAccessKey"], | ||||||
aws_session_token=credentials["SessionToken"], | ||||||
region_name=self.aws_region, | ||||||
) | ||||||
if isinstance(credentials["Expiration"], datetime): | ||||||
self._credentials_expiration = credentials["Expiration"] | ||||||
|
||||||
logger.debug( | ||||||
"Final session: %s", | ||||||
LazyLogEvaluator(self.get_caller_identity, session), | ||||||
) | ||||||
return session | ||||||
|
||||||
def _should_refresh_credentials(self) -> bool: | ||||||
def should_refresh_credentials(self) -> bool: | ||||||
logger.debug("Checking whether we should refresh credentials") | ||||||
if not self._normalized_aws_roles(): | ||||||
logger.debug( | ||||||
"Didn't recognize any aws roles to assume, deciding not to refresh" | ||||||
) | ||||||
return False | ||||||
if self._credentials_expiration is None: | ||||||
logger.debug("No credentials expiration time recorded") | ||||||
return True | ||||||
remaining_time = self._credentials_expiration - datetime.now(timezone.utc) | ||||||
return remaining_time < timedelta(minutes=5) | ||||||
time_now = datetime.now(self._credentials_expiration.tzinfo) | ||||||
remaining_time = self._credentials_expiration - time_now | ||||||
should_refresh = remaining_time < timedelta(minutes=5) | ||||||
logger.debug( | ||||||
"Current credentials expiration: %s | Current time: %s | Remaining time: %s | Therefor should " | ||||||
"we refresh? %s", | ||||||
self._credentials_expiration, | ||||||
time_now, | ||||||
remaining_time, | ||||||
"YES" if should_refresh else "NO", | ||||||
) | ||||||
return should_refresh | ||||||
|
||||||
def get_credentials(self) -> Dict[str, Optional[str]]: | ||||||
credentials = self.get_session().get_credentials() | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Iterable, List | ||
from typing import TYPE_CHECKING, Callable, Iterable, List | ||
|
||
import datahub.emitter.mce_builder as builder | ||
from datahub.ingestion.api.workunit import MetadataWorkUnit | ||
|
@@ -28,10 +29,12 @@ | |
FeatureGroupSummaryTypeDef, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class FeatureGroupProcessor: | ||
sagemaker_client: "SageMakerClient" | ||
sagemaker_client: Callable[[], "SageMakerClient"] | ||
env: str | ||
report: SagemakerSourceReport | ||
|
||
|
@@ -41,10 +44,13 @@ def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]: | |
""" | ||
|
||
feature_groups = [] | ||
|
||
logger.debug("Attempting to get all feature groups") | ||
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_feature_groups | ||
paginator = self.sagemaker_client.get_paginator("list_feature_groups") | ||
paginator = self.sagemaker_client().get_paginator("list_feature_groups") | ||
for page in paginator.paginate(): | ||
logger.debug( | ||
"Retrieved %s feature groups", len(page["FeatureGroupSummaries"]) | ||
) | ||
feature_groups += page["FeatureGroupSummaries"] | ||
|
||
return feature_groups | ||
|
@@ -55,9 +61,9 @@ def get_feature_group_details( | |
""" | ||
Get details of a feature group (including list of component features). | ||
""" | ||
|
||
logger.debug("Attempting to describe feature group: %s", feature_group_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here as above, please prefer fstrings where possible |
||
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group | ||
feature_group = self.sagemaker_client.describe_feature_group( | ||
feature_group = self.sagemaker_client().describe_feature_group( | ||
FeatureGroupName=feature_group_name | ||
) | ||
|
||
|
@@ -66,12 +72,19 @@ def get_feature_group_details( | |
|
||
# paginate over feature group features | ||
while next_token: | ||
next_features = self.sagemaker_client.describe_feature_group( | ||
logger.debug( | ||
"Iterating over another token to retrieve full feature group description for: %s", | ||
feature_group_name, | ||
) | ||
next_features = self.sagemaker_client().describe_feature_group( | ||
FeatureGroupName=feature_group_name, NextToken=next_token | ||
) | ||
feature_group["FeatureDefinitions"] += next_features["FeatureDefinitions"] | ||
next_token = feature_group.get("NextToken", "") | ||
|
||
logger.debug( | ||
"Retrieved full description for feature group: %s", feature_group_name | ||
) | ||
return feature_group | ||
|
||
def get_feature_group_wu( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this store mechanism used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I missed it, but you are not reusing this anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used by
DEBUG
logger to avoid costly call to sts if we are not onDEBUG
level. See here:https://github.com/datahub-project/datahub/pull/11252/files#diff-658ffa764c667e22fae4a46946899527226fa9e0ce7464d35cf56bd4c2a29726R220
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was more because I saw you cache the value in LazyLogEvalator, but the LazyLogEvaulator instance is never reused.
It is not a problem as it can be useful in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is reused actually. We create it once when log.debug line is hit. Then the repr function is called (assuming DEBUG-level is set) to print to console. And then it is called again (now using the cached value) because beside logging to console we also log to a file if debugging is turned on, so caching is, in fact, used.