Skip to content

Commit

Permalink
Add basic tests for Discord provider
Browse files Browse the repository at this point in the history
  • Loading branch information
scotttrinh committed Jan 18, 2024
1 parent 883b163 commit bd1e394
Showing 1 changed file with 274 additions and 7 deletions.
281 changes: 274 additions & 7 deletions tests/test_http_ext_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,19 @@ def handle_request(

# Parse and save the request details
parsed_path = urllib.parse.urlparse(path)
request_details = {
'headers': {k.lower(): v for k, v in dict(handler.headers).items()},
'query_params': urllib.parse.parse_qs(parsed_path.query),
'body': handler.rfile.read(
int(handler.headers['Content-Length'])
headers = {k.lower(): v for k, v in dict(handler.headers).items()}
query_params = urllib.parse.parse_qs(parsed_path.query)
if 'content-length' in headers:
body = handler.rfile.read(
int(headers['content-length'])
).decode()
if 'Content-Length' in handler.headers
else None,
else:
body = None

request_details = {
'headers': headers,
'query_params': query_params,
'body': body,
}
self.requests[key].append(request_details)

Expand Down Expand Up @@ -344,6 +349,7 @@ def __exit__(self, *exc):
GOOGLE_SECRET = 'c' * 32
AZURE_SECRET = 'c' * 32
APPLE_SECRET = 'c' * 32
DISCORD_SECRET = 'd' * 32


class TestHttpExtAuth(tb.ExtAuthTestCase):
Expand Down Expand Up @@ -391,6 +397,12 @@ class TestHttpExtAuth(tb.ExtAuthTestCase):
client_id := '{uuid.uuid4()}',
}};
CONFIGURE CURRENT DATABASE
INSERT ext::auth::DiscordOAuthProvider {{
secret := '{DISCORD_SECRET}',
client_id := '{uuid.uuid4()}',
}};
CONFIGURE CURRENT DATABASE
INSERT ext::auth::EmailPasswordProviderConfig {{
require_verification := false,
Expand Down Expand Up @@ -983,6 +995,261 @@ async def test_http_auth_ext_github_callback_failure_02(self):
"error=access_denied",
)

async def test_http_auth_ext_discord_authorize_01(self):
with MockAuthProvider(), self.http_con() as http_con:
provider_config = await self.get_builtin_provider_config_by_name(
"oauth_discord"
)
provider_name = provider_config.name
client_id = provider_config.client_id
redirect_to = f"{self.http_addr}/some/path"
challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(
base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=')
).digest()
)
.rstrip(b'=')
.decode()
)
query = {
"provider": provider_name,
"redirect_to": redirect_to,
"challenge": challenge,
}

_, headers, status = self.http_con_request(
http_con,
query,
path="authorize",
)

self.assertEqual(status, 302)

location = headers.get("location")
assert location is not None
url = urllib.parse.urlparse(location)
qs = urllib.parse.parse_qs(url.query, keep_blank_values=True)
self.assertEqual(url.scheme, "https")
self.assertEqual(url.hostname, "discord.com")
self.assertEqual(url.path, "/oauth2/authorize")
self.assertEqual(qs.get("scope"), ["email identify "])

state = qs.get("state")
assert state is not None

claims = await self.extract_jwt_claims(state[0])
self.assertEqual(claims.get("provider"), provider_name)
self.assertEqual(claims.get("iss"), self.http_addr)
self.assertEqual(claims.get("redirect_to"), redirect_to)

self.assertEqual(
qs.get("redirect_uri"), [f"{self.http_addr}/callback"]
)
self.assertEqual(qs.get("client_id"), [client_id])

pkce = await self.con.query(
"""
select ext::auth::PKCEChallenge
filter .challenge = <str>$challenge
""",
challenge=challenge,
)
self.assertEqual(len(pkce), 1)

_, _, repeat_status = self.http_con_request(
http_con,
query,
path="authorize",
)
self.assertEqual(repeat_status, 302)

repeat_pkce = await self.con.query_single(
"""
select ext::auth::PKCEChallenge
filter .challenge = <str>$challenge
""",
challenge=challenge,
)
self.assertEqual(pkce[0].id, repeat_pkce.id)

async def test_http_auth_ext_discord_callback_01(self):
with MockAuthProvider() as mock_provider, self.http_con() as http_con:
provider_config = await self.get_builtin_provider_config_by_name(
"oauth_discord"
)
provider_name = provider_config.name
client_id = provider_config.client_id
client_secret = DISCORD_SECRET

now = utcnow()
token_request = (
"POST",
"https://discord.com",
"/api/oauth2/token",
)
mock_provider.register_route_handler(*token_request)(
(
json.dumps(
{
"access_token": "discord_access_token",
"scope": "read:user",
"token_type": "bearer",
}
),
200,
)
)

user_request = ("GET", "https://discord.com/api/v10", "/users/@me")
mock_provider.register_route_handler(*user_request)(
(
json.dumps(
{
"id": 1,
"username": "dischord",
"global_name": "Ian MacKaye",
"email": "[email protected]",
"picture": "https://example.com/example.jpg",
}
),
200,
)
)

challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(
base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=')
).digest()
)
.rstrip(b'=')
.decode()
)
await self.con.query(
"""
insert ext::auth::PKCEChallenge {
challenge := <str>$challenge,
}
""",
challenge=challenge,
)

signing_key = await self.get_signing_key()

expires_at = now + datetime.timedelta(minutes=5)
state_claims = {
"iss": self.http_addr,
"provider": str(provider_name),
"exp": expires_at.timestamp(),
"redirect_to": f"{self.http_addr}/some/path",
"challenge": challenge,
}
state_token = self.generate_state_value(state_claims, signing_key)

data, headers, status = self.http_con_request(
http_con,
{"state": state_token, "code": "abc123"},
path="callback",
)

self.assertEqual(data, b"")
self.assertEqual(status, 302)

location = headers.get("location")
assert location is not None
server_url = urllib.parse.urlparse(self.http_addr)
url = urllib.parse.urlparse(location)
self.assertEqual(url.scheme, server_url.scheme)
self.assertEqual(url.hostname, server_url.hostname)
self.assertEqual(url.path, f"{server_url.path}/some/path")

requests_for_token = mock_provider.requests[token_request]
self.assertEqual(len(requests_for_token), 1)

self.assertEqual(
urllib.parse.parse_qs(requests_for_token[0]["body"]),
{
"grant_type": ["authorization_code"],
"code": ["abc123"],
"client_id": [client_id],
"client_secret": [client_secret],
"redirect_uri": [f"{self.http_addr}/callback"],
},
)

requests_for_user = mock_provider.requests[user_request]
self.assertEqual(len(requests_for_user), 1)
self.assertEqual(
requests_for_user[0]["headers"]["authorization"],
"Bearer discord_access_token",
)

identity = await self.con.query(
"""
SELECT ext::auth::Identity
FILTER .subject = '1'
AND .issuer = 'https://discord.com'
"""
)
self.assertEqual(len(identity), 1)

session_claims = await self.extract_session_claims(headers)
self.assertEqual(session_claims.get("sub"), str(identity[0].id))
self.assertEqual(session_claims.get("iss"), str(self.http_addr))
tomorrow = now + datetime.timedelta(hours=25)
self.assertTrue(session_claims.get("exp") > now.timestamp())
self.assertTrue(session_claims.get("exp") < tomorrow.timestamp())

pkce_object = await self.con.query(
"""
SELECT ext::auth::PKCEChallenge
{ id, auth_token, refresh_token }
filter .identity.id = <uuid>$identity_id
""",
identity_id=identity[0].id,
)

self.assertEqual(len(pkce_object), 1)
self.assertEqual(pkce_object[0].auth_token, "discord_access_token")
self.assertIsNone(pkce_object[0].refresh_token)

mock_provider.register_route_handler(*user_request)(
(
json.dumps(
{
"id": 1,
"login": "octocat",
"name": "monalisa octocat",
"email": "[email protected]",
"avatar_url": "https://example.com/example.jpg",
"updated_at": now.isoformat(),
}
),
200,
)
)
(_, new_headers, _) = self.http_con_request(
http_con,
{"state": state_token, "code": "abc123"},
path="callback",
)

same_identity = await self.con.query(
"""
SELECT ext::auth::Identity
FILTER .subject = '1'
AND .issuer = 'https://discord.com'
"""
)
self.assertEqual(len(same_identity), 1)
self.assertEqual(identity[0].id, same_identity[0].id)

new_session_claims = await self.extract_session_claims(new_headers)
self.assertTrue(
new_session_claims.get("exp") > session_claims.get("exp")
)

async def test_http_auth_ext_google_callback_01(self) -> None:
with MockAuthProvider() as mock_provider, self.http_con() as http_con:
provider_config = await self.get_builtin_provider_config_by_name(
Expand Down

0 comments on commit bd1e394

Please sign in to comment.