-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic tests for Discord provider
- Loading branch information
1 parent
883b163
commit bd1e394
Showing
1 changed file
with
274 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -242,14 +242,19 @@ def handle_request( | |
|
||
# Parse and save the request details | ||
parsed_path = urllib.parse.urlparse(path) | ||
request_details = { | ||
'headers': {k.lower(): v for k, v in dict(handler.headers).items()}, | ||
'query_params': urllib.parse.parse_qs(parsed_path.query), | ||
'body': handler.rfile.read( | ||
int(handler.headers['Content-Length']) | ||
headers = {k.lower(): v for k, v in dict(handler.headers).items()} | ||
query_params = urllib.parse.parse_qs(parsed_path.query) | ||
if 'content-length' in headers: | ||
body = handler.rfile.read( | ||
int(headers['content-length']) | ||
).decode() | ||
if 'Content-Length' in handler.headers | ||
else None, | ||
else: | ||
body = None | ||
|
||
request_details = { | ||
'headers': headers, | ||
'query_params': query_params, | ||
'body': body, | ||
} | ||
self.requests[key].append(request_details) | ||
|
||
|
@@ -344,6 +349,7 @@ def __exit__(self, *exc): | |
GOOGLE_SECRET = 'c' * 32 | ||
AZURE_SECRET = 'c' * 32 | ||
APPLE_SECRET = 'c' * 32 | ||
DISCORD_SECRET = 'd' * 32 | ||
|
||
|
||
class TestHttpExtAuth(tb.ExtAuthTestCase): | ||
|
@@ -391,6 +397,12 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): | |
client_id := '{uuid.uuid4()}', | ||
}}; | ||
CONFIGURE CURRENT DATABASE | ||
INSERT ext::auth::DiscordOAuthProvider {{ | ||
secret := '{DISCORD_SECRET}', | ||
client_id := '{uuid.uuid4()}', | ||
}}; | ||
CONFIGURE CURRENT DATABASE | ||
INSERT ext::auth::EmailPasswordProviderConfig {{ | ||
require_verification := false, | ||
|
@@ -983,6 +995,261 @@ async def test_http_auth_ext_github_callback_failure_02(self): | |
"error=access_denied", | ||
) | ||
|
||
async def test_http_auth_ext_discord_authorize_01(self): | ||
with MockAuthProvider(), self.http_con() as http_con: | ||
provider_config = await self.get_builtin_provider_config_by_name( | ||
"oauth_discord" | ||
) | ||
provider_name = provider_config.name | ||
client_id = provider_config.client_id | ||
redirect_to = f"{self.http_addr}/some/path" | ||
challenge = ( | ||
base64.urlsafe_b64encode( | ||
hashlib.sha256( | ||
base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') | ||
).digest() | ||
) | ||
.rstrip(b'=') | ||
.decode() | ||
) | ||
query = { | ||
"provider": provider_name, | ||
"redirect_to": redirect_to, | ||
"challenge": challenge, | ||
} | ||
|
||
_, headers, status = self.http_con_request( | ||
http_con, | ||
query, | ||
path="authorize", | ||
) | ||
|
||
self.assertEqual(status, 302) | ||
|
||
location = headers.get("location") | ||
assert location is not None | ||
url = urllib.parse.urlparse(location) | ||
qs = urllib.parse.parse_qs(url.query, keep_blank_values=True) | ||
self.assertEqual(url.scheme, "https") | ||
self.assertEqual(url.hostname, "discord.com") | ||
self.assertEqual(url.path, "/oauth2/authorize") | ||
self.assertEqual(qs.get("scope"), ["email identify "]) | ||
|
||
state = qs.get("state") | ||
assert state is not None | ||
|
||
claims = await self.extract_jwt_claims(state[0]) | ||
self.assertEqual(claims.get("provider"), provider_name) | ||
self.assertEqual(claims.get("iss"), self.http_addr) | ||
self.assertEqual(claims.get("redirect_to"), redirect_to) | ||
|
||
self.assertEqual( | ||
qs.get("redirect_uri"), [f"{self.http_addr}/callback"] | ||
) | ||
self.assertEqual(qs.get("client_id"), [client_id]) | ||
|
||
pkce = await self.con.query( | ||
""" | ||
select ext::auth::PKCEChallenge | ||
filter .challenge = <str>$challenge | ||
""", | ||
challenge=challenge, | ||
) | ||
self.assertEqual(len(pkce), 1) | ||
|
||
_, _, repeat_status = self.http_con_request( | ||
http_con, | ||
query, | ||
path="authorize", | ||
) | ||
self.assertEqual(repeat_status, 302) | ||
|
||
repeat_pkce = await self.con.query_single( | ||
""" | ||
select ext::auth::PKCEChallenge | ||
filter .challenge = <str>$challenge | ||
""", | ||
challenge=challenge, | ||
) | ||
self.assertEqual(pkce[0].id, repeat_pkce.id) | ||
|
||
async def test_http_auth_ext_discord_callback_01(self): | ||
with MockAuthProvider() as mock_provider, self.http_con() as http_con: | ||
provider_config = await self.get_builtin_provider_config_by_name( | ||
"oauth_discord" | ||
) | ||
provider_name = provider_config.name | ||
client_id = provider_config.client_id | ||
client_secret = DISCORD_SECRET | ||
|
||
now = utcnow() | ||
token_request = ( | ||
"POST", | ||
"https://discord.com", | ||
"/api/oauth2/token", | ||
) | ||
mock_provider.register_route_handler(*token_request)( | ||
( | ||
json.dumps( | ||
{ | ||
"access_token": "discord_access_token", | ||
"scope": "read:user", | ||
"token_type": "bearer", | ||
} | ||
), | ||
200, | ||
) | ||
) | ||
|
||
user_request = ("GET", "https://discord.com/api/v10", "/users/@me") | ||
mock_provider.register_route_handler(*user_request)( | ||
( | ||
json.dumps( | ||
{ | ||
"id": 1, | ||
"username": "dischord", | ||
"global_name": "Ian MacKaye", | ||
"email": "[email protected]", | ||
"picture": "https://example.com/example.jpg", | ||
} | ||
), | ||
200, | ||
) | ||
) | ||
|
||
challenge = ( | ||
base64.urlsafe_b64encode( | ||
hashlib.sha256( | ||
base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') | ||
).digest() | ||
) | ||
.rstrip(b'=') | ||
.decode() | ||
) | ||
await self.con.query( | ||
""" | ||
insert ext::auth::PKCEChallenge { | ||
challenge := <str>$challenge, | ||
} | ||
""", | ||
challenge=challenge, | ||
) | ||
|
||
signing_key = await self.get_signing_key() | ||
|
||
expires_at = now + datetime.timedelta(minutes=5) | ||
state_claims = { | ||
"iss": self.http_addr, | ||
"provider": str(provider_name), | ||
"exp": expires_at.timestamp(), | ||
"redirect_to": f"{self.http_addr}/some/path", | ||
"challenge": challenge, | ||
} | ||
state_token = self.generate_state_value(state_claims, signing_key) | ||
|
||
data, headers, status = self.http_con_request( | ||
http_con, | ||
{"state": state_token, "code": "abc123"}, | ||
path="callback", | ||
) | ||
|
||
self.assertEqual(data, b"") | ||
self.assertEqual(status, 302) | ||
|
||
location = headers.get("location") | ||
assert location is not None | ||
server_url = urllib.parse.urlparse(self.http_addr) | ||
url = urllib.parse.urlparse(location) | ||
self.assertEqual(url.scheme, server_url.scheme) | ||
self.assertEqual(url.hostname, server_url.hostname) | ||
self.assertEqual(url.path, f"{server_url.path}/some/path") | ||
|
||
requests_for_token = mock_provider.requests[token_request] | ||
self.assertEqual(len(requests_for_token), 1) | ||
|
||
self.assertEqual( | ||
urllib.parse.parse_qs(requests_for_token[0]["body"]), | ||
{ | ||
"grant_type": ["authorization_code"], | ||
"code": ["abc123"], | ||
"client_id": [client_id], | ||
"client_secret": [client_secret], | ||
"redirect_uri": [f"{self.http_addr}/callback"], | ||
}, | ||
) | ||
|
||
requests_for_user = mock_provider.requests[user_request] | ||
self.assertEqual(len(requests_for_user), 1) | ||
self.assertEqual( | ||
requests_for_user[0]["headers"]["authorization"], | ||
"Bearer discord_access_token", | ||
) | ||
|
||
identity = await self.con.query( | ||
""" | ||
SELECT ext::auth::Identity | ||
FILTER .subject = '1' | ||
AND .issuer = 'https://discord.com' | ||
""" | ||
) | ||
self.assertEqual(len(identity), 1) | ||
|
||
session_claims = await self.extract_session_claims(headers) | ||
self.assertEqual(session_claims.get("sub"), str(identity[0].id)) | ||
self.assertEqual(session_claims.get("iss"), str(self.http_addr)) | ||
tomorrow = now + datetime.timedelta(hours=25) | ||
self.assertTrue(session_claims.get("exp") > now.timestamp()) | ||
self.assertTrue(session_claims.get("exp") < tomorrow.timestamp()) | ||
|
||
pkce_object = await self.con.query( | ||
""" | ||
SELECT ext::auth::PKCEChallenge | ||
{ id, auth_token, refresh_token } | ||
filter .identity.id = <uuid>$identity_id | ||
""", | ||
identity_id=identity[0].id, | ||
) | ||
|
||
self.assertEqual(len(pkce_object), 1) | ||
self.assertEqual(pkce_object[0].auth_token, "discord_access_token") | ||
self.assertIsNone(pkce_object[0].refresh_token) | ||
|
||
mock_provider.register_route_handler(*user_request)( | ||
( | ||
json.dumps( | ||
{ | ||
"id": 1, | ||
"login": "octocat", | ||
"name": "monalisa octocat", | ||
"email": "[email protected]", | ||
"avatar_url": "https://example.com/example.jpg", | ||
"updated_at": now.isoformat(), | ||
} | ||
), | ||
200, | ||
) | ||
) | ||
(_, new_headers, _) = self.http_con_request( | ||
http_con, | ||
{"state": state_token, "code": "abc123"}, | ||
path="callback", | ||
) | ||
|
||
same_identity = await self.con.query( | ||
""" | ||
SELECT ext::auth::Identity | ||
FILTER .subject = '1' | ||
AND .issuer = 'https://discord.com' | ||
""" | ||
) | ||
self.assertEqual(len(same_identity), 1) | ||
self.assertEqual(identity[0].id, same_identity[0].id) | ||
|
||
new_session_claims = await self.extract_session_claims(new_headers) | ||
self.assertTrue( | ||
new_session_claims.get("exp") > session_claims.get("exp") | ||
) | ||
|
||
async def test_http_auth_ext_google_callback_01(self) -> None: | ||
with MockAuthProvider() as mock_provider, self.http_con() as http_con: | ||
provider_config = await self.get_builtin_provider_config_by_name( | ||
|