diff --git a/web/server/codechecker_server/api/authentication.py b/web/server/codechecker_server/api/authentication.py index e4f9629495..2186e37941 100644 --- a/web/server/codechecker_server/api/authentication.py +++ b/web/server/codechecker_server/api/authentication.py @@ -9,6 +9,10 @@ Handle Thrift requests for authentication. """ +import datetime +import sqlite3 +import os + from authlib.integrations.requests_client import OAuth2Session from authlib.common.security import generate_token from urllib.parse import urlparse, parse_qs @@ -48,6 +52,8 @@ def __init__(self, manager, auth_session, config_database): self.__manager = manager self.__auth_session = auth_session self.__config_db = config_database + self.__db_path = os.path.expanduser( + '~/.codechecker/state_codes.sqlite') def __require_privilaged_access(self): """ @@ -146,6 +152,71 @@ def getAccessControl(self): globalPermissions=global_permissions, productPermissions=product_permissions) + @timeit + def createdatabase(self): + """ + Create the SQLite database for storing the state codes + """ + + # Check if the database file exists + if os.path.exists(self.__db_path): + LOG.debug(f"Database of states {self.__db_path} already exists.") + return + + # Create the database and the table + # Create the database and the table + try: + conn = sqlite3.connect(self.__db_path) + conn.execute( + "CREATE TABLE state_codes (" + "ID INTEGER PRIMARY KEY AUTOINCREMENT, " + "state TEXT, " + "expires_at DATETIME)" + ) + conn.close() + LOG.debug("successfully created" + f" Database of states {self.__db_path}") + except sqlite3.Error as e: + LOG.error(f"An error occurred: {e}") + + @timeit + def insertState(self, state): + """ + Insert the state code into the database + """ + + # remove all the expired state codes from the database + try: + date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conn = sqlite3.connect(self.__db_path) + conn.execute("DELETE FROM state_codes " + "WHERE expires_at < DATETIME(\"" + date + "\")") + conn.commit() + conn.close() + except sqlite3.Error as e: + LOG.error(f"An error occurred: {e}") + + # Insert the state code into the database + try: + date = (datetime.datetime.now() + datetime.timedelta(minutes=15)) \ + .strftime("%Y-%m-%d %H:%M:%S") + conn = sqlite3.connect(self.__db_path) + # Insert the state code into the database + conn.execute("INSERT INTO state_codes (state, expires_at) " + "VALUES (?, ?)", (state, date)) + conn.commit() + state_id = conn.execute("SELECT ID FROM state_codes " + "WHERE state = ? AND expires_at = ?", + (state, date)).fetchone()[0] + conn.close() + LOG.debug(f"State {state[0]} inserted successfully.") + return state_id + except sqlite3.Error as e: + LOG.error(f"An error occurred: {e}") + raise codechecker_api_shared.ttypes.RequestFailed( + codechecker_api_shared.ttypes.ErrorCode.AUTH_DENIED, + "STATE insertion failed.") + @timeit def getOauthProviders(self): return self.__manager.get_oauth_providers() @@ -155,6 +226,13 @@ def createLink(self, provider): """ For creating a autehntication link for OAuth for specified provider """ + try: + self.createdatabase() + except Exception as ex: + LOG.error("Database creation failed: %s", str(ex)) + raise codechecker_api_shared.ttypes.RequestFailed( + codechecker_api_shared.ttypes.ErrorCode.AUTH_DENIED, + "Database creation failed.") oauth_config = self.__manager.get_oauth_config(provider) if not oauth_config.get('enabled'): raise codechecker_api_shared.ttypes.RequestFailed( @@ -177,12 +255,25 @@ def createLink(self, provider): # Create authorization URL nonce = generate_token() - url = session.create_authorization_url( - authorization_uri, nonce=nonce, state=stored_state)[0] - return url + url, state = session.create_authorization_url( + authorization_uri, nonce=nonce, state=stored_state) + + # Save the state and nonce to the database + state_id = self.insertState(state) + if not state_id: + raise codechecker_api_shared.ttypes.RequestFailed( + codechecker_api_shared.ttypes.ErrorCode.AUTH_DENIED, + "State code insertion failed.") + + LOG.debug(f"State {state} inserted successfully with ID {state_id}") + return url + "&state_id=" + str(state_id) @timeit def performLogin(self, auth_method, auth_string): + print("**********************") + print(auth_method, auth_string) + print("**********************") + if not auth_string: raise codechecker_api_shared.ttypes.RequestFailed( codechecker_api_shared.ttypes.ErrorCode.AUTH_DENIED, @@ -207,8 +298,32 @@ def performLogin(self, auth_method, auth_string): msg) elif auth_method == "oauth": + provider, url = auth_string.split("@") + url_new = urlparse(url) + parsed_query = parse_qs(url_new.query) + + code = parsed_query.get("code")[0] + state = parsed_query.get("state")[0] + state_id = parsed_query.get("state_id")[0] + + conn = sqlite3.connect(self.__db_path) + state_db = conn.execute("SELECT state " + "FROM state_codes " + "WHERE ID = " + state_id).fetchone()[0] + + # Delete the state from the database + conn.execute('DELETE FROM state_codes WHERE ID = ' + state_id) + conn.close() + + if state_db != state: + LOG.error("State code mismatch.") + raise codechecker_api_shared.ttypes.RequestFailed( + codechecker_api_shared.ttypes.ErrorCode.AUTH_DENIED, + "State code mismatch") + LOG.info("State code matched.") + oauth_config = self.__manager.get_oauth_config(provider) if not oauth_config.get('enabled'): LOG.error("OAuth authentication is " + @@ -233,6 +348,7 @@ def performLogin(self, auth_method, auth_string): client_secret, scope=scope, redirect_uri=redirect_uri) + except Exception as ex: LOG.error("OAuth2Session creation failed: %s", str(ex)) raise codechecker_api_shared.ttypes.RequestFailed( @@ -242,12 +358,6 @@ def performLogin(self, auth_method, auth_string): # FIXME: This is a workaround for the Microsoft OAuth2 provider # which doesn't correctly fetch the code from url. - url_new = urlparse(url) - parsed_query = parse_qs(url_new.query) - - code = parsed_query.get("code")[0] - state = parsed_query.get("state")[0] - url = url_new.scheme + "://" + url_new.netloc + url_new.path + \ "?code=" + code + "&state=" + state diff --git a/web/server/codechecker_server/session_manager.py b/web/server/codechecker_server/session_manager.py index 70d9a902bb..c2bd1bb226 100644 --- a/web/server/codechecker_server/session_manager.py +++ b/web/server/codechecker_server/session_manager.py @@ -682,16 +682,14 @@ def create_session(self, auth_string): return local_session def create_session_oauth(self, provider, username, token): - """ Creates a new session for the given auth-string. """ + """ + Creates a new session for the given auth-string + if the provider is enabled for OAuth authentication. + """ if not self.__is_method_enabled('oauth'): return False - # Try to get the user's previous session. - for sess in self.__sessions: - if sess.user == username: - return sess - providers = self.__auth_config.get( 'method_oauth', {}).get("providers", {}) diff --git a/web/server/vue-cli/src/views/Login.vue b/web/server/vue-cli/src/views/Login.vue index ebf5118a75..f2766b82e5 100644 --- a/web/server/vue-cli/src/views/Login.vue +++ b/web/server/vue-cli/src/views/Login.vue @@ -206,11 +206,12 @@ export default { return; } + const state_id = localStorage.getItem("state_id"); this.$store .dispatch(LOGIN, { type: "oauth", provider: provider, - url: window.location.href + url: window.location.href + "&state_id=" + state_id }) .then(() => { this.success = true; @@ -255,9 +256,10 @@ export default { const params = new URLSearchParams(url); localStorage.setItem("oauth_state", params.get("state")); + localStorage.setItem("state_id", + params.get("state_id")); window.location.href = url; - this.link = url; } else { this.errorMsg = `Server returned an invalid URL: ${url}`; this.error = true;