diff --git a/karton/core/backend.py b/karton/core/backend.py index 3045b9d..11aa3fa 100644 --- a/karton/core/backend.py +++ b/karton/core/backend.py @@ -14,6 +14,7 @@ InstanceMetadataFetcher, InstanceMetadataProvider, ) +from botocore.session import get_session from redis import AuthenticationError, StrictRedis from redis.client import Pipeline from urllib3.response import HTTPResponse @@ -120,7 +121,6 @@ def __init__( config, identity=identity, service_info=service_info ) - session_token = None endpoint = config.get("s3", "address") access_key = config.get("s3", "access_key") secret_key = config.get("s3", "secret_key") @@ -136,22 +136,10 @@ def __init__( ) if iam_auth: - iam_providers = [ - ContainerProvider(), - InstanceMetadataProvider( - iam_role_fetcher=InstanceMetadataFetcher( - timeout=1000, num_attempts=2 - ) - ), - ] - - for provider in iam_providers: - creds = provider.load() - if creds: - access_key = creds.access_key - secret_key = creds.secret_key - session_token = creds.token - break + s3_client = self.iam_auth_s3(endpoint) + if s3_client: + self.s3 = s3_client + return if access_key is None or secret_key is None: raise RuntimeError( @@ -163,9 +151,26 @@ def __init__( endpoint_url=endpoint, aws_access_key_id=access_key, aws_secret_access_key=secret_key, - aws_session_token=session_token, ) + def iam_auth_s3(self, endpoint: str): + boto_session = get_session() + iam_providers = [ + ContainerProvider(), + InstanceMetadataProvider( + iam_role_fetcher=InstanceMetadataFetcher(timeout=1000, num_attempts=2) + ), + ] + + for provider in iam_providers: + creds = provider.load() + if creds: + boto_session._credentials = creds # type: ignore + return boto3.Session(botocore_session=boto_session).client( + "s3", + endpoint_url=endpoint, + ) + @staticmethod def _validate_identity(identity: str): disallowed_chars = [" ", "?"]