diff --git a/api/src/pcapi/connectors/discord.py b/api/src/pcapi/connectors/discord.py index 31af2664296..0cf11856f32 100644 --- a/api/src/pcapi/connectors/discord.py +++ b/api/src/pcapi/connectors/discord.py @@ -7,16 +7,30 @@ DISCORD_CLIENT_ID = "1261948740915433574" DISCORD_CLIENT_SECRET = settings.DISCORD_CLIENT_SECRET DISCORD_CALLBACK_URI = f"{settings.API_URL}/auth/discord/callback" -DISCORD_REDIRECT_SUCCESS = f"{settings.API_URL}/auth/discord/success" -DISCORD_FULL_REDIRECT_URI = ( - f"https://discord.com/api/oauth2/authorize" - f"?client_id={DISCORD_CLIENT_ID}" - f"&redirect_uri={DISCORD_CALLBACK_URI}" - f"&response_type=code" - f"&scope=identify%20guilds.join" -) DISCORD_HOME_URI = f"https://discord.com/channels/{DISCORD_GUILD_ID}/@home" +DISCORD_API_URI = "https://discord.com/api" + + +def build_discord_redirection_uri(user_id: int) -> str: + base_uri = f"{DISCORD_API_URI}/oauth2/authorize" + client_id = DISCORD_CLIENT_ID + redirect_uri = DISCORD_CALLBACK_URI + response_type = "code" + scope = "identify%20guilds.join" + + return f"{base_uri}?client_id={client_id}&redirect_uri={redirect_uri}&response_type={response_type}&scope={scope}&state={user_id}" + + +def get_user_id(access_token: str) -> str | None: + url = f"{DISCORD_API_URI}/oauth2/@me" + user_response = requests.get(url, headers={"Authorization": f"Bearer {access_token}"}) + user_response.raise_for_status() + try: + return user_response.json()["user"]["id"] + except KeyError: + return None + def retrieve_access_token(code: str) -> str | None: data = { @@ -34,19 +48,14 @@ def retrieve_access_token(code: str) -> str | None: return access_token -def add_to_server(access_token: str) -> None: +def add_to_server(access_token: str, user_discord_id: str) -> None: """ Adds the user to the pass culture discord server Our server is identified by the DISCORD_GUILD_ID """ - user_response = requests.get( - "https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"} - ) - user_response.raise_for_status() - - user_id = user_response.json()["id"] data = {"access_token": access_token} - url = f"https://discord.com/api/guilds/{DISCORD_GUILD_ID}/members/{user_id}" + url = f"https://discord.com/api/guilds/{DISCORD_GUILD_ID}/members/{user_discord_id}" headers = {"Authorization": f"Bot {DISCORD_BOT_TOKEN}", "Content-Type": "application/json"} - requests.put(url, json=data, headers=headers) + response = requests.put(url, json=data, headers=headers) + response.raise_for_status() diff --git a/api/src/pcapi/routes/auth/discord.py b/api/src/pcapi/routes/auth/discord.py index bf10d1849b1..ed496cd5568 100644 --- a/api/src/pcapi/routes/auth/discord.py +++ b/api/src/pcapi/routes/auth/discord.py @@ -4,11 +4,13 @@ from flask_wtf.csrf import CSRFProtect from werkzeug.wrappers.response import Response +from pcapi import repository from pcapi.connectors import discord as discord_connector from pcapi.core.users import exceptions as users_exceptions from pcapi.core.users import models as user_models from pcapi.core.users import repository as users_repo from pcapi.models import db +import pcapi.routes.auth.exceptions as auth_exceptions from pcapi.routes.auth.forms.forms import SigninForm from pcapi.utils import requests @@ -16,44 +18,95 @@ from . import utils +ERROR_STRING_PREFIX = "Erreur d'authentification Discord: " + + @blueprint.auth_blueprint.route("/discord/signin", methods=["GET"]) def discord_signin() -> str: form = SigninForm() - form.discord_id.data = request.args.get("discord_id") - form.redirect_url.data = discord_connector.DISCORD_FULL_REDIRECT_URI if error_message := request.args.get("error"): form.error_message = error_message return render_template("discord_signin.html", form=form) +def redirect_with_error(error_message: str) -> Response: + repository.mark_transaction_as_invalid() + return redirect(f"/auth/discord/signin?error={error_message}", code=303) + + +def handle_http_error(error: requests.exceptions.HTTPError) -> Response: + error_message = "" + if error.response: + error_message = error.response.json().get("error_description") + if not error_message: + error_message = error.response.text + error_message += "Tu peux réessayer ou contacter le support." + return redirect_with_error(ERROR_STRING_PREFIX + error_message) + + @blueprint.auth_blueprint.route("/discord/callback", methods=["GET"]) +@repository.atomic() def discord_call_back() -> str | Response | None: - # Webhook called by the discord server once the discord authentication is successful code = request.args.get("code") - ERROR_STRING_PREFIX = "Erreur d'authentification Discord: " + user_id = request.args.get("state") + if not code: - return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}code non récupéré", code=303) + return redirect_with_error(f"{ERROR_STRING_PREFIX}code non récupéré") + if not user_id: + return redirect_with_error(f"{ERROR_STRING_PREFIX}user_id pass Culture non récupéré") try: access_token = discord_connector.retrieve_access_token(code) except requests.exceptions.HTTPError as e: - return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}{e.response.json().get('error')}", code=303) + return handle_http_error(e) if not access_token: - return redirect(f"/auth/discord/signin?error={ERROR_STRING_PREFIX}access token non récupéré", code=303) + return redirect_with_error(f"{ERROR_STRING_PREFIX}access token non récupéré") + + try: + user_discord_id = discord_connector.get_user_id(access_token) + except requests.exceptions.HTTPError as e: + return handle_http_error(e) + + if not user_discord_id: + return redirect_with_error(f"{ERROR_STRING_PREFIX}discord id non récupéré") + + try: + update_discord_user(user_id, user_discord_id) + except auth_exceptions.DiscordUserAlreadyLinked: + return redirect_with_error("Ce compte Discord est déjà lié à un autre compte pass Culture.") + except auth_exceptions.UserNotAllowed: + return redirect_with_error("Accès refusé au serveur Discord. Contacte le support pour plus d'informations") try: - discord_connector.add_to_server(access_token) + discord_connector.add_to_server(access_token, user_discord_id) except requests.exceptions.HTTPError as e: - return redirect( - f"/auth/discord/signin?error={ERROR_STRING_PREFIX}{e.response.json().get('message')}", - code=303, - ) + return handle_http_error(e) return redirect(discord_connector.DISCORD_HOME_URI, code=303) +def update_discord_user(user_id: str, discord_id: str) -> None: + already_linked_user = user_models.DiscordUser.query.filter_by(discordId=discord_id).first() + if already_linked_user: + raise auth_exceptions.DiscordUserAlreadyLinked() + + user = user_models.User.query.get(user_id) + discord_user = user.discordUser + + if discord_user is None: + # We still add the user to the database even if he doesn't have access to the discord server + discord_user = user_models.DiscordUser(userId=user.id, discordId=discord_id, hasAccess=False) + db.session.add(discord_user) + raise auth_exceptions.UserNotAllowed() + + if not discord_user.hasAccess: + raise auth_exceptions.UserNotAllowed() + + discord_user.discordId = discord_id + + @blueprint.auth_blueprint.route("/discord/signin", methods=["POST"]) def discord_signin_post() -> str | Response | None: csrf = CSRFProtect() @@ -65,8 +118,6 @@ def discord_signin_post() -> str | Response | None: email = form.email.data password = form.password.data - discord_id = form.discord_id.data - url_redirection = form.redirect_url.data try: user = users_repo.get_user_with_credentials(email, password, allow_inactive=True) @@ -86,17 +137,5 @@ def discord_signin_post() -> str | Response | None: form.error_message = "Le compte a été anonymisé" return render_template("discord_signin.html", form=form) - discord_user = user.discordUser - if discord_user is None or not discord_user.hasAccess: - if discord_user is None: - discord_user = user_models.DiscordUser(userId=user.id, discordId=discord_id, hasAccess=False) - db.session.add(discord_user) - db.session.commit() - form.error_message = "Accès refusé au serveur Discord. Contacte le support pour plus d'informations" - return render_template("discord_signin.html", form=form) - if discord_user.is_active: - return redirect(url_redirection) - - discord_user.discordId = discord_id - db.session.commit() + url_redirection = discord_connector.build_discord_redirection_uri(user.id) return redirect(url_redirection) diff --git a/api/src/pcapi/routes/auth/exceptions.py b/api/src/pcapi/routes/auth/exceptions.py new file mode 100644 index 00000000000..44a4555c05d --- /dev/null +++ b/api/src/pcapi/routes/auth/exceptions.py @@ -0,0 +1,10 @@ +class DiscordUserAlreadyLinked(Exception): + pass + + +class UserNotAllowed(Exception): + pass + + +class DiscordException(Exception): + pass diff --git a/api/src/pcapi/routes/auth/forms/forms.py b/api/src/pcapi/routes/auth/forms/forms.py index 5b6f1b7839a..f3f124dc860 100644 --- a/api/src/pcapi/routes/auth/forms/forms.py +++ b/api/src/pcapi/routes/auth/forms/forms.py @@ -5,6 +5,4 @@ class SigninForm(PCForm): email = fields.PCEmailField("Adresse email") password = fields.PCPasswordField("Mot de passe") - discord_id = fields.PCHiddenField("discord_id") - redirect_url = fields.PCLongHiddenField("redirect_url") error_message: str = "" diff --git a/api/src/pcapi/routes/auth/templates/discord_signin.html b/api/src/pcapi/routes/auth/templates/discord_signin.html index 6f36723904c..e52f21c4fa2 100644 --- a/api/src/pcapi/routes/auth/templates/discord_signin.html +++ b/api/src/pcapi/routes/auth/templates/discord_signin.html @@ -20,8 +20,6 @@

Pour accéder au serveur Discord du pass Culture

data-turbo="false">
-
{{ form.redirect_url }}
-
{{ form.discord_id }}
{{ form.email }}
{{ form.password }}
diff --git a/api/tests/routes/auth/discord_test.py b/api/tests/routes/auth/discord_test.py index 3b66c49fc0b..0cf5972d58e 100644 --- a/api/tests/routes/auth/discord_test.py +++ b/api/tests/routes/auth/discord_test.py @@ -7,11 +7,13 @@ import pytest from pcapi import settings +from pcapi.connectors import discord as discord_connector from pcapi.core.history import factories as history_factories from pcapi.core.testing import assert_num_queries from pcapi.core.testing import override_settings from pcapi.core.users import constants as users_constants from pcapi.core.users import factories as users_factories +from pcapi.utils import requests pytestmark = pytest.mark.usefixtures("db_session") @@ -59,108 +61,94 @@ def post_to_endpoint( return client.post(url, form=form, headers=headers, follow_redirects=follow_redirects) + def test_build_discord_redirection_uri(self): + assert ( + discord_connector.build_discord_redirection_uri("1") + == f"https://discord.com/api/oauth2/authorize?client_id={discord_connector.DISCORD_CLIENT_ID}&redirect_uri={discord_connector.DISCORD_CALLBACK_URI}&response_type=code&scope=identify%20guilds.join&state=1" + ) + @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem) - def test_successful_discord_signing(self, client, db_session): - redirect_url = "https://test.com" + def test_redirect_to_discord_on_post(self, client): form_data = { "email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": redirect_url, } - user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True) - discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False) + user = users_factories.BeneficiaryFactory( + email=form_data["email"], password=form_data["password"], isActive=True + ) response = self.post_to_endpoint(client, form=form_data) assert response.status_code == 302 - assert response.location == redirect_url - - db_session.refresh(discord_user) - assert discord_user.discordId == form_data["discord_id"] + assert response.location == discord_connector.build_discord_redirection_uri(user.id) @unittest.mock.patch( - "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="discord_access_token" + "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token" ) + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id") @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server") - def test_discord_webhook(self, mock_add_to_server, mock_retrieve_access_token, client): - client.get(url_for("auth.discord_call_back", code="discord_code")) + def test_discord_webhook_success( + self, mock_add_to_server, mock_get_user_id, mock_retrieve_access_token, client, db_session + ): + user = users_factories.BeneficiaryFactory() + discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False) + + client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id))) assert mock_retrieve_access_token.call_count == 1 assert mock_retrieve_access_token.call_args[0][0] == "discord_code" - assert mock_add_to_server.call_count == 1 - assert mock_add_to_server.call_args[0][0] == "discord_access_token" - - @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem) - def test_has_access_is_false(self, client, db_session): - redirect_url = "https://test.com" - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": redirect_url, - } - - user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True) - discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=False, isBanned=False) - - response = self.post_to_endpoint(client, form=form_data) - - assert response.status_code == 200 - assert response.location is None + assert mock_get_user_id.call_count == 1 + assert mock_get_user_id.call_args[0][0] == "access_token" - response_data = response.data.decode("utf-8") - assert "Accès refusé au serveur Discord. Contacte le support pour plus d'informations" in response_data + assert mock_add_to_server.call_count == 1 + assert mock_add_to_server.call_args[0][0] == "access_token" + assert mock_add_to_server.call_args[0][1] == "discord_user_id" db_session.refresh(discord_user) - assert discord_user.discordId is None + assert discord_user.discordId == "discord_user_id" - @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem) - def test_discord_user_is_none(self, client): - redirect_url = "https://test.com" - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": redirect_url, - } - users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True) + @unittest.mock.patch( + "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token" + ) + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id") + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server") + def test_has_access_is_false(self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client): + user = users_factories.BeneficiaryFactory() + users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=False, isBanned=False) - response = self.post_to_endpoint(client, form=form_data) + response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id))) - assert response.status_code == 200 + expected_query_params = ( + "Accès refusé au serveur Discord. Contacte le support pour plus d'informations".replace(" ", "%20") + .replace("è", "%C3%A8") + .replace("é", "%C3%A9") + ) - assert response.status_code == 200 - assert response.location is None + assert response.status_code == 303 + assert f"/auth/discord/signin?error={expected_query_params}" in response.location - response_data = response.data.decode("utf-8") - assert "Accès refusé au serveur Discord. Contacte le support pour plus d'informations" in response_data + @unittest.mock.patch( + "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token" + ) + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id") + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.add_to_server") + def test_discord_user_is_none(self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client): + user = users_factories.BeneficiaryFactory() - @override_settings(DISCORD_JWT_PUBLIC_KEY=public_key_pem, DISCORD_JWT_PRIVATE_KEY=private_key_pem) - def test_discord_user_is_active(self, client): - redirect_url = "https://test.com" - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": redirect_url, - } - user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=True) - users_factories.DiscordUserFactory(user=user, hasAccess=True, isBanned=False) + response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id))) - response = self.post_to_endpoint(client, form=form_data) + expected_query_params = ( + "Accès refusé au serveur Discord. Contacte le support pour plus d'informations".replace(" ", "%20") + .replace("è", "%C3%A8") + .replace("é", "%C3%A9") + ) - assert response.status_code == 302 - assert response.location == redirect_url + assert response.status_code == 303 + assert f"/auth/discord/signin?error={expected_query_params}" in response.location def test_account_anonymized_user_request_account_state(self, client): - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD} users_factories.AnonymizedUserFactory( email=form_data["email"], password=form_data["password"], @@ -172,12 +160,7 @@ def test_account_anonymized_user_request_account_state(self, client): assert "Le compte a été anonymisé" in response_data def test_wrong_password(self, client): - form_data = { - "email": "user@test.com", - "password": "wrong_password", - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com", "password": "wrong_password"} users_factories.AnonymizedUserFactory( email=form_data["email"], password=settings.TEST_DEFAULT_PASSWORD, @@ -189,12 +172,7 @@ def test_wrong_password(self, client): assert "Identifiant ou Mot de passe incorrect" in response_data def test_account_deleted_account_state(self, client): - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD} user = users_factories.UserFactory(email=form_data["email"], password=form_data["password"], isActive=False) history_factories.SuspendedUserActionHistoryFactory(user=user, reason=users_constants.SuspensionReason.DELETED) @@ -205,12 +183,7 @@ def test_account_deleted_account_state(self, client): assert "Le compte a été supprimé" in response_data def test_inactive_user_signin(self, client): - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD} users_factories.BaseUserFactory(email=form_data["email"], password=form_data["password"]) response = self.post_to_endpoint(client, form=form_data) assert response.status_code == 200 @@ -221,12 +194,7 @@ def test_inactive_user_signin(self, client): ) def test_unknown_user_logs_in(self, client): - form_data = { - "email": "user@test.com", - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com", "password": settings.TEST_DEFAULT_PASSWORD} response = self.post_to_endpoint(client, form=form_data) assert response.status_code == 200 @@ -235,12 +203,7 @@ def test_unknown_user_logs_in(self, client): def test_user_without_password_logs_in(self, client): user = users_factories.UserFactory(password=None, isActive=True) - form_data = { - "email": user.email, - "password": settings.TEST_DEFAULT_PASSWORD, - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": user.email, "password": settings.TEST_DEFAULT_PASSWORD} response = self.post_to_endpoint(client, form=form_data) assert response.status_code == 200 @@ -248,10 +211,30 @@ def test_user_without_password_logs_in(self, client): assert "Identifiant ou Mot de passe incorrect" in response_data def test_user_logs_in_with_missing_fields(self, client): - form_data = { - "email": "user@test.com", - "discord_id": "1234", - "redirect_url": "https://test.com", - } + form_data = {"email": "user@test.com"} response = self.post_to_endpoint(client, form=form_data) assert response.status_code == 200 + + response_data = response.data.decode("utf-8") + assert "Mot de passe : Information obligatoire" in response_data + + @unittest.mock.patch( + "pcapi.routes.auth.discord.discord_connector.retrieve_access_token", return_value="access_token" + ) + @unittest.mock.patch("pcapi.routes.auth.discord.discord_connector.get_user_id", return_value="discord_user_id") + @unittest.mock.patch( + "pcapi.routes.auth.discord.discord_connector.add_to_server", + side_effect=requests.exceptions.HTTPError(), + ) + def test_error_adding_user_to_server_rollbacks( + self, _mock_add_to_server, _mock_get_user_id, _mock_retrieve_access_token, client, db_session + ): + user = users_factories.BeneficiaryFactory() + discord_user = users_factories.DiscordUserFactory(user=user, discordId=None, hasAccess=True, isBanned=False) + + response = client.get(url_for("auth.discord_call_back", code="discord_code", state=str(user.id))) + + assert response.status_code == 303 + + db_session.refresh(discord_user) + assert discord_user.discordId is None