Skip to content

Commit

Permalink
adds session tagging possibility
Browse files Browse the repository at this point in the history
  • Loading branch information
pbinczyk-lcloud authored and mtskillman committed Aug 12, 2023
1 parent 1fdf508 commit e972fbe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
29 changes: 28 additions & 1 deletion awsume/awsumepy/default_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def add_arguments(config: dict, parser: argparse.ArgumentParser):
dest='list_profiles',
help='List profiles, "more" for detail (slow)',
)
parser.add_argument('-t', '--tags',
action='store',
default=None,
metavar='Key=Value,...',
dest='session_tags',
help='A Key=Value list of session tags',
)
parser.add_argument('--refresh-autocomplete',
action='store_true',
dest='refresh_autocomplete',
Expand Down Expand Up @@ -275,6 +282,20 @@ def post_add_arguments(config: dict, arguments: argparse.Namespace, parser: argp
else:
arguments.target_profile_name = arguments.profile_name

if arguments.session_tags:
tags = []
for tag in arguments.session_tags.split(","):
kv = tag.split('=')
try:
tags.append({
'Key': kv[0],
'Value': kv[1] or ''
})
except IndexError:
parser.error('--tags must be a valid string of Key=Value,... tags')
if len(tags) > 0:
arguments.session_tags = tags


@hookimpl(tryfirst=True)
def collect_aws_profiles(config: dict, arguments: argparse.Namespace, credentials_file: str, config_file: str):
Expand Down Expand Up @@ -308,7 +329,7 @@ def assume_role_from_cli(config: dict, arguments: dict, profiles: dict):
logger.debug('Session name: {}'.format(session_name))
if not arguments.source_profile:
logger.debug('Using current credentials to assume role')
role_session = aws_lib.assume_role({}, arguments.role_arn, session_name, region=region, external_id=arguments.external_id, role_duration=role_duration)
role_session = aws_lib.assume_role({}, arguments.role_arn, session_name, region=region, external_id=arguments.external_id, role_duration=role_duration, tags=arguments.session_tags)
else:
logger.debug('Using the source_profile from the cli to call assume_role')
source_profile = profiles.get(arguments.source_profile)
Expand All @@ -331,6 +352,7 @@ def assume_role_from_cli(config: dict, arguments: dict, profiles: dict):
role_duration=role_duration,
mfa_serial=mfa_serial,
mfa_token=arguments.mfa_token,
tags=arguments.session_tags
)
else:
logger.debug('MFA not needed, assuming role from with profile creds')
Expand All @@ -341,6 +363,7 @@ def assume_role_from_cli(config: dict, arguments: dict, profiles: dict):
region=region,
external_id=arguments.external_id,
role_duration=role_duration,
tags=arguments.session_tags
)
else:
logger.debug('Using default role duration')
Expand All @@ -364,6 +387,7 @@ def assume_role_from_cli(config: dict, arguments: dict, profiles: dict):
region=region,
external_id=arguments.external_id,
role_duration=role_duration,
tags=arguments.session_tags
)
return role_session

Expand All @@ -382,6 +406,7 @@ def get_assume_role_credentials(config: dict, arguments: argparse.Namespace, pro
region=region,
external_id=external_id,
role_duration=role_duration,
tags=arguments.session_tags
)
if 'SourceExpiration' in source_credentials:
role_session['SourceExpiration'] = source_credentials['SourceExpiration']
Expand Down Expand Up @@ -426,6 +451,7 @@ def get_assume_role_credentials_mfa_required(config: dict, arguments: argparse.N
region=region,
external_id=external_id,
role_duration=role_duration,
tags=arguments.session_tags
)

if 'SourceExpiration' in source_session:
Expand Down Expand Up @@ -457,6 +483,7 @@ def get_assume_role_credentials_mfa_required_large_custom_duration(config: dict,
role_duration=role_duration,
mfa_serial=mfa_serial,
mfa_token=arguments.mfa_token,
tags=arguments.session_tags
)
return role_session

Expand Down
3 changes: 3 additions & 0 deletions awsume/awsumepy/lib/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def assume_role(
role_duration: int = None,
mfa_serial: str = None,
mfa_token: str = None,
tags: list | None = None,
) -> dict:
if len(session_name) < 2:
session_name = session_name.center(2, '_')
Expand All @@ -62,6 +63,8 @@ def assume_role(
if mfa_serial:
kwargs['SerialNumber'] = mfa_serial
kwargs['TokenCode'] = mfa_token or profile_lib.get_mfa_token()
if tags:
kwargs["Tags"] = tags
logger.debug('Assuming role now')
role_session = role_sts_client.assume_role(**kwargs).get('Credentials')
logger.debug('Received role credentials')
Expand Down

0 comments on commit e972fbe

Please sign in to comment.