Skip to content

Commit

Permalink
Redesign connection cache to clear connection after disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
bhirsz committed Nov 8, 2023
1 parent 69dc9b1 commit da4203c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/DatabaseLibrary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DatabaseLibrary(ConnectionManager, Query, Assertion):
The library is basically compatible with any [https://peps.python.org/pep-0249|Python Database API Specification 2.0] module.
However, the actual implementation in existing Python modules is sometimes quite different, which requires custom handling in the library.
Therefore there are some modules, which are "natively" supported in the library - and others, which may work and may not.
Therefore, there are some modules, which are "natively" supported in the library - and others, which may work and may not.
See more on the [https://github.com/MarketSquare/Robotframework-Database-Library|project page on GitHub].
"""
Expand Down
39 changes: 31 additions & 8 deletions src/DatabaseLibrary/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class Assertion:
Assertion handles all the assertions of Database Library.
"""

def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def check_if_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default"
):
"""
Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will
throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction
Expand Down Expand Up @@ -52,7 +54,9 @@ def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = Fal
msg or f"Expected to have have at least one row, but got 0 rows from: '{selectStatement}'"
)

def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def check_if_not_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default"
):
"""
This is the negation of `check_if_exists_in_database`.
Expand Down Expand Up @@ -86,7 +90,9 @@ def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool =
msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}"
)

def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def row_count_is_0(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default"
):
"""
Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an
AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or
Expand Down Expand Up @@ -117,7 +123,12 @@ def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Opti
raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'")

def row_count_is_equal_to_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this
Expand Down Expand Up @@ -152,7 +163,12 @@ def row_count_is_equal_to_x(
)

def row_count_is_greater_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then
Expand Down Expand Up @@ -187,7 +203,12 @@ def row_count_is_greater_than_x(
)

def row_count_is_less_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this
Expand Down Expand Up @@ -221,7 +242,9 @@ def row_count_is_less_than_x(
msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
)

def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def table_must_exist(
self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: str = "default"
):
"""
Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument.
Expand All @@ -243,7 +266,7 @@ def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional
| Table Must Exist | first_name | msg=my error message |
"""
logger.info(f"Executing : Table Must Exist | {tableName}")
_, db_api_module_name = self._cache.switch(alias)
_, db_api_module_name = self._get_connection_with_alias(alias)
if db_api_module_name in ["cx_Oracle", "oracledb"]:
query = (
"SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND "
Expand Down
54 changes: 33 additions & 21 deletions src/DatabaseLibrary/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.

import importlib
from typing import Optional
from typing import Dict, Optional

try:
import ConfigParser
except:
import configparser as ConfigParser

from robot.api import logger
from robot.utils import ConnectionCache


class ConnectionManager:
Expand All @@ -30,8 +29,15 @@ class ConnectionManager:
"""

def __init__(self):
self.omit_trailing_semicolon = False
self._cache = ConnectionCache("No sessions created")
self.omit_trailing_semicolon: bool = False
self._connections: Dict = {}
self.default_alias: str = "default"

def _register_connection(self, conn_details, alias: str):
if alias in self._connections:
logger.warn(f"Overwriting not closed connection for alias = '{alias}'")
self._connections[alias] = conn_details
self.default_alias = alias

def connect_to_database(
self,
Expand All @@ -45,7 +51,7 @@ def connect_to_database(
dbDriver: Optional[str] = None,
dbConfigFile: Optional[str] = None,
driverMode: Optional[str] = None,
alias: Optional[str] = "default",
alias: str = "default",
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand Down Expand Up @@ -261,10 +267,10 @@ def connect_to_database(
host=dbHost,
port=dbPort,
)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection((db_connection, db_api_module_name), alias)

def connect_to_database_using_custom_params(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default"
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand Down Expand Up @@ -299,10 +305,10 @@ def connect_to_database_using_custom_params(
)

db_connection = eval(db_connect_string)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection((db_connection, db_api_module_name), alias)

def connect_to_database_using_custom_connection_string(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default"
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand All @@ -323,9 +329,9 @@ def connect_to_database_using_custom_connection_string(
f"'{db_connect_string}')"
)
db_connection = db_api_2.connect(db_connect_string)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection((db_connection, db_api_module_name), alias)

def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = "default"):
def disconnect_from_database(self, error_if_no_connection: bool = False, alias: str = "default"):
"""
Disconnects from the database.
Expand All @@ -339,12 +345,9 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias:
"""
logger.info("Executing : Disconnect From Database")
try:
db_connection, _ = self._cache.switch(alias)
except RuntimeError: # Non-existing index or alias
db_connection = None
if db_connection:
db_connection, _ = self._connections.pop(alias)
db_connection.close()
else:
except KeyError: # Non-existing alias
log_msg = "No open database connection to close"
if error_if_no_connection:
raise ConnectionError(log_msg) from None
Expand All @@ -358,12 +361,12 @@ def disconnect_from_all_databases(self):
| Disconnect From All Databases | # disconnects from all connections to the database |
"""
logger.info("Executing : Disconnect From All Databases")
for db_connection, _ in self._cache:
for db_connection, _ in self._connections.values():
if db_connection:
db_connection.close()
self._cache.empty_cache()
self._connections = {}

def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "default"):
def set_auto_commit(self, autoCommit: bool = True, alias: str = "default"):
"""
Turn the autocommit on the database connection ON or OFF.
Expand All @@ -381,7 +384,7 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "defau
| Set Auto Commit | False
"""
logger.info("Executing : Set Auto Commit")
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
db_connection.autocommit = autoCommit

def switch_database(self, alias):
Expand All @@ -392,4 +395,13 @@ def switch_database(self, alias):
| Switch Database | my_alias |
| Switch Database | alias=my_alias |
"""
self._cache.switch(alias)
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
self.default_alias = alias

def _get_connection_with_alias(self, alias: str):
if alias not in self._connections:
if alias == "default":
raise ValueError(f"No database connection is open.")
raise ValueError(f"Alias '{alias}' not found in existing connections.")
return self._connections[alias]
30 changes: 16 additions & 14 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Query:
Query handles all the querying done by the Database Library.
"""

def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None):
def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: str = "default"):
"""
Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional
input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -59,7 +59,7 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Query | SELECT * FROM person | True |
"""
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
cur = None
try:
cur = db_connection.cursor()
Expand All @@ -74,7 +74,7 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool
if cur and not sansTran:
db_connection.rollback()

def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def row_count(self, selectStatement: str, sansTran: bool = False, alias: str = "default"):
"""
Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set
optional input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -102,7 +102,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| ${rowCount} | Row Count | SELECT * FROM person | True |
"""
db_connection, db_api_module_name = self._cache.switch(alias)
db_connection, db_api_module_name = self._get_connection_with_alias(alias)
cur = None
try:
cur = db_connection.cursor()
Expand All @@ -116,7 +116,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
if cur and not sansTran:
db_connection.rollback()

def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def description(self, selectStatement: str, sansTran: bool = False, alias: str = "default"):
"""
Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set
optional input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand All @@ -138,7 +138,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Description | SELECT * FROM person | True |
"""
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
cur = None
try:
cur = db_connection.cursor()
Expand All @@ -153,7 +153,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
if cur and not sansTran:
db_connection.rollback()

def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: Optional[str] = None):
def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: str = "default"):
"""
Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback.
Expand All @@ -173,7 +173,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Delete All Rows From Table | person | True |
"""
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
cur = None
query = f"DELETE FROM {tableName}"
try:
Expand All @@ -190,7 +190,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali
if cur and not sansTran:
db_connection.rollback()

def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: Optional[str] = None):
def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: str = "default"):
"""
Executes the content of the `sqlScriptFileName` as SQL commands. Useful for setting the database to a known
state before running your tests, or clearing out your test data after running each a test. Set optional input
Expand Down Expand Up @@ -249,7 +249,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True |
"""
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
with open(sqlScriptFileName, encoding="UTF-8") as sql_file:
cur = None
try:
Expand Down Expand Up @@ -315,7 +315,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
if cur and not sansTran:
db_connection.rollback()

def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None):
def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: str = "default"):
"""
Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to
True to run command without an explicit transaction commit or rollback.
Expand All @@ -332,7 +332,7 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True |
"""
db_connection, _ = self._cache.switch(alias)
db_connection, _ = self._get_connection_with_alias(alias)
cur = None
try:
cur = db_connection.cursor()
Expand All @@ -344,7 +344,9 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti
if cur and not sansTran:
db_connection.rollback()

def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: Optional[str] = None):
def call_stored_procedure(
self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: str = "default"
):
"""
Calls a stored procedure `spName` with the `spParams` - a *list* of parameters the procedure requires.
Use the special *CURSOR* value for OUT params, which should receive result sets -
Expand Down Expand Up @@ -383,7 +385,7 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True |
"""
db_connection, db_api_module_name = self._cache.switch(alias)
db_connection, db_api_module_name = self._get_connection_with_alias(alias)
if spParams is None:
spParams = []
cur = None
Expand Down
2 changes: 1 addition & 1 deletion test/tests/common_tests/aliased_connection.robot
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Switch Not Existing Alias

Execute SQL Script - Insert Data In Person table
Connect To DB alias=aliased_conn
${output}= Execute SQL Script ${CURDIR}/../insert_data_in_person_table.sql alias=aliased_conn
${output}= Execute SQL Script ../../resources/insert_data_in_person_table.sql alias=aliased_conn
Should Be Equal As Strings ${output} None

Check If Exists In DB - Franz Allan
Expand Down

0 comments on commit da4203c

Please sign in to comment.