diff --git a/api/src/pcapi/connectors/discord.py b/api/src/pcapi/connectors/discord.py
index 31af2664296..0cf11856f32 100644
--- a/api/src/pcapi/connectors/discord.py
+++ b/api/src/pcapi/connectors/discord.py
@@ -7,16 +7,30 @@
DISCORD_CLIENT_ID = "1261948740915433574"
DISCORD_CLIENT_SECRET = settings.DISCORD_CLIENT_SECRET
DISCORD_CALLBACK_URI = f"{settings.API_URL}/auth/discord/callback"
-DISCORD_REDIRECT_SUCCESS = f"{settings.API_URL}/auth/discord/success"
-DISCORD_FULL_REDIRECT_URI = (
- f"https://discord.com/api/oauth2/authorize"
- f"?client_id={DISCORD_CLIENT_ID}"
- f"&redirect_uri={DISCORD_CALLBACK_URI}"
- f"&response_type=code"
- f"&scope=identify%20guilds.join"
-)
DISCORD_HOME_URI = f"https://discord.com/channels/{DISCORD_GUILD_ID}/@home"
+DISCORD_API_URI = "https://discord.com/api"
+
+
+def build_discord_redirection_uri(user_id: int) -> str:
+ base_uri = f"{DISCORD_API_URI}/oauth2/authorize"
+ client_id = DISCORD_CLIENT_ID
+ redirect_uri = DISCORD_CALLBACK_URI
+ response_type = "code"
+ scope = "identify%20guilds.join"
+
+ return f"{base_uri}?client_id={client_id}&redirect_uri={redirect_uri}&response_type={response_type}&scope={scope}&state={user_id}"
+
+
+def get_user_id(access_token: str) -> str | None:
+ url = f"{DISCORD_API_URI}/oauth2/@me"
+ user_response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"})
+ user_response.raise_for_status()
+ try:
+ return user_response.json()["user"]["id"]
+ except KeyError:
+ return None
+
def retrieve_access_token(code: str) -> str | None:
data = {
@@ -34,19 +48,14 @@ def retrieve_access_token(code: str) -> str | None:
return access_token
-def add_to_server(access_token: str) -> None:
+def add_to_server(access_token: str, user_discord_id: str) -> None:
"""
Adds the user to the pass culture discord server
Our server is identified by the DISCORD_GUILD_ID
"""
- user_response = requests.get(
- "https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
- )
- user_response.raise_for_status()
-
- user_id = user_response.json()["id"]
data = {"access_token": access_token}
- url = f"https://discord.com/api/guilds/{DISCORD_GUILD_ID}/members/{user_id}"
+ url = f"https://discord.com/api/guilds/{DISCORD_GUILD_ID}/members/{user_discord_id}"
headers = {"Authorization": f"Bot {DISCORD_BOT_TOKEN}", "Content-Type": "application/json"}
- requests.put(url, json=data, headers=headers)
+ response = requests.put(url, json=data, headers=headers)
+ response.raise_for_status()
diff --git a/api/src/pcapi/routes/auth/discord.py b/api/src/pcapi/routes/auth/discord.py
index bf10d1849b1..ed496cd5568 100644
--- a/api/src/pcapi/routes/auth/discord.py
+++ b/api/src/pcapi/routes/auth/discord.py
@@ -4,11 +4,13 @@
from flask_wtf.csrf import CSRFProtect
from werkzeug.wrappers.response import Response
+from pcapi import repository
from pcapi.connectors import discord as discord_connector
from pcapi.core.users import exceptions as users_exceptions
from pcapi.core.users import models as user_models
from pcapi.core.users import repository as users_repo
from pcapi.models import db
+import pcapi.routes.auth.exceptions as auth_exceptions
from pcapi.routes.auth.forms.forms import SigninForm
from pcapi.utils import requests
@@ -16,44 +18,95 @@
from . import utils
+ERROR_STRING_PREFIX = "Erreur d'authentification Discord: "
+
+
@blueprint.auth_blueprint.route("/discord/signin", methods=["GET"])
def discord_signin() -> str:
form = SigninForm()
- form.discord_id.data = request.args.get("discord_id")
- form.redirect_url.data = discord_connector.DISCORD_FULL_REDIRECT_URI
if error_message := request.args.get("error"):
form.error_message = error_message
return render_template("discord_signin.html", form=form)
+def redirect_with_error(error_message: str) -> Response:
+ repository.mark_transaction_as_invalid()
+ return redirect(f"/auth/discord/signin?error={error_message}", code=303)
+
+
+def handle_http_error(error: requests.exceptions.HTTPError) -> Response:
+ error_message = ""
+ if error.response:
+ error_message = error.response.json().get("error_description")
+ if not error_message:
+ error_message = error.response.text
+ error_message += "Tu peux réessayer ou contacter le support."
+ return redirect_with_error(ERROR_STRING_PREFIX + error_message)
+
+
@blueprint.auth_blueprint.route("/discord/callback", methods=["GET"])
+@repository.atomic()
def discord_call_back() -> str | Response | None:
- # Webhook called by the discord server once the discord authentication is successful
code = request.args.get("code")
- ERROR_STRING_PREFIX = "Erreur d'authentification Discord: "
+ user_id = request.args.get("state")
+
if not code:
- return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}code non récupéré", code=303)
+ return redirect_with_error(f"{ERROR_STRING_PREFIX}code non récupéré")
+ if not user_id:
+ return redirect_with_error(f"{ERROR_STRING_PREFIX}user_id pass Culture non récupéré")
try:
access_token = discord_connector.retrieve_access_token(code)
except requests.exceptions.HTTPError as e:
- return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}{e.response.json().get('error')}", code=303)
+ return handle_http_error(e)
if not access_token:
- return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}access token non récupéré", code=303)
+ return redirect_with_error(f"{ERROR_STRING_PREFIX}access token non récupéré")
+
+ try:
+ user_discord_id = discord_connector.get_user_id(access_token)
+ except requests.exceptions.HTTPError as e:
+ return handle_http_error(e)
+
+ if not user_discord_id:
+ return redirect_with_error(f"{ERROR_STRING_PREFIX}discord id non récupéré")
+
+ try:
+ update_discord_user(user_id, user_discord_id)
+ except auth_exceptions.DiscordUserAlreadyLinked:
+ return redirect_with_error("Ce compte Discord est déjà lié à un autre compte pass Culture.")
+ except auth_exceptions.UserNotAllowed:
+ return redirect_with_error("Accès refusé au serveur Discord. Contacte le support pour plus d'informations")
try:
- discord_connector.add_to_server(access_token)
+ discord_connector.add_to_server(access_token, user_discord_id)
except requests.exceptions.HTTPError as e:
- return redirect(
- f"/auth/discord/signin?error={ERROR_STRING_PREFIX}{e.response.json().get('message')}",
- code=303,
- )
+ return handle_http_error(e)
return redirect(discord_connector.DISCORD_HOME_URI, code=303)
+def update_discord_user(user_id: str, discord_id: str) -> None:
+ already_linked_user = user_models.DiscordUser.query.filter_by(discordId=discord_id).first()
+ if already_linked_user:
+ raise auth_exceptions.DiscordUserAlreadyLinked()
+
+ user = user_models.User.query.get(user_id)
+ discord_user = user.discordUser
+
+ if discord_user is None:
+ # We still add the user to the database even if he doesn't have access to the discord server
+ discord_user = user_models.DiscordUser(userId=user.id, discordId=discord_id, hasAccess=False)
+ db.session.add(discord_user)
+ raise auth_exceptions.UserNotAllowed()
+
+ if not discord_user.hasAccess:
+ raise auth_exceptions.UserNotAllowed()
+
+ discord_user.discordId = discord_id
+
+
@blueprint.auth_blueprint.route("/discord/signin", methods=["POST"])
def discord_signin_post() -> str | Response | None:
csrf = CSRFProtect()
@@ -65,8 +118,6 @@ def discord_signin_post() -> str | Response | None:
email = form.email.data
password = form.password.data
- discord_id = form.discord_id.data
- url_redirection = form.redirect_url.data
try:
user = users_repo.get_user_with_credentials(email, password, allow_inactive=True)
@@ -86,17 +137,5 @@ def discord_signin_post() -> str | Response | None:
form.error_message = "Le compte a été anonymisé"
return render_template("discord_signin.html", form=form)
- discord_user = user.discordUser
- if discord_user is None or not discord_user.hasAccess:
- if discord_user is None:
- discord_user = user_models.DiscordUser(userId=user.id, discordId=discord_id, hasAccess=False)
- db.session.add(discord_user)
- db.session.commit()
- form.error_message = "Accès refusé au serveur Discord. Contacte le support pour plus d'informations"
- return render_template("discord_signin.html", form=form)
- if discord_user.is_active:
- return redirect(url_redirection)
-
- discord_user.discordId = discord_id
- db.session.commit()
+ url_redirection = discord_connector.build_discord_redirection_uri(user.id)
return redirect(url_redirection)
diff --git a/api/src/pcapi/routes/auth/exceptions.py b/api/src/pcapi/routes/auth/exceptions.py
new file mode 100644
index 00000000000..44a4555c05d
--- /dev/null
+++ b/api/src/pcapi/routes/auth/exceptions.py
@@ -0,0 +1,10 @@
+class DiscordUserAlreadyLinked(Exception):
+ pass
+
+
+class UserNotAllowed(Exception):
+ pass
+
+
+class DiscordException(Exception):
+ pass
diff --git a/api/src/pcapi/routes/auth/forms/forms.py b/api/src/pcapi/routes/auth/forms/forms.py
index 5b6f1b7839a..f3f124dc860 100644
--- a/api/src/pcapi/routes/auth/forms/forms.py
+++ b/api/src/pcapi/routes/auth/forms/forms.py
@@ -5,6 +5,4 @@
class SigninForm(PCForm):
email = fields.PCEmailField("Adresse email")
password = fields.PCPasswordField("Mot de passe")
- discord_id = fields.PCHiddenField("discord_id")
- redirect_url = fields.PCLongHiddenField("redirect_url")
error_message: str = ""
diff --git a/api/src/pcapi/routes/auth/templates/discord_signin.html b/api/src/pcapi/routes/auth/templates/discord_signin.html
index 6f36723904c..e52f21c4fa2 100644
--- a/api/src/pcapi/routes/auth/templates/discord_signin.html
+++ b/api/src/pcapi/routes/auth/templates/discord_signin.html
@@ -20,8 +20,6 @@
-
{{ form.redirect_url }}
-
{{ form.discord_id }}
{{ form.email }}
{{ form.password }}
diff --git a/api/tests/routes/auth/discord_test.py b/api/tests/routes/auth/discord_test.py
index 3b66c49fc0b..0cf5972d58e 100644
--- a/api/tests/routes/auth/discord_test.py
+++ b/api/tests/routes/auth/discord_test.py
@@ -7,11 +7,13 @@
import pytest
from pcapi import settings
+from pcapi.connectors import discord as discord_connector
from pcapi.core.history import factories as history_factories
from pcapi.core.testing import assert_num_queries
from pcapi.core.testing import override_settings
from pcapi.core.users import constants as users_constants
from pcapi.core.users import factories as users_factories
+from pcapi.utils import requests
pytestmark = pytest.mark.usefixtures("db_session")
@@ -59,108 +61,94 @@ def post_to_endpoint(
return client.post(url, form=form, headers=headers, follow_redirects=follow_redirects)
+ def test_build_discord_redirection_uri(self):
+ assert (
+ discord_connector.build_discord_redirection_uri("1")
+ == f"https://discord.com/api/oauth2/authorize?client_id={discord_connector.DISCORD_CLIENT_ID}&redirect_uri={discord_connector.DISCORD_CALLBACK_URI}&response_type=code&scope=identify%20guilds.join&state=1"
+ )
+
@override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem)
- def test_successful_discord_signing(self, client, db_session):
- redirect_url = "https://test.com"
+ def test_redirect_to_discord_on_post(self, client):
form_data = {
"email": "user@test.com",
"password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": redirect_url,
}
- user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True)
- discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False)
+ user = users_factories.BeneficiaryFactory(
+ email=form_data["email"], password=form_data["password"], isActive=True
+ )
response = self.post_to_endpoint(client, form=form_data)
assert response.status_code == 302
- assert response.location == redirect_url
-
- db_session.refresh(discord_user)
- assert discord_user.discordId == form_data["discord_id"]
+ assert response.location == discord_connector.build_discord_redirection_uri(user.id)
@unittest.mock.patch(
- "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="discord_access_token"
+ "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token"
)
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id")
@unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server")
- def test_discord_webhook(self, mock_add_to_server, mock_retrieve_access_token, client):
- client.get(url_for("auth.discord_call_back", code="discord_code"))
+ def test_discord_webhook_success(
+ self, mock_add_to_server, mock_get_user_id, mock_retrieve_access_token, client, db_session
+ ):
+ user = users_factories.BeneficiaryFactory()
+ discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False)
+
+ client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id)))
assert mock_retrieve_access_token.call_count == 1
assert mock_retrieve_access_token.call_args[0][0] == "discord_code"
- assert mock_add_to_server.call_count == 1
- assert mock_add_to_server.call_args[0][0] == "discord_access_token"
-
- @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem)
- def test_has_access_is_false(self, client, db_session):
- redirect_url = "https://test.com"
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": redirect_url,
- }
-
- user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True)
- discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=False, isBanned=False)
-
- response = self.post_to_endpoint(client, form=form_data)
-
- assert response.status_code == 200
- assert response.location is None
+ assert mock_get_user_id.call_count == 1
+ assert mock_get_user_id.call_args[0][0] == "access_token"
- response_data = response.data.decode("utf-8")
- assert "Accès refusé au serveur Discord. Contacte le support pour plus d'informations" in response_data
+ assert mock_add_to_server.call_count == 1
+ assert mock_add_to_server.call_args[0][0] == "access_token"
+ assert mock_add_to_server.call_args[0][1] == "discord_user_id"
db_session.refresh(discord_user)
- assert discord_user.discordId is None
+ assert discord_user.discordId == "discord_user_id"
- @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem)
- def test_discord_user_is_none(self, client):
- redirect_url = "https://test.com"
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": redirect_url,
- }
- users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True)
+ @unittest.mock.patch(
+ "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token"
+ )
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id")
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server")
+ def test_has_access_is_false(self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client):
+ user = users_factories.BeneficiaryFactory()
+ users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=False, isBanned=False)
- response = self.post_to_endpoint(client, form=form_data)
+ response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id)))
- assert response.status_code == 200
+ expected_query_params = (
+ "Accès refusé au serveur Discord. Contacte le support pour plus d'informations".replace(" ", "%20")
+ .replace("è", "%C3%A8")
+ .replace("é", "%C3%A9")
+ )
- assert response.status_code == 200
- assert response.location is None
+ assert response.status_code == 303
+ assert f"/auth/discord/signin?error={expected_query_params}" in response.location
- response_data = response.data.decode("utf-8")
- assert "Accès refusé au serveur Discord. Contacte le support pour plus d'informations" in response_data
+ @unittest.mock.patch(
+ "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token"
+ )
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id")
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server")
+ def test_discord_user_is_none(self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client):
+ user = users_factories.BeneficiaryFactory()
- @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem)
- def test_discord_user_is_active(self, client):
- redirect_url = "https://test.com"
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": redirect_url,
- }
- user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True)
- users_factories.DiscordUserFactory(user=user, hasAccess=True, isBanned=False)
+ response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id)))
- response = self.post_to_endpoint(client, form=form_data)
+ expected_query_params = (
+ "Accès refusé au serveur Discord. Contacte le support pour plus d'informations".replace(" ", "%20")
+ .replace("è", "%C3%A8")
+ .replace("é", "%C3%A9")
+ )
- assert response.status_code == 302
- assert response.location == redirect_url
+ assert response.status_code == 303
+ assert f"/auth/discord/signin?error={expected_query_params}" in response.location
def test_account_anonymized_user_request_account_state(self, client):
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD}
users_factories.AnonymizedUserFactory(
email=form_data["email"],
password=form_data["password"],
@@ -172,12 +160,7 @@ def test_account_anonymized_user_request_account_state(self, client):
assert "Le compte a été anonymisé" in response_data
def test_wrong_password(self, client):
- form_data = {
- "email": "user@test.com",
- "password": "wrong_password",
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com", "password": "wrong_password"}
users_factories.AnonymizedUserFactory(
email=form_data["email"],
password=settings.TEST_DEFAULT_PASSWORD,
@@ -189,12 +172,7 @@ def test_wrong_password(self, client):
assert "Identifiant ou Mot de passe incorrect" in response_data
def test_account_deleted_account_state(self, client):
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD}
user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=False)
history_factories.SuspendedUserActionHistoryFactory(user=user, reason=users_constants.SuspensionReason.DELETED)
@@ -205,12 +183,7 @@ def test_account_deleted_account_state(self, client):
assert "Le compte a été supprimé" in response_data
def test_inactive_user_signin(self, client):
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD}
users_factories.BaseUserFactory(email=form_data["email"], password=form_data["password"])
response = self.post_to_endpoint(client, form=form_data)
assert response.status_code == 200
@@ -221,12 +194,7 @@ def test_inactive_user_signin(self, client):
)
def test_unknown_user_logs_in(self, client):
- form_data = {
- "email": "user@test.com",
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD}
response = self.post_to_endpoint(client, form=form_data)
assert response.status_code == 200
@@ -235,12 +203,7 @@ def test_unknown_user_logs_in(self, client):
def test_user_without_password_logs_in(self, client):
user = users_factories.UserFactory(password=None, isActive=True)
- form_data = {
- "email": user.email,
- "password": settings.TEST_DEFAULT_PASSWORD,
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": user.email, "password": settings.TEST_DEFAULT_PASSWORD}
response = self.post_to_endpoint(client, form=form_data)
assert response.status_code == 200
@@ -248,10 +211,30 @@ def test_user_without_password_logs_in(self, client):
assert "Identifiant ou Mot de passe incorrect" in response_data
def test_user_logs_in_with_missing_fields(self, client):
- form_data = {
- "email": "user@test.com",
- "discord_id": "1234",
- "redirect_url": "https://test.com",
- }
+ form_data = {"email": "user@test.com"}
response = self.post_to_endpoint(client, form=form_data)
assert response.status_code == 200
+
+ response_data = response.data.decode("utf-8")
+ assert "Mot de passe : Information obligatoire" in response_data
+
+ @unittest.mock.patch(
+ "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token"
+ )
+ @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id")
+ @unittest.mock.patch(
+ "pcapi.routes.auth.discord.discord_connector.add_to_server",
+ side_effect=requests.exceptions.HTTPError(),
+ )
+ def test_error_adding_user_to_server_rollbacks(
+ self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client, db_session
+ ):
+ user = users_factories.BeneficiaryFactory()
+ discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False)
+
+ response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id)))
+
+ assert response.status_code == 303
+
+ db_session.refresh(discord_user)
+ assert discord_user.discordId is None