diff --git a/aws_okta_keyman/aws.py b/aws_okta_keyman/aws.py index 80484db..91e961b 100644 --- a/aws_okta_keyman/aws.py +++ b/aws_okta_keyman/aws.py @@ -179,7 +179,6 @@ def available_roles(self): multiple_accounts = False first_account = '' formatted_roles = [] - i = 0 for role in self.assertion.roles(): account = role['role'].split(':')[4] role_name = role['role'].split(':')[5].split('/')[1] @@ -187,20 +186,27 @@ def available_roles(self): 'account': account, 'role_name': role_name, 'arn': role['role'], - 'principle': role['principle'], - 'roleIdx': i + 'principle': role['principle'] }) if first_account == '': first_account = account elif first_account != account: multiple_accounts = True - i = i + 1 if multiple_accounts: formatted_roles = self.account_ids_to_names(formatted_roles) - self.roles = sorted(formatted_roles, - key=lambda k: (k['account'], k['role_name'])) + formatted_roles = sorted(formatted_roles, + key=lambda k: (k['account'], k['role_name'])) + + # set the role role index after sorting + i = 0 + for role in formatted_roles: + role['roleIdx'] = i + i = i + 1 + + self.roles = formatted_roles + return self.roles def assume_role(self, print_only=False): diff --git a/aws_okta_keyman/test/aws_test.py b/aws_okta_keyman/test/aws_test.py index 60be405..f4f1de1 100644 --- a/aws_okta_keyman/test/aws_test.py +++ b/aws_okta_keyman/test/aws_test.py @@ -271,10 +271,15 @@ def test_assume_role_multiple(self, mock_write): def test_assume_role_preset(self, mock_write): mock_write.return_value = None assertion = mock.Mock() - assertion.roles.return_value = [{'arn': '', 'principle': ''}] + + roles = [{'role': '::::1:role/role1', 'principle': '', 'arn': '1'}, + {'role': '::::1:role/role2', 'principle': '', 'arn': '2'}, + {'role': '::::1:role/role3', 'principle': '', 'arn': '3'}] + + assertion.roles.return_value = roles session = aws.Session('BogusAssertion') - session.role = 0 - session.roles = [{'arn': '', 'principle': ''}] + session.role = 1 + session.roles = roles session.assertion = assertion sts = {'Credentials': {'AccessKeyId': 'AKI', @@ -296,6 +301,13 @@ def test_assume_role_preset(self, mock_write): mock_write.assert_has_calls([ mock.call() ]) + session.sts.assert_has_calls([ + mock.call.assume_role_with_saml( + RoleArn='2', + PrincipalArn='', + SAMLAssertion=mock.ANY, + DurationSeconds=3600) + ]) @mock.patch('aws_okta_keyman.aws.Session._print_creds') @mock.patch('aws_okta_keyman.aws.Session._write') @@ -420,23 +432,29 @@ def test_export_creds_to_var_string(self): self.assertEqual(ret, expected) def test_available_roles(self): - roles = [{'role': '::::1:role/role', 'principle': ''}, - {'role': '::::1:role/role', 'principle': ''}] + roles = [{'role': '::::1:role/role1', 'principle': ''}, + {'role': '::::1:role/role3', 'principle': ''}, + {'role': '::::1:role/role2', 'principle': ''}] session = aws.Session('BogusAssertion') session.assertion = mock.MagicMock() session.assertion.roles.return_value = roles - expected = [ - {'account': '1', 'role_name': 'role', - 'principle': '', 'arn': '::::1:role/role', - 'roleIdx': 0}, - {'account': '1', 'role_name': 'role', - 'principle': '', 'arn': '::::1:role/role', - 'roleIdx': 1} - ] result = session.available_roles() print(result) + + expected = [ + {'account': '1', 'role_name': 'role1', + 'principle': '', 'arn': '::::1:role/role1', + 'roleIdx': 0}, + {'account': '1', 'role_name': 'role2', + 'principle': '', 'arn': '::::1:role/role2', + 'roleIdx': 1}, + {'account': '1', 'role_name': 'role3', + 'principle': '', 'arn': '::::1:role/role3', + 'roleIdx': 2} + ] + self.assertEqual(expected, result) def test_available_roles_multiple_accounts(self): @@ -453,9 +471,9 @@ def test_available_roles_multiple_accounts(self): session.account_ids_to_names.return_value = roles_full expected = [ {'account': '1', 'role_name': 'role', - 'principle': '', 'arn': '::::1:role/role'}, + 'principle': '', 'arn': '::::1:role/role', 'roleIdx': 0}, {'account': '2', 'role_name': 'role', - 'principle': '', 'arn': '::::2:role/role'} + 'principle': '', 'arn': '::::2:role/role', 'roleIdx': 1} ] result = session.available_roles()