diff --git a/api/app_factory.py b/api/app_factory.py index 60a584798b608a..46a101c4ab3ecf 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -15,6 +15,7 @@ from flask import Flask, Response, request from flask_cors import CORS +from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized import contexts @@ -120,11 +121,17 @@ def load_user_from_request(request_from_flask_login): user_id = decoded.get("user_id") logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) return logged_in_account +@user_logged_in.connect +@user_loaded_from_request.connect +def on_user_logged_in(_sender, user): + """Called when a user logged in.""" + if user: + contexts.tenant_id.set(user.current_tenant_id) + + @login_manager.unauthorized_handler def unauthorized_handler(): """Handle unauthorized requests."""