-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Integration][AWS] | Fix InvalidToken Exceptions Due to Improper Token Refresh Calls #1190
base: main
Are you sure you want to change the base?
Changes from 21 commits
c4c8e36
02c9c6f
7f6187b
75d6190
629af39
ddc123b
3bac44a
f794011
6e8ae57
186a768
a700b11
2e65ce4
17c9cdf
5cc7258
7e5e81e
072894b
004a3d8
ecbd9a2
b79b4de
69c176e
418176e
2c59250
ec9fd92
e8a9de8
afa2dec
12029a8
098c2d3
7d113c5
d0a122d
6b486c1
aa06383
e07a666
9a816a0
baca035
655a49c
bb63690
42f9456
7246b49
0c5db5d
3204bd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
from typing import AsyncIterator, Optional, Iterable | ||
from typing import AsyncIterator, Optional, Iterable, List, Dict, Any | ||
import aioboto3 | ||
from aiobotocore.credentials import AioRefreshableCredentials | ||
from aiobotocore.session import get_session | ||
from types_aiobotocore_sts import STSClient | ||
|
||
from functools import partial | ||
|
||
|
||
ASSUME_ROLE_DURATION_SECONDS = 3600 # 1 hour | ||
|
||
|
||
class AwsCredentials: | ||
|
@@ -9,18 +17,97 @@ def __init__( | |
access_key_id: str, | ||
secret_access_key: str, | ||
session_token: Optional[str] = None, | ||
role_arn: Optional[str] = None, | ||
session_name: Optional[str] = None, | ||
): | ||
""" | ||
Represents AWS credentials for an account, with support for automatic refreshing. | ||
|
||
:param account_id: AWS account ID. | ||
:param access_key_id: AWS access key ID. | ||
:param secret_access_key: AWS secret access key. | ||
:param session_token: AWS session token (for temporary credentials). | ||
:param role_arn: ARN of the role to assume for refreshing credentials. | ||
:param session_name: Name for the assumed role session. | ||
""" | ||
self.account_id = account_id | ||
self.access_key_id = access_key_id | ||
self.secret_access_key = secret_access_key | ||
self.session_token = session_token | ||
self.enabled_regions: list[str] = [] | ||
self.default_regions: list[str] = [] | ||
self.role_arn = role_arn | ||
self.session_name = session_name | ||
self.enabled_regions: List[str] = [] | ||
self.default_regions: List[str] = [] | ||
|
||
async def update_enabled_regions(self) -> None: | ||
session = aioboto3.Session( | ||
self.access_key_id, self.secret_access_key, self.session_token | ||
async def _refresh_credentials(self, sts_client: STSClient) -> Dict[str, Any]: | ||
""" | ||
Refreshes AWS credentials by re-assuming the role to get new credentials. | ||
|
||
:return: A dictionary containing the new credentials and their expiration time. | ||
""" | ||
response = await sts_client.assume_role( | ||
RoleArn=str(self.role_arn), | ||
RoleSessionName=str(self.session_name), | ||
DurationSeconds=ASSUME_ROLE_DURATION_SECONDS, | ||
) | ||
credentials = response["Credentials"] | ||
self.access_key_id = credentials["AccessKeyId"] | ||
self.secret_access_key = credentials["SecretAccessKey"] | ||
self.session_token = credentials["SessionToken"] | ||
expiry_time = credentials["Expiration"].isoformat() | ||
return { | ||
"access_key": self.access_key_id, | ||
"secret_key": self.secret_access_key, | ||
"token": self.session_token, | ||
"expiry_time": expiry_time, | ||
} | ||
|
||
async def create_refreshable_session( | ||
self, region: Optional[str] = None | ||
) -> aioboto3.Session: | ||
""" | ||
Creates an aioboto3 Session with refreshable credentials. | ||
|
||
:param region: AWS region for the session. | ||
:return: An aioboto3 Session object. | ||
""" | ||
if self.is_role(): | ||
session = aioboto3.Session( | ||
aws_access_key_id=self.access_key_id, | ||
aws_secret_access_key=self.secret_access_key, | ||
aws_session_token=self.session_token, | ||
) | ||
async with session.client("sts") as sts_client: | ||
initial_credentials = await self._refresh_credentials(sts_client) | ||
refresh_credentials = partial(self._refresh_credentials, sts_client) | ||
refreshable_credentials = ( | ||
AioRefreshableCredentials.create_from_metadata( | ||
metadata=initial_credentials, | ||
refresh_using=refresh_credentials, | ||
method="sts-assume-role", | ||
) | ||
) | ||
botocore_session = get_session() | ||
setattr(botocore_session, "_credentials", refreshable_credentials) | ||
if region: | ||
botocore_session.set_config_variable("region", region) | ||
autorefresh_session = aioboto3.Session( | ||
botocore_session=botocore_session | ||
) | ||
return autorefresh_session | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a pretty complex part of the code what do you think about adding maybe some comments about why this code exists and what is happening baisically? |
||
session = aioboto3.Session( | ||
aws_access_key_id=self.access_key_id, | ||
aws_secret_access_key=self.secret_access_key, | ||
region_name=region, | ||
) | ||
return session | ||
|
||
async def update_enabled_regions(self) -> None: | ||
""" | ||
Updates the list of enabled regions for the AWS account. | ||
""" | ||
session = await self.create_refreshable_session() | ||
async with session.client("account") as account_client: | ||
response = await account_client.list_regions( | ||
RegionOptStatusContains=["ENABLED", "ENABLED_BY_DEFAULT"] | ||
|
@@ -33,24 +120,22 @@ async def update_enabled_regions(self) -> None: | |
if region["RegionOptStatus"] == "ENABLED_BY_DEFAULT" | ||
] | ||
|
||
def is_role(self) -> bool: | ||
return self.session_token is not None | ||
|
||
async def create_session(self, region: Optional[str] = None) -> aioboto3.Session: | ||
if self.is_role(): | ||
return aioboto3.Session( | ||
self.access_key_id, self.secret_access_key, self.session_token, region | ||
) | ||
else: | ||
return aioboto3.Session( | ||
aws_access_key_id=self.access_key_id, | ||
aws_secret_access_key=self.secret_access_key, | ||
region_name=region, | ||
) | ||
|
||
async def create_session_for_each_region( | ||
async def create_refreshable_session_for_each_region( | ||
self, allowed_regions: Optional[Iterable[str]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why change the name pf this function, from an API standpoint the person using this function it will not matter if it's refreshable or not, WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree |
||
) -> AsyncIterator[aioboto3.Session]: | ||
""" | ||
Creates refreshable sessions for each allowed or enabled region. | ||
|
||
:param allowed_regions: Iterable of region names to create sessions for. | ||
:yield: An aioboto3 Session for each region. | ||
""" | ||
regions = allowed_regions or self.enabled_regions | ||
for region in regions: | ||
yield await self.create_session(region) | ||
yield await self.create_refreshable_session(region) | ||
|
||
def is_role(self) -> bool: | ||
""" | ||
Checks if the credentials are for an assumed role. | ||
:return: True if the credentials are for a role, False otherwise. | ||
""" | ||
return bool(self.session_token and self.role_arn) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,9 @@ async def reset(self) -> None: | |
application_credentials = await self._get_application_credentials() | ||
await application_credentials.update_enabled_regions() | ||
self._application_account_id = application_credentials.account_id | ||
self._application_session = await application_credentials.create_session() | ||
self._application_session = ( | ||
await application_credentials.create_refreshable_session() | ||
) | ||
|
||
self._aws_credentials.append(application_credentials) | ||
self._aws_accessible_accounts.append( | ||
|
@@ -98,7 +100,7 @@ async def _get_organization_session(self) -> aioboto3.Session | None: | |
organizations_client = await sts_client.assume_role( | ||
RoleArn=organization_role_arn, | ||
RoleSessionName="OceanOrgAssumeRoleSession", | ||
DurationSeconds=ASSUME_ROLE_DURATION_SECONDS, | ||
DurationSeconds=self._assume_role_duration_seconds(), | ||
) | ||
|
||
credentials = organizations_client["Credentials"] | ||
|
@@ -121,6 +123,10 @@ async def _get_organization_session(self) -> aioboto3.Session | None: | |
def _get_account_read_role_name(self) -> str: | ||
return ocean.integration_config.get("account_read_role_name", "") | ||
|
||
@staticmethod | ||
def _assume_role_duration_seconds() -> int: | ||
return int(ocean.integration_config.get("assume_role_duration", 900)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this default to ASSUME_ROLE_DURATION_SECONDS? |
||
|
||
async def _update_available_access_credentials(self) -> None: | ||
logger.info("Updating AWS credentials") | ||
async with ( | ||
|
@@ -152,26 +158,36 @@ async def _update_available_access_credentials(self) -> None: | |
async def _assume_role_and_update_credentials( | ||
self, sts_client: STSClient, account: dict[str, Any] | ||
) -> None: | ||
""" | ||
Assumes a role in a member account and updates the credentials list. | ||
|
||
:param account: A dictionary containing account information. | ||
""" | ||
role_name = self._get_account_read_role_name() | ||
role_arn = f'arn:aws:iam::{account["Id"]}:role/{role_name}' | ||
role_session_name = "OceanMemberAssumeRoleSession" | ||
|
||
try: | ||
account_role = await sts_client.assume_role( | ||
RoleArn=f'arn:aws:iam::{account["Id"]}:role/{self._get_account_read_role_name()}', | ||
RoleSessionName="OceanMemberAssumeRoleSession", | ||
response = await sts_client.assume_role( | ||
RoleArn=role_arn, | ||
RoleSessionName=role_session_name, | ||
DurationSeconds=ASSUME_ROLE_DURATION_SECONDS, | ||
) | ||
raw_credentials = account_role["Credentials"] | ||
credentials = AwsCredentials( | ||
credentials = response["Credentials"] | ||
aws_credentials = AwsCredentials( | ||
account_id=account["Id"], | ||
access_key_id=raw_credentials["AccessKeyId"], | ||
secret_access_key=raw_credentials["SecretAccessKey"], | ||
session_token=raw_credentials["SessionToken"], | ||
access_key_id=credentials["AccessKeyId"], | ||
secret_access_key=credentials["SecretAccessKey"], | ||
session_token=credentials["SessionToken"], | ||
role_arn=role_arn, | ||
session_name=role_session_name, | ||
) | ||
await credentials.update_enabled_regions() | ||
self._aws_credentials.append(credentials) | ||
await aws_credentials.update_enabled_regions() | ||
self._aws_credentials.append(aws_credentials) | ||
self._aws_accessible_accounts.append(account) | ||
except sts_client.exceptions.ClientError as e: | ||
if is_access_denied_exception(e): | ||
logger.info(f"Cannot assume role in account {account['Id']}. Skipping.") | ||
pass # Skip the account if assume_role fails due to permission issues or non-existent role | ||
else: | ||
raise | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is using the same
sts_client
here thread safe?I mean since we are running in parallel, can sharing the
sts_client
result in unexpected errors?