Skip to content

Commit

Permalink
Merge pull request #192 from MarketSquare/add_typing
Browse files Browse the repository at this point in the history
Add type hints
  • Loading branch information
amochin authored Nov 6, 2023
2 parents 2286aed + ce260dd commit 10fe9e2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
21 changes: 14 additions & 7 deletions src/DatabaseLibrary/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from robot.api import logger

Expand All @@ -20,7 +21,7 @@ class Assertion:
Assertion handles all the assertions of Database Library.
"""

def check_if_exists_in_database(self, selectStatement, sansTran=False, msg=None):
def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None):
"""
Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will
throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction
Expand Down Expand Up @@ -50,7 +51,7 @@ def check_if_exists_in_database(self, selectStatement, sansTran=False, msg=None)
msg or f"Expected to have have at least one row, " f"but got 0 rows from: '{selectStatement}'"
)

def check_if_not_exists_in_database(self, selectStatement, sansTran=False, msg=None):
def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None):
"""
This is the negation of `check_if_exists_in_database`.
Expand Down Expand Up @@ -83,7 +84,7 @@ def check_if_not_exists_in_database(self, selectStatement, sansTran=False, msg=N
msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}"
)

def row_count_is_0(self, selectStatement, sansTran=False, msg=None):
def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None):
"""
Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an
AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or
Expand Down Expand Up @@ -112,7 +113,9 @@ def row_count_is_0(self, selectStatement, sansTran=False, msg=None):
if num_rows > 0:
raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'")

def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False, msg=None):
def row_count_is_equal_to_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None
):
"""
Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this
will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit
Expand Down Expand Up @@ -144,7 +147,9 @@ def row_count_is_equal_to_x(self, selectStatement, numRows, sansTran=False, msg=
msg or f"Expected {numRows} rows, but {num_rows} were returned from: '{selectStatement}'"
)

def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False, msg=None):
def row_count_is_greater_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None
):
"""
Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then
this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit
Expand Down Expand Up @@ -176,7 +181,9 @@ def row_count_is_greater_than_x(self, selectStatement, numRows, sansTran=False,
msg or f"Expected more than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
)

def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False, msg=None):
def row_count_is_less_than_x(
self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None
):
"""
Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this
will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit
Expand Down Expand Up @@ -208,7 +215,7 @@ def row_count_is_less_than_x(self, selectStatement, numRows, sansTran=False, msg
msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
)

def table_must_exist(self, tableName, sansTran=False, msg=None):
def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None):
"""
Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument.
Expand Down
31 changes: 18 additions & 13 deletions src/DatabaseLibrary/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import importlib
from typing import Optional

try:
import ConfigParser
Expand All @@ -37,15 +38,15 @@ def __init__(self):

def connect_to_database(
self,
dbapiModuleName=None,
dbName=None,
dbUsername=None,
dbPassword=None,
dbHost=None,
dbPort=None,
dbCharset=None,
dbDriver=None,
dbConfigFile=None,
dbapiModuleName: Optional[str] = None,
dbName: Optional[str] = None,
dbUsername: Optional[str] = None,
dbPassword: Optional[str] = None,
dbHost: Optional[str] = None,
dbPort: Optional[int] = None,
dbCharset: Optional[str] = None,
dbDriver: Optional[str] = None,
dbConfigFile: Optional[str] = None,
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
Expand Down Expand Up @@ -234,7 +235,9 @@ def connect_to_database(
port=dbPort,
)

def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_connect_string=""):
def connect_to_database_using_custom_params(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = ""
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
connect to the database using the map string `db_connect_string`
Expand Down Expand Up @@ -269,7 +272,9 @@ def connect_to_database_using_custom_params(self, dbapiModuleName=None, db_conne

self._dbconnection = eval(db_connect_string)

def connect_to_database_using_custom_connection_string(self, dbapiModuleName=None, db_connect_string=""):
def connect_to_database_using_custom_connection_string(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = ""
):
"""
Loads the DB API 2.0 module given `dbapiModuleName` then uses it to
connect to the database using the `db_connect_string`
Expand All @@ -290,7 +295,7 @@ def connect_to_database_using_custom_connection_string(self, dbapiModuleName=Non
)
self._dbconnection = db_api_2.connect(db_connect_string)

def disconnect_from_database(self, error_if_no_connection=False):
def disconnect_from_database(self, error_if_no_connection: bool = False):
"""
Disconnects from the database.
Expand All @@ -311,7 +316,7 @@ def disconnect_from_database(self, error_if_no_connection=False):
self._dbconnection.close()
self._dbconnection = None

def set_auto_commit(self, autoCommit=True):
def set_auto_commit(self, autoCommit: bool = True):
"""
Turn the autocommit on the database connection ON or OFF.
Expand Down
17 changes: 9 additions & 8 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
import sys
from typing import List, Optional

from robot.api import logger

Expand All @@ -23,7 +24,7 @@ class Query:
Query handles all the querying done by the Database Library.
"""

def query(self, selectStatement, sansTran=False, returnAsDict=False):
def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False):
"""
Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional
input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -71,7 +72,7 @@ def query(self, selectStatement, sansTran=False, returnAsDict=False):
if cur and not sansTran:
self._dbconnection.rollback()

def row_count(self, selectStatement, sansTran=False):
def row_count(self, selectStatement: str, sansTran: bool = False):
"""
Uses the input `selectStatement` to query the database and returns the number of rows from the query. Set
optional input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -111,7 +112,7 @@ def row_count(self, selectStatement, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def description(self, selectStatement, sansTran=False):
def description(self, selectStatement: str, sansTran: bool = False):
"""
Uses the input `selectStatement` to query a table in the db which will be used to determine the description. Set
optional input `sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -146,7 +147,7 @@ def description(self, selectStatement, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def delete_all_rows_from_table(self, tableName, sansTran=False):
def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False):
"""
Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback.
Expand Down Expand Up @@ -181,7 +182,7 @@ def delete_all_rows_from_table(self, tableName, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def execute_sql_script(self, sqlScriptFileName, sansTran=False):
def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False):
"""
Executes the content of the `sqlScriptFileName` as SQL commands. Useful for setting the database to a known
state before running your tests, or clearing out your test data after running each a test. Set optional input
Expand Down Expand Up @@ -304,7 +305,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def execute_sql_string(self, sqlString, sansTran=False):
def execute_sql_string(self, sqlString: str, sansTran: bool = False):
"""
Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to
True to run command without an explicit transaction commit or rollback.
Expand All @@ -331,7 +332,7 @@ def execute_sql_string(self, sqlString, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def call_stored_procedure(self, spName, spParams=None, sansTran=False):
def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False):
"""
Calls a stored procedure `spName` with the `spParams` - a *list* of parameters the procedure requires.
Use the special *CURSOR* value for OUT params, which should receive result sets -
Expand Down Expand Up @@ -469,7 +470,7 @@ def call_stored_procedure(self, spName, spParams=None, sansTran=False):
if cur and not sansTran:
self._dbconnection.rollback()

def __execute_sql(self, cur, sql_statement, omit_trailing_semicolon=None):
def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None):
"""
Runs the `sql_statement` using `cur` as Cursor object.
Use `omit_trailing_semicolon` parameter (bool) for explicit instruction,
Expand Down

0 comments on commit 10fe9e2

Please sign in to comment.