diff --git a/.gitignore b/.gitignore index 5b338fe..d77cc3b 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv-awsprocesscreds/ # mypy .mypy_cache/ diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 16ed432..72b54df 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -218,6 +218,199 @@ def _get_value_of_first_tag(self, root, tag, attr, trait): class OktaAuthenticator(GenericFormsBasedAuthenticator): _AUTH_URL = '/api/v1/authn' + _ERROR_AUTH_CANCELLED = ( + 'Authentication cancelled' + ) + + _ERROR_LOCKED_OUT = ( + "You are locked out of your Okta account. Go to %s to unlock it." + ) + + _ERROR_PASSWORD_EXPIRED = ( + "Your password has expired. Go to %s to change it." + ) + + _ERROR_MFA_ENROLL = ( + "You need to enroll a MFA first." + ) + + _MSG_AUTH_CODE = ( + "Authentication code (RETURN to cancel): " + ) + + _MSG_ANSWER = ( + "Answer (RETURN to cancel): " + ) + + _MSG_SMS_CODE = ( + "SMS authentication code (RETURN to cancel, " + "'RESEND' to get new code sent): " + ) + + def get_response(self, prompt, allow_cancel=True): + response = self._password_prompter(prompt) + if allow_cancel and response == "": + raise SAMLError(self._ERROR_AUTH_CANCELLED) + return response + + def get_assertion_from_response(self, endpoint, parsed): + session_token = parsed['sessionToken'] + saml_url = endpoint + '?sessionToken=%s' % session_token + response = self._requests_session.get(saml_url) + logger.info( + 'Received HTTP response of status code: %s', response.status_code) + r = self._extract_saml_assertion_from_response(response.text) + logger.info( + 'Received the following SAML assertion: \n%s', r, + extra={'is_saml_assertion': True} + ) + return r + + def process_response(self, response, endpoint): + parsed = json.loads(response.text) + if response.status_code == 200: + return self.get_assertion_from_response(endpoint, parsed) + if response.status_code >= 400: + error = parsed["errorCauses"][0]["errorSummary"] + self.get_response("%s\r\nPress RETURN to continue\r\n" + % error, False) + return None + + def process_mfa_totp(self, endpoint, url, statetoken): + while True: + response = self.get_response(self._MSG_AUTH_CODE) + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken, + 'passCode': response}) + ) + result = self.process_response(totp_response, endpoint) + if result is not None: + return result + + def process_mfa_push(self, endpoint, url, statetoken): + self.get_response(("Press RETURN when you are ready to request the " + "push notification"), False) + while True: + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken}) + ) + totp_parsed = json.loads(totp_response.text) + if totp_parsed["status"] == "SUCCESS": + return self.get_assertion_from_response(endpoint, totp_parsed) + if totp_parsed["factorResult"] != "WAITING": + raise SAMLError(self._ERROR_AUTH_CANCELLED) + + def process_mfa_security_question(self, endpoint, url, statetoken): + while True: + response = self.get_response(self._MSG_ANSWER) + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken, + 'answer': response}) + ) + result = self.process_response(totp_response, endpoint) + if result is not None: + return result + + def verify_sms_factor(self, url, statetoken, passcode): + body = {'stateToken': statetoken} + if passcode != "": + body['passCode'] = passcode + return self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps(body) + ) + + def process_mfa_sms(self, endpoint, url, statetoken): + # Need to trigger the initial code to be sent ... + self.verify_sms_factor(url, statetoken, "") + while True: + response = self.get_response(self._MSG_SMS_CODE) + # If the user has asked for the code to be resent, clear + # the response to retrigger sending the code. + if response == "RESEND": + response = "" + sms_response = self.verify_sms_factor(url, statetoken, response) + # If we've just requested a resend, don't check the result + # - just loop around to get the next response from the user. + if response != "": + result = self.process_response(sms_response, endpoint) + if result is not None: + return result + + def display_mfa_choices(self, parsed): + index = 1 + prompt = "" + for f in parsed["_embedded"]["factors"]: + if f["factorType"] == "token": + prompt += "%s: %s token\r\n" % (index, f["provider"]) + elif f["factorType"] == "token:software:totp": + prompt += ("%s: %s authenticator app\r\n" + % (index, f["provider"])) + elif f["factorType"] == "sms": + prompt += "%s: SMS text message\r\n" % index + elif f["factorType"] == "push": + prompt += "%s: Push notification\r\n" % index + elif f["factorType"] == "question": + prompt += "%s: Security question\r\n" % index + else: + prompt += "%s: %s %s\r\n" % (index, + f["provider"], + f["factorType"]) + index += 1 + return index, prompt + + def get_number(self, prompt): + response = self.get_response(prompt) + choice = 0 + try: + choice = int(response) + except ValueError: + pass + return choice + + def get_mfa_choice(self, parsed): + count, prompt = self.display_mfa_choices(parsed) + prompt = ("Please choose from the following authentication" + " choices:\r\n") + prompt + prompt += ("Enter the number corresponding to your choice " + "or press RETURN to cancel authentication: ") + while True: + choice = self.get_number(prompt) + if 0 < choice < count: + return choice + + def process_mfa_verification(self, endpoint, parsed): + # If we've only got one factor, pick that automatically + if len(parsed["_embedded"]["factors"]) == 1: + choice = 1 + else: + choice = self.get_mfa_choice(parsed) + factor = parsed["_embedded"]["factors"][choice - 1] + url = factor["_links"]["verify"]["href"] + statetoken = parsed["stateToken"] + if factor["factorType"] == "token:software:totp": + return self.process_mfa_totp(endpoint, url, statetoken) + if factor["factorType"] == "push": + return self.process_mfa_push(endpoint, url, statetoken) + if factor["factorType"] == "question": + return self.process_mfa_security_question(endpoint, + url, statetoken) + if factor["factorType"] == "sms": + return self.process_mfa_sms(endpoint, url, statetoken) + + raise SAMLError("Unsupported factor") + def retrieve_saml_assertion(self, config): self._validate_config_values(config) endpoint = config['saml_endpoint'] @@ -237,17 +430,27 @@ def retrieve_saml_assertion(self, config): 'password': password}) ) parsed = json.loads(response.text) - session_token = parsed['sessionToken'] - saml_url = endpoint + '?sessionToken=%s' % session_token - response = self._requests_session.get(saml_url) - logger.info( - 'Received HTTP response of status code: %s', response.status_code) - r = self._extract_saml_assertion_from_response(response.text) logger.info( - 'Received the following SAML assertion: \n%s', r, - extra={'is_saml_assertion': True} + 'Got status %s and response: %s', + response.status_code, response.text ) - return r + if response.status_code == 401: + raise SAMLError(self._ERROR_LOGIN_FAILED_NON_200 % + parsed["errorSummary"]) + if "status" in parsed: + if parsed["status"] == "SUCCESS": + return self.get_assertion_from_response(endpoint, parsed) + if parsed["status"] == "LOCKED_OUT": + raise SAMLError(self._ERROR_LOCKED_OUT % + parsed["_links"]["href"]) + if parsed["status"] == "PASSWORD_EXPIRED": + raise SAMLError(self._ERROR_PASSWORD_EXPIRED % + parsed["_links"]["href"]) + if parsed["status"] == "MFA_ENROLL": + raise SAMLError(self._ERROR_MFA_ENROLL) + if parsed["status"] == "MFA_REQUIRED": + return self.process_mfa_verification(endpoint, parsed) + raise SAMLError("Code logic failure") def is_suitable(self, config): return (config.get('saml_authentication_type') == 'form' and @@ -309,7 +512,6 @@ class SAMLCredentialFetcher(CachedCredentialFetcher): SAML_FORM_AUTHENTICATORS = { 'okta': OktaAuthenticator, 'adfs': ADFSFormsBasedAuthenticator - } def __init__(self, client_creator, provider_name, saml_config, diff --git a/tests/functional/test_saml.py b/tests/functional/test_saml.py index dcaf836..14ab407 100644 --- a/tests/functional/test_saml.py +++ b/tests/functional/test_saml.py @@ -9,7 +9,8 @@ from tests import create_assertion from awsprocesscreds.cli import saml, PrettyPrinterLogHandler -from awsprocesscreds.saml import SAMLCredentialFetcher +from awsprocesscreds.saml import SAMLCredentialFetcher, OktaAuthenticator, \ + SAMLError @pytest.fixture @@ -22,9 +23,646 @@ def argv(): ] +def test_get_response_1(): + def mock_prompter(prompt): + return "" + + authenticator = OktaAuthenticator(mock_prompter) + with pytest.raises(SAMLError): + authenticator.get_response("") + + +def test_get_response_2(): + def mock_prompter(prompt): + return "mock_result" + + authenticator = OktaAuthenticator(mock_prompter) + response = authenticator.get_response("") + assert response == "mock_result" + + +def test_get_response_3(): + def mock_prompter(prompt): + return "" + + authenticator = OktaAuthenticator(mock_prompter) + response = authenticator.get_response("", False) + assert response == "" + + +def test_process_response_1(mock_requests_session, assertion, prompter): + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + mock_requests_session.get.return_value = assertion_response + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + result = authenticator._authenticator.process_response( + token_response, "endpoint") + assert result == assertion.decode() + + +def test_process_response_2(mock_requests_session, assertion, prompter): + def mock_prompter(prompt): + assert prompt == "Mock error\r\nPress RETURN to continue\r\n" + return "" + + session_token = { + 'sessionToken': 'spam', + 'status': 'FAILED', + 'errorCauses': [ + { + 'errorSummary': "Mock error" + } + ] + } + token_response = mock.Mock( + spec=requests.Response, + status_code=400, + text=json.dumps(session_token) + ) + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + + result = authenticator._authenticator.process_response( + token_response, "endpoint") + assert result is None + + +def test_process_mfa_totp( + mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "12345678" + + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + + result = authenticator._authenticator.process_mfa_totp( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_process_mfa_push_1( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + result = authenticator._authenticator.process_mfa_push( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_process_mfa_push_2( + mock_requests_session, prompter, assertion, capsys): + session_token = { + 'sessionToken': 'spam', + 'status': 'CANCELLED', + 'factorResult': 'FAILED' + } + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + mock_requests_session.post.return_value = token_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + with pytest.raises(SAMLError): + authenticator._authenticator.process_mfa_push( + "endpoint", "url", "statetoken") + + +def test_process_mfa_security_question( + mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "security_answer" + + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + + result = authenticator._authenticator.process_mfa_security_question( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_verify_sms_factor( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + mock_requests_session.post.return_value = token_response + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + result = authenticator._authenticator.verify_sms_factor( + "url", "statetoken", "passcode") + assert result.status_code == 200 + test = json.loads(result.text) + assert test["status"] == "SUCCESS" + + +def test_process_mfa_sms( + mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "12345678" + + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.verify_sms_factor", + return_value=token_response): + result = authenticator._authenticator.process_mfa_sms( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_display_mfa_choices( + mock_requests_session, prompter, assertion, capsys): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "token", + "provider": "OKTA" + }, + { + "factorType": "token:software:totp", + "provider": "OKTA" + }, + { + "factorType": "sms" + }, + { + "factorType": "push" + }, + { + "factorType": "question" + }, + { + "factorType": "blackboard", + "provider": "classroom" + } + ] + } + } + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + index, prompt = authenticator._authenticator.display_mfa_choices(parsed) + assert index == 7 + assert prompt == ( + "1: OKTA token\r\n" + "2: OKTA authenticator app\r\n" + "3: SMS text message\r\n" + "4: Push notification\r\n" + "5: Security question\r\n" + "6: classroom blackboard\r\n" + ) + + +def test_get_number_1(prompter): + def mock_prompter(prompt): + return "1" + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_number("") + assert response == 1 + + +def test_get_number_2(prompter): + def mock_prompter(prompt): + return "fred" + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_number("") + assert response == 0 + + +def test_get_mfa_choice( + mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + assert prompt == ( + "Please choose from the following authentication choices:\r\n" + "1: SMS text message\r\n" + "Enter the number corresponding to your choice or press RETURN to " + "cancel authentication: " + ) + return "1" + + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "sms" + } + ] + } + } + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_mfa_choice(parsed) + assert response == 1 + + +def test_process_mfa_verification_1(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "unsupported", + "_links": { + "verify": { + "href": "href" + } + } + }, + { + "factorType": "unsupported" + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.get_mfa_choice", + return_value=1): + with pytest.raises(SAMLError): + authenticator.process_mfa_verification("endpoint", parsed) + + +def test_process_mfa_verification_2(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "token:software:totp", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_totp", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_3(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "push", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_push", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_4(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "question", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator." + "process_mfa_security_question", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_5(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "sms", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_sms", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_retrieve_saml_assertion_1( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'FAILED', + 'errorSummary': 'Testing failure' + } + token_response = mock.Mock( + spec=requests.Response, status_code=401, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_2( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'LOCKED_OUT', + '_links': { + 'href': 'href' + } + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_3( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'PASSWORD_EXPIRED', + '_links': { + 'href': 'href' + } + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_4( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'MFA_ENROLL' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_5( + mock_requests_session, argv, prompter, assertion, + client_creator, capsys, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'MFA_REQUIRED' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_verification", + return_value=assertion): + saml(argv=argv, prompter=prompter, + client_creator=client_creator, + cache_dir=cache_dir) + + stdout, _ = capsys.readouterr() + assert stdout.endswith('\n') + + response = json.loads(stdout) + expected_response = { + "AccessKeyId": "foo", + "SecretAccessKey": "bar", + "SessionToken": "baz", + "Expiration": mock.ANY, + "Version": 1 + } + assert response == expected_response + + +def test_retrieve_saml_assertion_6( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + def test_cli(mock_requests_session, argv, prompter, assertion, client_creator, capsys, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -55,7 +693,7 @@ def test_cli(mock_requests_session, argv, prompter, assertion, client_creator, def test_no_cache(mock_requests_session, argv, prompter, assertion, client_creator, capsys, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -92,7 +730,7 @@ def test_no_cache(mock_requests_session, argv, prompter, assertion, def test_verbose(mock_requests_session, argv, prompter, assertion, client_creator, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -123,7 +761,7 @@ def test_verbose(mock_requests_session, argv, prompter, assertion, def test_log_handler_parses_assertion(mock_requests_session, argv, prompter, client_creator, cache_dir, caplog): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -160,7 +798,7 @@ def test_log_handler_parses_assertion(mock_requests_session, argv, prompter, def test_log_handler_parses_dict(mock_requests_session, argv, prompter, client_creator, cache_dir, caplog): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -237,7 +875,7 @@ def test_unsupported_saml_provider(client_creator, prompter): def test_prompter_only_called_once(client_creator, prompter, assertion, mock_requests_session): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index db2e218..a87d37f 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -373,7 +373,9 @@ def test_authn_requests_made(self, okta_auth, okta_config, session_token = 'mytoken' # 1st response is for authentication. mock_requests_session.post.return_value = mock.Mock( - text=json.dumps({"sessionToken": session_token}), + text=json.dumps( + {"sessionToken": session_token, "status": "SUCCESS"} + ), status_code=200 ) # 2nd response is to then retrieve the assertion.