Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingestion): Fixed STS token refresh mechanism for sagemaker source #11252

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 79 additions & 20 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timedelta, timezone
import logging
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import boto3
from boto3.session import Session
from botocore.config import DEFAULT_TIMEOUT, Config
from botocore.utils import fix_s3_host
Expand All @@ -21,6 +21,8 @@
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sts import STSClient

logger = logging.getLogger(__name__)


class AwsAssumeRoleConfig(PermissiveConfigModel):
# Using the PermissiveConfigModel to allow the user to pass additional arguments.
Expand All @@ -40,13 +42,13 @@ def assume_role(
credentials: Optional[dict] = None,
) -> dict:
credentials = credentials or {}
sts_client: "STSClient" = boto3.client(
"sts",
region_name=aws_region,
session = Session(
aws_access_key_id=credentials.get("AccessKeyId"),
aws_secret_access_key=credentials.get("SecretAccessKey"),
aws_session_token=credentials.get("SessionToken"),
region_name=aws_region,
)
sts_client: STSClient = session.client("sts")

assume_role_args: dict = {
**dict(
Expand All @@ -64,6 +66,25 @@ def assume_role(
AUTODETECT_CREDENTIALS_DOC_LINK = "Can be auto-detected, see [the AWS boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) for details."


class LazyLogEvaluator:
"""
Used by debug logging to avoid costly function calls if the information is not logged (due to the logging level set
higher than DEBUG)
"""

def __init__(self, callback, *args):
self.callback = callback
self.args = args
self.value_stored = False
self.value = None

def __repr__(self):
if not self.value_stored:
self.value = self.callback(*self.args)
self.value_stored = True
return self.value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this store mechanism used anywhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed it, but you are not reusing this anywhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used by DEBUG logger to avoid costly call to sts if we are not on DEBUG level. See here:
https://github.com/datahub-project/datahub/pull/11252/files#diff-658ffa764c667e22fae4a46946899527226fa9e0ce7464d35cf56bd4c2a29726R220

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was more because I saw you cache the value in LazyLogEvalator, but the LazyLogEvaulator instance is never reused.
It is not a problem as it can be useful in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is reused actually. We create it once when log.debug line is hit. Then the repr function is called (assuming DEBUG-level is set) to print to console. And then it is called again (now using the cached value) because beside logging to console we also log to a file if debugging is turned on, so caching is, in fact, used.



class AwsConnectionConfig(ConfigModel):
"""
Common AWS credentials config.
Expand Down Expand Up @@ -118,10 +139,12 @@ class AwsConnectionConfig(ConfigModel):
description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).",
)

def allowed_cred_refresh(self) -> bool:
if self._normalized_aws_roles():
return True
return False
@staticmethod
def get_caller_identity(session: Session) -> Optional[str]:
logger.info("Retrieving identity of session: %s", session.profile_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this change over time, or can we cache this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of the ingestor itself does not need it, we extract the caller identity from the session to make sure we acquired correct role (in debug mode only)

sts_client = session.client("sts")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer to use format string instead of parameter if possible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info("Retrieving identity of session: %s", session.profile_name)
logger.info(f"Retrieving identity of session: {session.profile_name}")

response = sts_client.get_caller_identity()
return response.get("Arn")

def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
if not self.aws_role:
Expand All @@ -136,22 +159,38 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
]

def get_session(self) -> Session:
logger.debug("Beginning session retrieval")
if self.aws_access_key_id and self.aws_secret_access_key:
session = Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.aws_region,
)
logger.debug(
"Authenticated using access key, I am: %s",
LazyLogEvaluator(self.get_caller_identity, session),
)
elif self.aws_profile:
session = Session(
region_name=self.aws_region, profile_name=self.aws_profile
)
logger.debug(
"Authenticated using AWS profile, I am: %s",
LazyLogEvaluator(self.get_caller_identity, session),
)
else:
# Use boto3's credential autodetection.
session = Session(region_name=self.aws_region)
logger.debug(
"Authenticated using auto-detection, I am: %s",
LazyLogEvaluator(self.get_caller_identity, session),
)

if self._normalized_aws_roles():
logger.debug(
"Detected normalized aws roles list: %s", self._normalized_aws_roles()
)
# Use existing session credentials to start the chain of role assumption.
current_credentials = session.get_credentials()
credentials = {
Expand All @@ -161,29 +200,49 @@ def get_session(self) -> Session:
}

for role in self._normalized_aws_roles():
if self._should_refresh_credentials():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

logger.debug(
"Final session: %s",
LazyLogEvaluator(self.get_caller_identity, session),
)
return session

def _should_refresh_credentials(self) -> bool:
def should_refresh_credentials(self) -> bool:
logger.debug("Checking whether we should refresh credentials")
if not self._normalized_aws_roles():
logger.debug(
"Didn't recognize any aws roles to assume, deciding not to refresh"
)
return False
if self._credentials_expiration is None:
logger.debug("No credentials expiration time recorded")
return True
remaining_time = self._credentials_expiration - datetime.now(timezone.utc)
return remaining_time < timedelta(minutes=5)
time_now = datetime.now(self._credentials_expiration.tzinfo)
remaining_time = self._credentials_expiration - time_now
should_refresh = remaining_time < timedelta(minutes=5)
logger.debug(
"Current credentials expiration: %s | Current time: %s | Remaining time: %s | Therefor should "
"we refresh? %s",
self._credentials_expiration,
time_now,
remaining_time,
"YES" if should_refresh else "NO",
)
return should_refresh

def get_credentials(self) -> Dict[str, Optional[str]]:
credentials = self.get_session().get_credentials()
Expand Down
23 changes: 13 additions & 10 deletions metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.source_config = config
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
self.env = config.env
self.client_factory = ClientFactory(config)

Expand All @@ -77,14 +76,18 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# get common lineage graph
lineage_processor = LineageProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
)
lineage = lineage_processor.get_lineage()

# extract feature groups if specified
if self.source_config.extract_feature_groups:
feature_group_processor = FeatureGroupProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
)
yield from feature_group_processor.get_workunits()

Expand All @@ -100,7 +103,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
env=self.env,
report=self.report,
job_type_filter=self.source_config.extract_jobs,
aws_region=self.sagemaker_client.meta.region_name,
aws_region=self.client_factory.get_client().meta.region_name,
)
yield from job_processor.get_workunits()

Expand All @@ -110,13 +113,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# extract models if specified
if self.source_config.extract_models:
model_processor = ModelProcessor(
sagemaker_client=self.sagemaker_client,
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
model_image_to_jobs=model_image_to_jobs,
model_name_to_jobs=model_name_to_jobs,
lineage=lineage,
aws_region=self.sagemaker_client.meta.region_name,
aws_region=self.client_factory.get_client().meta.region_name,
)
yield from model_processor.get_workunits()

Expand All @@ -127,10 +130,10 @@ def get_report(self):
class ClientFactory:
def __init__(self, config: SagemakerSourceConfig):
self.config = config
self._cached_client = self.config.sagemaker_client
self._cached_client: Optional[SageMakerClient] = None

def get_client(self) -> "SageMakerClient":
if self.config.allowed_cred_refresh():
# Always fetch the client dynamically with auto-refresh logic
return self.config.sagemaker_client
if not self._cached_client or self.config.should_refresh_credentials():
self._cached_client = self.config.get_sagemaker_client()
assert self._cached_client is not None
return self._cached_client
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class SagemakerSourceConfig(
# Custom Stateful Ingestion settings
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None

@property
def sagemaker_client(self):
return self.get_sagemaker_client()


@dataclass
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable, List
from typing import TYPE_CHECKING, Callable, Iterable, List

import datahub.emitter.mce_builder as builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
Expand Down Expand Up @@ -28,10 +29,12 @@
FeatureGroupSummaryTypeDef,
)

logger = logging.getLogger(__name__)


@dataclass
class FeatureGroupProcessor:
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport

Expand All @@ -41,10 +44,13 @@ def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
"""

feature_groups = []

logger.debug("Attempting to get all feature groups")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_feature_groups
paginator = self.sagemaker_client.get_paginator("list_feature_groups")
paginator = self.sagemaker_client().get_paginator("list_feature_groups")
for page in paginator.paginate():
logger.debug(
"Retrieved %s feature groups", len(page["FeatureGroupSummaries"])
)
feature_groups += page["FeatureGroupSummaries"]

return feature_groups
Expand All @@ -55,9 +61,9 @@ def get_feature_group_details(
"""
Get details of a feature group (including list of component features).
"""

logger.debug("Attempting to describe feature group: %s", feature_group_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here as above, please prefer fstrings where possible

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_feature_group
feature_group = self.sagemaker_client.describe_feature_group(
feature_group = self.sagemaker_client().describe_feature_group(
FeatureGroupName=feature_group_name
)

Expand All @@ -66,12 +72,19 @@ def get_feature_group_details(

# paginate over feature group features
while next_token:
next_features = self.sagemaker_client.describe_feature_group(
logger.debug(
"Iterating over another token to retrieve full feature group description for: %s",
feature_group_name,
)
next_features = self.sagemaker_client().describe_feature_group(
FeatureGroupName=feature_group_name, NextToken=next_token
)
feature_group["FeatureDefinitions"] += next_features["FeatureDefinitions"]
next_token = feature_group.get("NextToken", "")

logger.debug(
"Retrieved full description for feature group: %s", feature_group_name
)
return feature_group

def get_feature_group_wu(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -171,9 +174,11 @@ class JobProcessor:

def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []
logger.debug("Attempting to retrieve all jobs for type %s", job_type)
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]
logger.debug("Retrieved %s jobs", len(page_jobs))

for job in page_jobs:
job_name = (
Expand Down Expand Up @@ -269,6 +274,11 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key

logger.debug(
"Retrieving description for job: %s using command: %s",
job_name,
describe_command,
)
return getattr(self.sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)
Expand Down
Loading
Loading