Skip to content

Commit 8347257

Browse files
committed
Add timeout hack to mitigate timeouts
1 parent 6f83144 commit 8347257

File tree

2 files changed

+48
-109
lines changed

2 files changed

+48
-109
lines changed

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
import thrift
6+
import thrift.transport.THttpClient
67

78
import urllib.parse, six, base64
89

src/databricks/sql/thrift_backend.py

Lines changed: 47 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
1-
from decimal import Decimal
21
import errno
32
import logging
43
import math
5-
import time
4+
import os
65
import threading
7-
import lz4.frame
6+
import time
7+
from decimal import Decimal
88
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
99
from typing import List, Union
1010

11+
import databricks.sql.auth.thrift_http_client
12+
import lz4.frame
1113
import pyarrow
12-
import thrift.transport.THttpClient
1314
import thrift.protocol.TBinaryProtocol
15+
import thrift.transport.THttpClient
1416
import thrift.transport.TSocket
1517
import thrift.transport.TTransport
16-
17-
import databricks.sql.auth.thrift_http_client
18+
from databricks.sql import *
1819
from databricks.sql.auth.authenticators import AuthProvider
1920
from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes
20-
from databricks.sql import *
21-
from databricks.sql.thrift_api.TCLIService.TCLIService import (
22-
Client as TCLIServiceClient,
23-
)
24-
25-
from databricks.sql.utils import (
26-
ArrowQueue,
27-
ExecuteResponse,
28-
_bound,
29-
RequestErrorInfo,
30-
NoRetryReason,
31-
)
21+
from databricks.sql.thrift_api.TCLIService.TCLIService import \
22+
Client as TCLIServiceClient
23+
from databricks.sql.utils import (ArrowQueue, ExecuteResponse, NoRetryReason,
24+
RequestErrorInfo, _bound)
3225

3326
logger = logging.getLogger(__name__)
3427

@@ -38,6 +31,9 @@
3831

3932
TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString"
4033

34+
# HACK!
35+
THRIFT_SOCKET_TIMEOUT = os.getenv("THRIFT_SOCKET_TIMEOUT", None)
36+
4137
# see Connection.__init__ for parameter descriptions.
4238
# - Min/Max avoids unsustainable configs (sane values are far more constrained)
4339
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
@@ -114,13 +110,9 @@ def __init__(
114110

115111
self.staging_allowed_local_path = staging_allowed_local_path
116112
self._initialize_retry_args(kwargs)
117-
self._use_arrow_native_complex_types = kwargs.get(
118-
"_use_arrow_native_complex_types", True
119-
)
113+
self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True)
120114
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
121-
self._use_arrow_native_timestamps = kwargs.get(
122-
"_use_arrow_native_timestamps", True
123-
)
115+
self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True)
124116

125117
# Configure tls context
126118
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
@@ -152,7 +144,10 @@ def __init__(
152144
ssl_context=ssl_context,
153145
)
154146

155-
timeout = kwargs.get("_socket_timeout")
147+
# HACK!
148+
timeout = kwargs.get("_socket_timeout") or THRIFT_SOCKET_TIMEOUT
149+
logger.info(f"Setting timeout HACK! to {timeout}")
150+
156151
# setTimeout defaults to None (i.e. no timeout), and is expected in ms
157152
self._transport.setTimeout(timeout and (float(timeout) * 1000.0))
158153

@@ -175,15 +170,11 @@ def _initialize_retry_args(self, kwargs):
175170
given_or_default = type_(kwargs.get(key, default))
176171
bound = _bound(min, max, given_or_default)
177172
setattr(self, key, bound)
178-
logger.debug(
179-
"retry parameter: {} given_or_default {}".format(key, given_or_default)
180-
)
173+
logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default))
181174
if bound != given_or_default:
182175
logger.warning(
183176
"Override out of policy retry parameter: "
184-
+ "{} given {}, restricted to {}".format(
185-
key, given_or_default, bound
186-
)
177+
+ "{} given {}, restricted to {}".format(key, given_or_default, bound)
187178
)
188179

189180
# Fail on retry delay min > max; consider later adding fail on min > duration?
@@ -211,9 +202,7 @@ def _extract_error_message_from_headers(headers):
211202
if THRIFT_ERROR_MESSAGE_HEADER in headers:
212203
err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER]
213204
if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers:
214-
if (
215-
err_msg
216-
): # We don't expect both to be set, but log both here just in case
205+
if err_msg: # We don't expect both to be set, but log both here just in case
217206
err_msg = "Thriftserver error: {}, Databricks error: {}".format(
218207
err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER]
219208
)
@@ -406,10 +395,7 @@ def _check_initial_namespace(self, catalog, schema, response):
406395
if not (catalog or schema):
407396
return
408397

409-
if (
410-
response.serverProtocolVersion
411-
< ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4
412-
):
398+
if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4:
413399
raise InvalidServerResponseError(
414400
"Setting initial namespace not supported by the DBR version, "
415401
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0."
@@ -424,10 +410,7 @@ def _check_initial_namespace(self, catalog, schema, response):
424410

425411
def _check_session_configuration(self, session_configuration):
426412
# This client expects timetampsAsString to be false, so we do not allow users to modify that
427-
if (
428-
session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower()
429-
!= "false"
430-
):
413+
if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false":
431414
raise Error(
432415
"Invalid session configuration: {} cannot be changed "
433416
"while using the Databricks SQL connector, it must be false not {}".format(
@@ -439,18 +422,14 @@ def _check_session_configuration(self, session_configuration):
439422
def open_session(self, session_configuration, catalog, schema):
440423
try:
441424
self._transport.open()
442-
session_configuration = {
443-
k: str(v) for (k, v) in (session_configuration or {}).items()
444-
}
425+
session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()}
445426
self._check_session_configuration(session_configuration)
446427
# We want to receive proper Timestamp arrow types.
447428
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
448429
# but it doesn't hurt to also set for the whole session.
449430
session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false"
450431
if catalog or schema:
451-
initial_namespace = ttypes.TNamespace(
452-
catalogName=catalog, schemaName=schema
453-
)
432+
initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema)
454433
else:
455434
initial_namespace = None
456435

@@ -476,9 +455,7 @@ def close_session(self, session_handle) -> None:
476455
finally:
477456
self._transport.close()
478457

479-
def _check_command_not_in_error_or_closed_state(
480-
self, op_handle, get_operations_resp
481-
):
458+
def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp):
482459
if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE:
483460
if get_operations_resp.displayMessage:
484461
raise ServerOperationError(
@@ -513,17 +490,11 @@ def _poll_for_status(self, op_handle):
513490

514491
def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
515492
if t_row_set.columns is not None:
516-
(
517-
arrow_table,
518-
num_rows,
519-
) = ThriftBackend._convert_column_based_set_to_arrow_table(
493+
(arrow_table, num_rows,) = ThriftBackend._convert_column_based_set_to_arrow_table(
520494
t_row_set.columns, description
521495
)
522496
elif t_row_set.arrowBatches is not None:
523-
(
524-
arrow_table,
525-
num_rows,
526-
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
497+
(arrow_table, num_rows,) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
527498
t_row_set.arrowBatches, lz4_compressed, schema_bytes
528499
)
529500
else:
@@ -534,9 +505,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
534505
def _convert_decimals_in_arrow_table(table, description):
535506
for (i, col) in enumerate(table.itercolumns()):
536507
if description[i][1] == "decimal":
537-
decimal_col = col.to_pandas().apply(
538-
lambda v: v if v is None else Decimal(v)
539-
)
508+
decimal_col = col.to_pandas().apply(lambda v: v if v is None else Decimal(v))
540509
precision, scale = description[i][4], description[i][5]
541510
assert scale is not None
542511
assert precision is not None
@@ -549,9 +518,7 @@ def _convert_decimals_in_arrow_table(table, description):
549518
return table
550519

551520
@staticmethod
552-
def _convert_arrow_based_set_to_arrow_table(
553-
arrow_batches, lz4_compressed, schema_bytes
554-
):
521+
def _convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes):
555522
ba = bytearray()
556523
ba += schema_bytes
557524
n_rows = 0
@@ -597,9 +564,7 @@ def _convert_column_to_arrow_array(t_col):
597564
for field in field_name_to_arrow_type.keys():
598565
wrapper = getattr(t_col, field)
599566
if wrapper:
600-
return ThriftBackend._create_arrow_array(
601-
wrapper, field_name_to_arrow_type[field]
602-
)
567+
return ThriftBackend._create_arrow_array(wrapper, field_name_to_arrow_type[field])
603568

604569
raise OperationalError("Empty TColumn instance {}".format(t_col))
605570

@@ -654,9 +619,7 @@ def map_type(t_type_entry):
654619
else:
655620
# Current thriftserver implementation should always return a primitiveEntry,
656621
# even for complex types
657-
raise OperationalError(
658-
"Thrift protocol error: t_type_entry not a primitiveEntry"
659-
)
622+
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
660623

661624
def convert_col(t_column_desc):
662625
return pyarrow.field(
@@ -674,9 +637,7 @@ def _col_to_description(col):
674637
# Drop _TYPE suffix
675638
cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower()
676639
else:
677-
raise OperationalError(
678-
"Thrift protocol error: t_type_entry not a primitiveEntry"
679-
)
640+
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")
680641

681642
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
682643
qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers
@@ -697,9 +658,7 @@ def _col_to_description(col):
697658

698659
@staticmethod
699660
def _hive_schema_to_description(t_table_schema):
700-
return [
701-
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
702-
]
661+
return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns]
703662

704663
def _results_message_to_execute_response(self, resp, operation_state):
705664
if resp.directResults and resp.directResults.resultSetMetadata:
@@ -726,9 +685,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
726685
or (not direct_results.resultSet)
727686
or direct_results.resultSet.hasMoreRows
728687
)
729-
description = self._hive_schema_to_description(
730-
t_result_set_metadata_resp.schema
731-
)
688+
description = self._hive_schema_to_description(t_result_set_metadata_resp.schema)
732689
schema_bytes = (
733690
t_result_set_metadata_resp.arrowSchema
734691
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
@@ -768,8 +725,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
768725
op_handle, initial_operation_status_resp
769726
)
770727
operation_state = (
771-
initial_operation_status_resp
772-
and initial_operation_status_resp.operationState
728+
initial_operation_status_resp and initial_operation_status_resp.operationState
773729
)
774730
while not operation_state or operation_state in [
775731
ttypes.TOperationState.RUNNING_STATE,
@@ -784,21 +740,13 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
784740
def _check_direct_results_for_error(t_spark_direct_results):
785741
if t_spark_direct_results:
786742
if t_spark_direct_results.operationStatus:
787-
ThriftBackend._check_response_for_error(
788-
t_spark_direct_results.operationStatus
789-
)
743+
ThriftBackend._check_response_for_error(t_spark_direct_results.operationStatus)
790744
if t_spark_direct_results.resultSetMetadata:
791-
ThriftBackend._check_response_for_error(
792-
t_spark_direct_results.resultSetMetadata
793-
)
745+
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSetMetadata)
794746
if t_spark_direct_results.resultSet:
795-
ThriftBackend._check_response_for_error(
796-
t_spark_direct_results.resultSet
797-
)
747+
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSet)
798748
if t_spark_direct_results.closeOperation:
799-
ThriftBackend._check_response_for_error(
800-
t_spark_direct_results.closeOperation
801-
)
749+
ThriftBackend._check_response_for_error(t_spark_direct_results.closeOperation)
802750

803751
def execute_command(
804752
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
@@ -817,9 +765,7 @@ def execute_command(
817765
sessionHandle=session_handle,
818766
statement=operation,
819767
runAsync=True,
820-
getDirectResults=ttypes.TSparkGetDirectResults(
821-
maxRows=max_rows, maxBytes=max_bytes
822-
),
768+
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
823769
canReadArrowResult=True,
824770
canDecompressLZ4Result=lz4_compression,
825771
canDownloadResult=False,
@@ -837,9 +783,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
837783

838784
req = ttypes.TGetCatalogsReq(
839785
sessionHandle=session_handle,
840-
getDirectResults=ttypes.TSparkGetDirectResults(
841-
maxRows=max_rows, maxBytes=max_bytes
842-
),
786+
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
843787
)
844788
resp = self.make_request(self._client.GetCatalogs, req)
845789
return self._handle_execute_response(resp, cursor)
@@ -857,9 +801,7 @@ def get_schemas(
857801

858802
req = ttypes.TGetSchemasReq(
859803
sessionHandle=session_handle,
860-
getDirectResults=ttypes.TSparkGetDirectResults(
861-
maxRows=max_rows, maxBytes=max_bytes
862-
),
804+
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
863805
catalogName=catalog_name,
864806
schemaName=schema_name,
865807
)
@@ -881,9 +823,7 @@ def get_tables(
881823

882824
req = ttypes.TGetTablesReq(
883825
sessionHandle=session_handle,
884-
getDirectResults=ttypes.TSparkGetDirectResults(
885-
maxRows=max_rows, maxBytes=max_bytes
886-
),
826+
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
887827
catalogName=catalog_name,
888828
schemaName=schema_name,
889829
tableName=table_name,
@@ -907,9 +847,7 @@ def get_columns(
907847

908848
req = ttypes.TGetColumnsReq(
909849
sessionHandle=session_handle,
910-
getDirectResults=ttypes.TSparkGetDirectResults(
911-
maxRows=max_rows, maxBytes=max_bytes
912-
),
850+
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
913851
catalogName=catalog_name,
914852
schemaName=schema_name,
915853
tableName=table_name,

0 commit comments

Comments
 (0)