diff --git a/README.rst b/README.rst index e0effe09..5f54235a 100644 --- a/README.rst +++ b/README.rst @@ -199,7 +199,7 @@ Dictionary. Current available keys are: - connection_retry_backoff_time - Integer. Sets the back off time in seconds for reries of + Integer. Sets the back off time in seconds for retries of the database connection process. Default value is ``5``. - query_timeout @@ -207,6 +207,17 @@ Dictionary. Current available keys are: Integer. Sets the timeout in seconds for the database query. Default value is ``0`` which disables the timeout. +- failover_partner + + String. Same as HOST but for failover partner. + Default is not specified which disable the failover partner. + +- connection_retry_failover_backoff_time + + Integer. Sets the back off time in seconds before trying + to connect to failover partner. + Default value is ``0`` which disables the timeout. + backend-specific settings ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/sql_server/pyodbc/base.py b/sql_server/pyodbc/base.py index 1b0e813e..0fcb6793 100644 --- a/sql_server/pyodbc/base.py +++ b/sql_server/pyodbc/base.py @@ -170,6 +170,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): '49919', '49920', ) + _unrecoverable_error_numbers = ( + '18486', # account is locked + '18487', # password expired + '18488', # password should be changed + '18452', # login from untrusted domain + ) def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) @@ -230,7 +236,7 @@ def get_connection_params(self): conn_params['NAME'] = 'master' return conn_params - def get_new_connection(self, conn_params): + def _get_connection_strings(self, conn_params, options): database = conn_params['NAME'] host = conn_params.get('HOST', 'localhost') user = conn_params.get('USER', None) @@ -238,7 +244,6 @@ def get_new_connection(self, conn_params): port = conn_params.get('PORT', None) default_driver = 'SQL Server' if os.name == 'nt' else 'FreeTDS' - options = conn_params.get('OPTIONS', {}) driver = options.get('driver', default_driver) dsn = options.get('dsn', None) @@ -298,24 +303,56 @@ def get_new_connection(self, conn_params): if options.get('extra_params', None): connstr += ';' + options['extra_params'] + failover_host = options.get('failover_partner', None) + failover_connstr = None + if failover_host: + failover_connstr = connstr.replace(host, failover_host, 1) + + return connstr, failover_connstr + + def get_new_connection(self, conn_params): + options = conn_params.get('OPTIONS', {}) + + connstr, failover_connstr = self._get_connection_strings(conn_params, options) + unicode_results = options.get('unicode_results', False) timeout = options.get('connection_timeout', 0) retries = options.get('connection_retries', 5) backoff_time = options.get('connection_retry_backoff_time', 5) + failover_backoff_time = options.get('connection_retry_failover_backoff_time', 0) query_timeout = options.get('query_timeout', 0) conn = None retry_count = 0 need_to_retry = False + failover = False + error_numbers, failover_error_numbers = '', '' while conn is None: try: conn = Database.connect(connstr, unicode_results=unicode_results, timeout=timeout) except Exception as e: + current_error_numbers = e.args[1] + for error_number in self._unrecoverable_error_numbers: # never retry upon receiving unrecoverable code + if error_number in current_error_numbers: + raise + + if not failover: + error_numbers = current_error_numbers + else: + failover_error_numbers = current_error_numbers + + if failover_connstr: # retry with failover if available + connstr, failover_connstr = failover_connstr, connstr + failover = not failover + if failover: + time.sleep(failover_backoff_time) + continue + for error_number in self._transient_error_numbers: - if error_number in e.args[1]: - if error_number in e.args[1] and retry_count < retries: + if error_number in error_numbers or error_number in failover_error_numbers: + if retry_count < retries: time.sleep(backoff_time) need_to_retry = True retry_count = retry_count + 1