From d7b66b426bb4895284c9dc10557c8ed6f6d47bbb Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 12 Apr 2024 13:50:20 -0700 Subject: [PATCH] tests/ext_auth: Move mock server startup to `setUp()` All tests in the test case use it, so it's an appropriate thing to do, and allows us to avoid setting the ContextVar in the mock server guts. --- tests/test_http_ext_auth.py | 155 +++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 65 deletions(-) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 9966d89c6b73..a2015484b745 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -234,6 +234,12 @@ def __init__(self) -> None: ResponseType | Callable[[MockHttpServerHandler], ResponseType], ] = {} self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {} + self.url: Optional[str] = None + + def get_base_url(self) -> str: + if self.url is None: + raise RuntimeError("mock server is not running") + return self.url def register_route_handler( self, @@ -342,12 +348,15 @@ def handle_request( handler.end_headers() handler.wfile.write(data) - def __enter__(self): + def start(self): assert not hasattr(self, '_http_runner') self._http_runner = threading.Thread(target=self._http_worker) self._http_runner.start() self.has_started.wait() - HTTP_TEST_PORT.set(f'http://{self._address[0]}:{self._address[1]}/') + self.url = f'http://{self._address[0]}:{self._address[1]}/' + + def __enter__(self): + self.start() return self def _http_worker(self): @@ -360,11 +369,15 @@ def _http_worker(self): self._http_server.serve_forever(poll_interval=0.01) self._http_server.server_close() - def __exit__(self, *exc): + def stop(self): self._http_server.shutdown() self._http_runner.join() self._http_runner = None + def __exit__(self, *exc): + self.stop() + self.url = None + SIGNING_KEY = 'a' * 32 GITHUB_SECRET = 'b' * 32 @@ -478,6 +491,18 @@ def setUpClass(cls): cls._wait_for_db_config('ext::auth::AuthConfig::providers') ) + mock_provider: MockAuthProvider + + def setUp(self): + self.mock_provider = MockAuthProvider() + self.mock_provider.start() + HTTP_TEST_PORT.set(self.mock_provider.get_base_url()) + + def tearDown(self): + if self.mock_provider is not None: + self.mock_provider.stop() + self.mock_provider = None + @classmethod def get_setup_script(cls): res = super().get_setup_script() @@ -588,7 +613,7 @@ async def extract_session_claims(self, headers: dict[str, str]): return claims async def test_http_auth_ext_github_authorize_01(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_github" ) @@ -666,7 +691,7 @@ async def test_http_auth_ext_github_authorize_01(self): self.assertEqual(pkce[0].id, repeat_pkce.id) async def test_http_auth_ext_github_callback_missing_provider_01(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: signing_key = await self.get_signing_key() expires_at = utcnow() + datetime.timedelta(minutes=5) @@ -687,7 +712,7 @@ async def test_http_auth_ext_github_callback_missing_provider_01(self): self.assertEqual(status, 400) async def test_http_auth_ext_github_callback_wrong_key_01(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_github" ) @@ -715,7 +740,7 @@ async def test_http_auth_ext_github_callback_wrong_key_01(self): self.assertEqual(status, 400) async def test_http_auth_ext_github_unknown_provider_01(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: signing_key = await self.get_signing_key() expires_at = utcnow() + datetime.timedelta(minutes=5) @@ -748,7 +773,7 @@ async def test_http_auth_ext_github_unknown_provider_01(self): ) async def test_http_auth_ext_github_callback_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_github" ) @@ -762,7 +787,7 @@ async def test_http_auth_ext_github_callback_01(self): "https://github.com", "/login/oauth/access_token", ) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -776,7 +801,7 @@ async def test_http_auth_ext_github_callback_01(self): ) user_request = ("GET", "https://api.github.com", "/user") - mock_provider.register_route_handler(*user_request)( + self.mock_provider.register_route_handler(*user_request)( ( json.dumps( { @@ -839,7 +864,7 @@ async def test_http_auth_ext_github_callback_01(self): self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( json.loads(requests_for_token[0]["body"]), @@ -852,7 +877,7 @@ async def test_http_auth_ext_github_callback_01(self): }, ) - requests_for_user = mock_provider.requests[user_request] + requests_for_user = self.mock_provider.requests[user_request] self.assertEqual(len(requests_for_user), 1) self.assertEqual( requests_for_user[0]["headers"]["authorization"], @@ -888,7 +913,7 @@ async def test_http_auth_ext_github_callback_01(self): self.assertEqual(pkce_object[0].auth_token, "github_access_token") self.assertIsNone(pkce_object[0].refresh_token) - mock_provider.register_route_handler(*user_request)( + self.mock_provider.register_route_handler(*user_request)( ( json.dumps( { @@ -925,7 +950,7 @@ async def test_http_auth_ext_github_callback_01(self): ) async def test_http_auth_ext_github_callback_failure_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_github" ) @@ -937,7 +962,7 @@ async def test_http_auth_ext_github_callback_failure_01(self): "https://github.com", "/login/oauth/access_token", ) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -991,7 +1016,7 @@ async def test_http_auth_ext_github_callback_failure_01(self): ) async def test_http_auth_ext_github_callback_failure_02(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_github" ) @@ -1003,7 +1028,7 @@ async def test_http_auth_ext_github_callback_failure_02(self): "https://github.com", "/login/oauth/access_token", ) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -1052,7 +1077,7 @@ async def test_http_auth_ext_github_callback_failure_02(self): ) async def test_http_auth_ext_discord_authorize_01(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_discord" ) @@ -1130,7 +1155,7 @@ async def test_http_auth_ext_discord_authorize_01(self): self.assertEqual(pkce[0].id, repeat_pkce.id) async def test_http_auth_ext_discord_callback_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_discord" ) @@ -1144,7 +1169,7 @@ async def test_http_auth_ext_discord_callback_01(self): "https://discord.com", "/api/oauth2/token", ) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -1158,7 +1183,7 @@ async def test_http_auth_ext_discord_callback_01(self): ) user_request = ("GET", "https://discord.com/api/v10", "/users/@me") - mock_provider.register_route_handler(*user_request)( + self.mock_provider.register_route_handler(*user_request)( ( json.dumps( { @@ -1220,7 +1245,7 @@ async def test_http_auth_ext_discord_callback_01(self): self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( @@ -1234,7 +1259,7 @@ async def test_http_auth_ext_discord_callback_01(self): }, ) - requests_for_user = mock_provider.requests[user_request] + requests_for_user = self.mock_provider.requests[user_request] self.assertEqual(len(requests_for_user), 1) self.assertEqual( requests_for_user[0]["headers"]["authorization"], @@ -1270,7 +1295,7 @@ async def test_http_auth_ext_discord_callback_01(self): self.assertEqual(pkce_object[0].auth_token, "discord_access_token") self.assertIsNone(pkce_object[0].refresh_token) - mock_provider.register_route_handler(*user_request)( + self.mock_provider.register_route_handler(*user_request)( ( json.dumps( { @@ -1307,7 +1332,7 @@ async def test_http_auth_ext_discord_callback_01(self): ) async def test_http_auth_ext_google_callback_01(self) -> None: - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_google" ) @@ -1322,7 +1347,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: "https://accounts.google.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(GOOGLE_DISCOVERY_DOCUMENT), 200, @@ -1343,7 +1368,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: private_keys=False, as_dict=True ) - mock_provider.register_route_handler(*jwks_request)( + self.mock_provider.register_route_handler(*jwks_request)( ( json.dumps(jwk_set), 200, @@ -1366,7 +1391,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) id_token.make_signed_token(k) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -1427,10 +1452,10 @@ async def test_http_auth_ext_google_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 2) - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( json.loads(requests_for_token[0]["body"]), @@ -1460,7 +1485,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: self.assertTrue(session_claims.get("exp") < tomorrow.timestamp()) async def test_http_auth_ext_google_authorize_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_google" ) @@ -1481,7 +1506,7 @@ async def test_http_auth_ext_google_authorize_01(self): "https://accounts.google.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(GOOGLE_DISCOVERY_DOCUMENT), 200, @@ -1523,7 +1548,7 @@ async def test_http_auth_ext_google_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1536,7 +1561,7 @@ async def test_http_auth_ext_google_authorize_01(self): self.assertEqual(len(pkce), 1) async def test_http_auth_ext_azure_authorize_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_azure" ) @@ -1549,7 +1574,7 @@ async def test_http_auth_ext_azure_authorize_01(self): "https://login.microsoftonline.com/common/v2.0", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(AZURE_DISCOVERY_DOCUMENT), 200, @@ -1593,7 +1618,7 @@ async def test_http_auth_ext_azure_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1606,7 +1631,7 @@ async def test_http_auth_ext_azure_authorize_01(self): self.assertEqual(len(pkce), 1) async def test_http_auth_ext_azure_callback_01(self) -> None: - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_azure" ) @@ -1621,7 +1646,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: "https://login.microsoftonline.com/common/v2.0", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(AZURE_DISCOVERY_DOCUMENT), 200, @@ -1641,7 +1666,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: private_keys=False, as_dict=True ) - mock_provider.register_route_handler(*jwks_request)( + self.mock_provider.register_route_handler(*jwks_request)( ( json.dumps(jwk_set), 200, @@ -1664,7 +1689,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) id_token.make_signed_token(k) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -1725,10 +1750,10 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 2) - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( urllib.parse.parse_qs(requests_for_token[0]["body"]), @@ -1742,7 +1767,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: ) async def test_http_auth_ext_apple_authorize_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_apple" ) @@ -1763,7 +1788,7 @@ async def test_http_auth_ext_apple_authorize_01(self): "https://appleid.apple.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(APPLE_DISCOVERY_DOCUMENT), 200, @@ -1805,7 +1830,7 @@ async def test_http_auth_ext_apple_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1818,7 +1843,7 @@ async def test_http_auth_ext_apple_authorize_01(self): self.assertEqual(len(pkce), 1) async def test_http_auth_ext_apple_callback_01(self) -> None: - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_apple" ) @@ -1833,7 +1858,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: "https://appleid.apple.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(APPLE_DISCOVERY_DOCUMENT), 200, @@ -1853,7 +1878,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: private_keys=False, as_dict=True ) - mock_provider.register_route_handler(*jwks_request)( + self.mock_provider.register_route_handler(*jwks_request)( ( json.dumps(jwk_set), 200, @@ -1876,7 +1901,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) id_token.make_signed_token(k) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -1942,10 +1967,10 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 2) - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( urllib.parse.parse_qs(requests_for_token[0]["body"]), @@ -1961,7 +1986,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: async def test_http_auth_ext_apple_callback_redirect_on_signup_02( self, ) -> None: - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_apple" ) @@ -1975,7 +2000,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( "https://appleid.apple.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(APPLE_DISCOVERY_DOCUMENT), 200, @@ -1995,7 +2020,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( private_keys=False, as_dict=True ) - mock_provider.register_route_handler(*jwks_request)( + self.mock_provider.register_route_handler(*jwks_request)( ( json.dumps(jwk_set), 200, @@ -2018,7 +2043,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) id_token.make_signed_token(k) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -2108,7 +2133,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( self.assertEqual(url.path, f"{server_url.path}/some/path") async def test_http_auth_ext_slack_callback_01(self) -> None: - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_slack" ) @@ -2123,7 +2148,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: "https://slack.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(SLACK_DISCOVERY_DOCUMENT), 200, @@ -2144,7 +2169,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: private_keys=False, as_dict=True ) - mock_provider.register_route_handler(*jwks_request)( + self.mock_provider.register_route_handler(*jwks_request)( ( json.dumps(jwk_set), 200, @@ -2167,7 +2192,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) id_token.make_signed_token(k) - mock_provider.register_route_handler(*token_request)( + self.mock_provider.register_route_handler(*token_request)( ( json.dumps( { @@ -2228,10 +2253,10 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 2) - requests_for_token = mock_provider.requests[token_request] + requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( urllib.parse.parse_qs(requests_for_token[0]["body"]), @@ -2261,7 +2286,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: self.assertTrue(session_claims.get("exp") < tomorrow.timestamp()) async def test_http_auth_ext_slack_authorize_01(self): - with MockAuthProvider() as mock_provider, self.http_con() as http_con: + with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( "oauth_slack" ) @@ -2282,7 +2307,7 @@ async def test_http_auth_ext_slack_authorize_01(self): "https://slack.com", "/.well-known/openid-configuration", ) - mock_provider.register_route_handler(*discovery_request)( + self.mock_provider.register_route_handler(*discovery_request)( ( json.dumps(SLACK_DISCOVERY_DOCUMENT), 200, @@ -2324,7 +2349,7 @@ async def test_http_auth_ext_slack_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = mock_provider.requests[discovery_request] + requests_for_discovery = self.mock_provider.requests[discovery_request] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -3933,7 +3958,7 @@ async def test_client_token_identity_card(self): ) async def test_http_auth_ext_static_files(self): - with MockAuthProvider(), self.http_con() as http_con: + with self.http_con() as http_con: _, _, status = self.http_con_request( http_con, path="ui/_static/icon_github.svg",