Skip to content

Commit

Permalink
Redesign connection cache to clear connection after disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
bhirsz committed Nov 8, 2023
1 parent 69dc9b1 commit cc846ac
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 107 deletions.
2 changes: 1 addition & 1 deletion src/DatabaseLibrary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DatabaseLibrary(ConnectionManager, Query, Assertion):
The library is basically compatible with any [https://peps.python.org/pep-0249|Python Database API Specification 2.0] module.
However, the actual implementation in existing Python modules is sometimes quite different, which requires custom handling in the library.
Therefore there are some modules, which are "natively" supported in the library - and others, which may work and may not.
Therefore, there are some modules, which are "natively" supported in the library - and others, which may work and may not.
See more on the [https://github.com/MarketSquare/Robotframework-Database-Library|project page on GitHub].
"""
Expand Down
47 changes: 35 additions & 12 deletions src/DatabaseLibrary/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class Assertion:
Assertion handles all the assertions of Database Library.
"""

def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def check_if_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
):
"""
Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will
throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction
Expand Down Expand Up @@ -52,7 +54,9 @@ def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = Fal
msg or f"Expected to have have at least one row, but got 0 rows from: '{selectStatement}'"
)

def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def check_if_not_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
):
"""
This is the negation of `check_if_exists_in_database`.
Expand Down Expand Up @@ -86,7 +90,9 @@ def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool =
msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}"
)

def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def row_count_is_0(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
):
"""
Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an
AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or
Expand Down Expand Up @@ -117,7 +123,12 @@ def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Opti
raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'")

def row_count_is_equal_to_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this
Expand Down Expand Up @@ -152,7 +163,12 @@ def row_count_is_equal_to_x(
)

def row_count_is_greater_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then
Expand Down Expand Up @@ -187,7 +203,12 @@ def row_count_is_greater_than_x(
)

def row_count_is_less_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
numRows: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: str = "default",
):
"""
Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this
Expand Down Expand Up @@ -221,7 +242,9 @@ def row_count_is_less_than_x(
msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
)

def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None):
def table_must_exist(
self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
):
"""
Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument.
Expand All @@ -243,20 +266,20 @@ def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional
| Table Must Exist | first_name | msg=my error message |
"""
logger.info(f"Executing : Table Must Exist | {tableName}")
_, db_api_module_name = self._cache.switch(alias)
if db_api_module_name in ["cx_Oracle", "oracledb"]:
db_connection = self._get_connection_with_alias(alias)
if db_connection.module_name in ["cx_Oracle", "oracledb"]:
query = (
"SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND "
f"owner = SYS_CONTEXT('USERENV', 'SESSION_USER') AND object_name = UPPER('{tableName}')"
)
table_exists = self.row_count(query, sansTran, alias=alias) > 0
elif db_api_module_name in ["sqlite3"]:
elif db_connection.module_name in ["sqlite3"]:
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{tableName}' COLLATE NOCASE"
table_exists = self.row_count(query, sansTran, alias=alias) > 0
elif db_api_module_name in ["ibm_db", "ibm_db_dbi"]:
elif db_connection.module_name in ["ibm_db", "ibm_db_dbi"]:
query = f"SELECT name FROM SYSIBM.SYSTABLES WHERE type='T' AND name=UPPER('{tableName}')"
table_exists = self.row_count(query, sansTran, alias=alias) > 0
elif db_api_module_name in ["teradata"]:
elif db_connection.module_name in ["teradata"]:
query = f"SELECT TableName FROM DBC.TablesV WHERE TableKind='T' AND TableName='{tableName}'"
table_exists = self.row_count(query, sansTran, alias=alias) > 0
else:
Expand Down
81 changes: 55 additions & 26 deletions src/DatabaseLibrary/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@
# limitations under the License.

import importlib
from typing import Optional
from dataclasses import dataclass
from typing import Any, Dict, Optional

try:
import ConfigParser
except:
import configparser as ConfigParser

from robot.api import logger
from robot.utils import ConnectionCache


@dataclass
class Connection:
client: Any
module_name: str


class ConnectionManager:
Expand All @@ -30,8 +36,14 @@ class ConnectionManager:
"""

def __init__(self):
self.omit_trailing_semicolon = False
self._cache = ConnectionCache("No sessions created")
self.omit_trailing_semicolon: bool = False
self._connections: Dict[str, Connection] = {}
self.default_alias: str = "default"

def _register_connection(self, client: Any, module_name: str, alias: str):
if alias in self._connections:
logger.warn(f"Overwriting not closed connection for alias = '{alias}'")
self._connections[alias] = Connection(client, module_name)

def connect_to_database(
self,
Expand All @@ -45,7 +57,7 @@ def connect_to_database(
dbDriver: Optional[str] = None,
dbConfigFile: Optional[str] = None,
driverMode: Optional[str] = None,
alias: Optional[str] = "default",
alias: str = "default",
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand Down Expand Up @@ -261,10 +273,10 @@ def connect_to_database(
host=dbHost,
port=dbPort,
)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection(db_connection, db_api_module_name, alias)

def connect_to_database_using_custom_params(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default"
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand Down Expand Up @@ -299,10 +311,10 @@ def connect_to_database_using_custom_params(
)

db_connection = eval(db_connect_string)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection(db_connection, db_api_module_name, alias)

def connect_to_database_using_custom_connection_string(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default"
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand All @@ -323,9 +335,9 @@ def connect_to_database_using_custom_connection_string(
f"'{db_connect_string}')"
)
db_connection = db_api_2.connect(db_connect_string)
self._cache.register((db_connection, db_api_module_name), alias=alias)
self._register_connection(db_connection, db_api_module_name, alias)

def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = "default"):
def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = None):
"""
Disconnects from the database.
Expand All @@ -338,13 +350,12 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias:
| Disconnect From Database | alias=my_alias | # disconnects from current connection to the database |
"""
logger.info("Executing : Disconnect From Database")
if not alias:
alias = self.default_alias
try:
db_connection, _ = self._cache.switch(alias)
except RuntimeError: # Non-existing index or alias
db_connection = None
if db_connection:
db_connection.close()
else:
db_connection = self._connections.pop(alias)
db_connection.client.close()
except KeyError: # Non-existing alias
log_msg = "No open database connection to close"
if error_if_no_connection:
raise ConnectionError(log_msg) from None
Expand All @@ -358,12 +369,11 @@ def disconnect_from_all_databases(self):
| Disconnect From All Databases | # disconnects from all connections to the database |
"""
logger.info("Executing : Disconnect From All Databases")
for db_connection, _ in self._cache:
if db_connection:
db_connection.close()
self._cache.empty_cache()
for db_connection in self._connections.values():
db_connection.client.close()
self._connections = {}

def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "default"):
def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None):
"""
Turn the autocommit on the database connection ON or OFF.
Expand All @@ -381,15 +391,34 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "defau
| Set Auto Commit | False
"""
logger.info("Executing : Set Auto Commit")
db_connection, _ = self._cache.switch(alias)
db_connection.autocommit = autoCommit
db_connection = self._get_connection_with_alias(alias)
db_connection.client.autocommit = autoCommit

def switch_database(self, alias):
def switch_database(self, alias: str):
"""
Switch default database.
Example:
| Switch Database | my_alias |
| Switch Database | alias=my_alias |
"""
self._cache.switch(alias)
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
self.default_alias = alias

def _get_connection_with_alias(self, alias: Optional[str]) -> Connection:
"""
Return connection with given alias.
If alias is not provided, it will return default connection.
If there is no default connection, it will return last opened connection.
"""
if not self._connections:
raise ValueError(f"No database connection is open.")
if not alias:
if self.default_alias in self._connections:
return self._connections[self.default_alias]
return list(self._connections.values())[-1]
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
return self._connections[alias]
Loading

0 comments on commit cc846ac

Please sign in to comment.