diff --git a/src/DatabaseLibrary/assertion.py b/src/DatabaseLibrary/assertion.py index 0de14a50..185c2ba2 100644 --- a/src/DatabaseLibrary/assertion.py +++ b/src/DatabaseLibrary/assertion.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from robot.api import logger @@ -20,7 +21,7 @@ class Assertion: Assertion handles all the assertions of Database Library. """ - def check_if_exists_in_database(self, selectStatement, sansTran=False, msg=None): + def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None): """ Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction @@ -50,7 +51,7 @@ def check_if_exists_in_database(self, selectStatement, sansTran=False, msg=None) msg or f"Expected to have have at least one row, " f"but got 0 rows from: '{selectStatement}'" ) - def check_if_not_exists_in_database(self, selectStatement, sansTran=False, msg=None): + def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None): """ This is the negation of `check_if_exists_in_database`. @@ -83,7 +84,7 @@ def check_if_not_exists_in_database(self, selectStatement, sansTran=False, msg=N msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}" ) - def row_count_is_0(self, selectStatement, sansTran=False, msg=None): + def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None): """ Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or @@ -112,7 +113,9 @@ def row_count_is_0(self, selectStatement, sansTran=False, msg=None): if num_rows > 0: raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'") - def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False, msg=None): + def row_count_is_equal_to_x( + self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None + ): """ Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -144,7 +147,9 @@ def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False, msg= msg or f"Expected {numRows} rows, but {num_rows} were returned from: '{selectStatement}'" ) - def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False, msg=None): + def row_count_is_greater_than_x( + self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None + ): """ Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -176,7 +181,9 @@ def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False, msg or f"Expected more than {numRows} rows, but {num_rows} were returned from '{selectStatement}'" ) - def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False, msg=None): + def row_count_is_less_than_x( + self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None + ): """ Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit @@ -208,7 +215,7 @@ def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False, msg msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'" ) - def table_must_exist(self, tableName, sansTran=False, msg=None): + def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None): """ Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument. diff --git a/src/DatabaseLibrary/connection_manager.py b/src/DatabaseLibrary/connection_manager.py index ad005d2b..003a38a2 100644 --- a/src/DatabaseLibrary/connection_manager.py +++ b/src/DatabaseLibrary/connection_manager.py @@ -13,6 +13,7 @@ # limitations under the License. import importlib +from typing import Optional try: import ConfigParser @@ -37,15 +38,15 @@ def __init__(self): def connect_to_database( self, - dbapiModuleName=None, - dbName=None, - dbUsername=None, - dbPassword=None, - dbHost=None, - dbPort=None, - dbCharset=None, - dbDriver=None, - dbConfigFile=None, + dbapiModuleName: Optional[str] = None, + dbName: Optional[str] = None, + dbUsername: Optional[str] = None, + dbPassword: Optional[str] = None, + dbHost: Optional[str] = None, + dbPort: Optional[int] = None, + dbCharset: Optional[str] = None, + dbDriver: Optional[str] = None, + dbConfigFile: Optional[str] = None, ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -234,7 +235,9 @@ def connect_to_database( port=dbPort, ) - def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_connect_string=""): + def connect_to_database_using_custom_params( + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "" + ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to connect to the database using the map string `db_connect_string` @@ -269,7 +272,9 @@ def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_conne self._dbconnection = eval(db_connect_string) - def connect_to_database_using_custom_connection_string(self, dbapiModuleName=None, db_connect_string=""): + def connect_to_database_using_custom_connection_string( + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "" + ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to connect to the database using the `db_connect_string` @@ -290,7 +295,7 @@ def connect_to_database_using_custom_connection_string(self, dbapiModuleName=Non ) self._dbconnection = db_api_2.connect(db_connect_string) - def disconnect_from_database(self, error_if_no_connection=False): + def disconnect_from_database(self, error_if_no_connection: bool = False): """ Disconnects from the database. @@ -311,7 +316,7 @@ def disconnect_from_database(self, error_if_no_connection=False): self._dbconnection.close() self._dbconnection = None - def set_auto_commit(self, autoCommit=True): + def set_auto_commit(self, autoCommit: bool = True): """ Turn the autocommit on the database connection ON or OFF. diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index 5065f0f4..f4c47867 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -14,6 +14,7 @@ import inspect import sys +from typing import List, Optional from robot.api import logger @@ -23,7 +24,7 @@ class Query: Query handles all the querying done by the Database Library. """ - def query(self, selectStatement, sansTran=False, returnAsDict=False): + def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False): """ Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -71,7 +72,7 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False): if cur and not sansTran: self._dbconnection.rollback() - def row_count(self, selectStatement, sansTran=False): + def row_count(self, selectStatement: str, sansTran: bool = False): """ Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -111,7 +112,7 @@ def row_count(self, selectStatement, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def description(self, selectStatement, sansTran=False): + def description(self, selectStatement: str, sansTran: bool = False): """ Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -146,7 +147,7 @@ def description(self, selectStatement, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def delete_all_rows_from_table(self, tableName, sansTran=False): + def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False): """ Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -181,7 +182,7 @@ def delete_all_rows_from_table(self, tableName, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def execute_sql_script(self, sqlScriptFileName, sansTran=False): + def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False): """ Executes the content of the `sqlScriptFileName` as SQL commands. Useful for setting the database to a known state before running your tests, or clearing out your test data after running each a test. Set optional input @@ -304,7 +305,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def execute_sql_string(self, sqlString, sansTran=False): + def execute_sql_string(self, sqlString: str, sansTran: bool = False): """ Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -331,7 +332,7 @@ def execute_sql_string(self, sqlString, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def call_stored_procedure(self, spName, spParams=None, sansTran=False): + def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False): """ Calls a stored procedure `spName` with the `spParams` - a *list* of parameters the procedure requires. Use the special *CURSOR* value for OUT params, which should receive result sets - @@ -469,7 +470,7 @@ def call_stored_procedure(self, spName, spParams=None, sansTran=False): if cur and not sansTran: self._dbconnection.rollback() - def __execute_sql(self, cur, sql_statement, omit_trailing_semicolon=None): + def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None): """ Runs the `sql_statement` using `cur` as Cursor object. Use `omit_trailing_semicolon` parameter (bool) for explicit instruction,