From da4203c754668dcd0cd6361c7ff7334b1aa0219b Mon Sep 17 00:00:00 2001 From: Bartlomiej Hirsz Date: Wed, 8 Nov 2023 10:58:02 +0100 Subject: [PATCH] Redesign connection cache to clear connection after disconnect --- src/DatabaseLibrary/__init__.py | 2 +- src/DatabaseLibrary/assertion.py | 39 +++++++++++--- src/DatabaseLibrary/connection_manager.py | 54 +++++++++++-------- src/DatabaseLibrary/query.py | 30 ++++++----- .../common_tests/aliased_connection.robot | 2 +- 5 files changed, 82 insertions(+), 45 deletions(-) diff --git a/src/DatabaseLibrary/__init__.py b/src/DatabaseLibrary/__init__.py index efb757f1..cd70e0b4 100644 --- a/src/DatabaseLibrary/__init__.py +++ b/src/DatabaseLibrary/__init__.py @@ -69,7 +69,7 @@ class DatabaseLibrary(ConnectionManager, Query, Assertion): The library is basically compatible with any [https://peps.python.org/pep-0249|Python Database API Specification 2.0] module. However, the actual implementation in existing Python modules is sometimes quite different, which requires custom handling in the library. - Therefore there are some modules, which are "natively" supported in the library - and others, which may work and may not. + Therefore, there are some modules, which are "natively" supported in the library - and others, which may work and may not. See more on the [https://github.com/MarketSquare/Robotframework-Database-Library|project page on GitHub]. """ diff --git a/src/DatabaseLibrary/assertion.py b/src/DatabaseLibrary/assertion.py index ea586943..5bff3262 100644 --- a/src/DatabaseLibrary/assertion.py +++ b/src/DatabaseLibrary/assertion.py @@ -21,7 +21,9 @@ class Assertion: Assertion handles all the assertions of Database Library. """ - def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def check_if_exists_in_database( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default" + ): """ Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction @@ -52,7 +54,9 @@ def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = Fal msg or f"Expected to have have at least one row, but got 0 rows from: '{selectStatement}'" ) - def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def check_if_not_exists_in_database( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default" + ): """ This is the negation of `check_if_exists_in_database`. @@ -86,7 +90,9 @@ def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}" ) - def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def row_count_is_0( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default" + ): """ Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or @@ -117,7 +123,12 @@ def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Opti raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'") def row_count_is_equal_to_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this @@ -152,7 +163,12 @@ def row_count_is_equal_to_x( ) def row_count_is_greater_than_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then @@ -187,7 +203,12 @@ def row_count_is_greater_than_x( ) def row_count_is_less_than_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this @@ -221,7 +242,9 @@ def row_count_is_less_than_x( msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'" ) - def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def table_must_exist( + self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default" + ): """ Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument. @@ -243,7 +266,7 @@ def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional | Table Must Exist | first_name | msg=my error message | """ logger.info(f"Executing : Table Must Exist | {tableName}") - _, db_api_module_name = self._cache.switch(alias) + _, db_api_module_name = self._get_connection_with_alias(alias) if db_api_module_name in ["cx_Oracle", "oracledb"]: query = ( "SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND " diff --git a/src/DatabaseLibrary/connection_manager.py b/src/DatabaseLibrary/connection_manager.py index 13259b05..000e21b9 100644 --- a/src/DatabaseLibrary/connection_manager.py +++ b/src/DatabaseLibrary/connection_manager.py @@ -13,7 +13,7 @@ # limitations under the License. import importlib -from typing import Optional +from typing import Dict, Optional try: import ConfigParser @@ -21,7 +21,6 @@ import configparser as ConfigParser from robot.api import logger -from robot.utils import ConnectionCache class ConnectionManager: @@ -30,8 +29,15 @@ class ConnectionManager: """ def __init__(self): - self.omit_trailing_semicolon = False - self._cache = ConnectionCache("No sessions created") + self.omit_trailing_semicolon: bool = False + self._connections: Dict = {} + self.default_alias: str = "default" + + def _register_connection(self, conn_details, alias: str): + if alias in self._connections: + logger.warn(f"Overwriting not closed connection for alias = '{alias}'") + self._connections[alias] = conn_details + self.default_alias = alias def connect_to_database( self, @@ -45,7 +51,7 @@ def connect_to_database( dbDriver: Optional[str] = None, dbConfigFile: Optional[str] = None, driverMode: Optional[str] = None, - alias: Optional[str] = "default", + alias: str = "default", ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -261,10 +267,10 @@ def connect_to_database( host=dbHost, port=dbPort, ) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection((db_connection, db_api_module_name), alias) def connect_to_database_using_custom_params( - self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default" + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -299,10 +305,10 @@ def connect_to_database_using_custom_params( ) db_connection = eval(db_connect_string) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection((db_connection, db_api_module_name), alias) def connect_to_database_using_custom_connection_string( - self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default" + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -323,9 +329,9 @@ def connect_to_database_using_custom_connection_string( f"'{db_connect_string}')" ) db_connection = db_api_2.connect(db_connect_string) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection((db_connection, db_api_module_name), alias) - def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = "default"): + def disconnect_from_database(self, error_if_no_connection: bool = False, alias: str = "default"): """ Disconnects from the database. @@ -339,12 +345,9 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias: """ logger.info("Executing : Disconnect From Database") try: - db_connection, _ = self._cache.switch(alias) - except RuntimeError: # Non-existing index or alias - db_connection = None - if db_connection: + db_connection, _ = self._connections.pop(alias) db_connection.close() - else: + except KeyError: # Non-existing alias log_msg = "No open database connection to close" if error_if_no_connection: raise ConnectionError(log_msg) from None @@ -358,12 +361,12 @@ def disconnect_from_all_databases(self): | Disconnect From All Databases | # disconnects from all connections to the database | """ logger.info("Executing : Disconnect From All Databases") - for db_connection, _ in self._cache: + for db_connection, _ in self._connections.values(): if db_connection: db_connection.close() - self._cache.empty_cache() + self._connections = {} - def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "default"): + def set_auto_commit(self, autoCommit: bool = True, alias: str = "default"): """ Turn the autocommit on the database connection ON or OFF. @@ -381,7 +384,7 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "defau | Set Auto Commit | False """ logger.info("Executing : Set Auto Commit") - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) db_connection.autocommit = autoCommit def switch_database(self, alias): @@ -392,4 +395,13 @@ def switch_database(self, alias): | Switch Database | my_alias | | Switch Database | alias=my_alias | """ - self._cache.switch(alias) + if alias not in self._connections: + raise ValueError(f"Alias '{alias}' not found in existing connections.") + self.default_alias = alias + + def _get_connection_with_alias(self, alias: str): + if alias not in self._connections: + if alias == "default": + raise ValueError(f"No database connection is open.") + raise ValueError(f"Alias '{alias}' not found in existing connections.") + return self._connections[alias] diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index a59ba536..3871c36f 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -24,7 +24,7 @@ class Query: Query handles all the querying done by the Database Library. """ - def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None): + def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: str = "default"): """ Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -59,7 +59,7 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{queryResults} | Query | SELECT * FROM person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) cur = None try: cur = db_connection.cursor() @@ -74,7 +74,7 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool if cur and not sansTran: db_connection.rollback() - def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None): + def row_count(self, selectStatement: str, sansTran: bool = False, alias: str = "default"): """ Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -102,7 +102,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona Using optional `sansTran` to run command without an explicit transaction commit or rollback: | ${rowCount} | Row Count | SELECT * FROM person | True | """ - db_connection, db_api_module_name = self._cache.switch(alias) + db_connection, db_api_module_name = self._get_connection_with_alias(alias) cur = None try: cur = db_connection.cursor() @@ -116,7 +116,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona if cur and not sansTran: db_connection.rollback() - def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None): + def description(self, selectStatement: str, sansTran: bool = False, alias: str = "default"): """ Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -138,7 +138,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{queryResults} | Description | SELECT * FROM person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) cur = None try: cur = db_connection.cursor() @@ -153,7 +153,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio if cur and not sansTran: db_connection.rollback() - def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: Optional[str] = None): + def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: str = "default"): """ Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -173,7 +173,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Delete All Rows From Table | person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) cur = None query = f"DELETE FROM {tableName}" try: @@ -190,7 +190,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali if cur and not sansTran: db_connection.rollback() - def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: Optional[str] = None): + def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: str = "default"): """ Executes the content of the `sqlScriptFileName` as SQL commands. Useful for setting the database to a known state before running your tests, or clearing out your test data after running each a test. Set optional input @@ -249,7 +249,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) with open(sqlScriptFileName, encoding="UTF-8") as sql_file: cur = None try: @@ -315,7 +315,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali if cur and not sansTran: db_connection.rollback() - def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None): + def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: str = "default"): """ Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -332,7 +332,7 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection, _ = self._get_connection_with_alias(alias) cur = None try: cur = db_connection.cursor() @@ -344,7 +344,9 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti if cur and not sansTran: db_connection.rollback() - def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: Optional[str] = None): + def call_stored_procedure( + self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: str = "default" + ): """ Calls a stored procedure `spName` with the `spParams` - a *list* of parameters the procedure requires. Use the special *CURSOR* value for OUT params, which should receive result sets - @@ -383,7 +385,7 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True | """ - db_connection, db_api_module_name = self._cache.switch(alias) + db_connection, db_api_module_name = self._get_connection_with_alias(alias) if spParams is None: spParams = [] cur = None diff --git a/test/tests/common_tests/aliased_connection.robot b/test/tests/common_tests/aliased_connection.robot index 29b89026..1ff6ee64 100644 --- a/test/tests/common_tests/aliased_connection.robot +++ b/test/tests/common_tests/aliased_connection.robot @@ -45,7 +45,7 @@ Switch Not Existing Alias Execute SQL Script - Insert Data In Person table Connect To DB alias=aliased_conn - ${output}= Execute SQL Script ${CURDIR}/../insert_data_in_person_table.sql alias=aliased_conn + ${output}= Execute SQL Script ../../resources/insert_data_in_person_table.sql alias=aliased_conn Should Be Equal As Strings ${output} None Check If Exists In DB - Franz Allan