Skip to content

Commit

Permalink
Add timeout hack to mitigate timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
capitancambio committed Jun 2, 2023
1 parent 6f83144 commit 8347257
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 109 deletions.
1 change: 1 addition & 0 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


import thrift
import thrift.transport.THttpClient

import urllib.parse, six, base64

Expand Down
156 changes: 47 additions & 109 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,27 @@
from decimal import Decimal
import errno
import logging
import math
import time
import os
import threading
import lz4.frame
import time
from decimal import Decimal
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import databricks.sql.auth.thrift_http_client
import lz4.frame
import pyarrow
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.THttpClient
import thrift.transport.TSocket
import thrift.transport.TTransport

import databricks.sql.auth.thrift_http_client
from databricks.sql import *
from databricks.sql.auth.authenticators import AuthProvider
from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes
from databricks.sql import *
from databricks.sql.thrift_api.TCLIService.TCLIService import (
Client as TCLIServiceClient,
)

from databricks.sql.utils import (
ArrowQueue,
ExecuteResponse,
_bound,
RequestErrorInfo,
NoRetryReason,
)
from databricks.sql.thrift_api.TCLIService.TCLIService import \
Client as TCLIServiceClient
from databricks.sql.utils import (ArrowQueue, ExecuteResponse, NoRetryReason,
RequestErrorInfo, _bound)

logger = logging.getLogger(__name__)

Expand All @@ -38,6 +31,9 @@

TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString"

# HACK!
THRIFT_SOCKET_TIMEOUT = os.getenv("THRIFT_SOCKET_TIMEOUT", None)

# see Connection.__init__ for parameter descriptions.
# - Min/Max avoids unsustainable configs (sane values are far more constrained)
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
Expand Down Expand Up @@ -114,13 +110,9 @@ def __init__(

self.staging_allowed_local_path = staging_allowed_local_path
self._initialize_retry_args(kwargs)
self._use_arrow_native_complex_types = kwargs.get(
"_use_arrow_native_complex_types", True
)
self._use_arrow_native_complex_types = kwargs.get("_use_arrow_native_complex_types", True)
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
self._use_arrow_native_timestamps = kwargs.get(
"_use_arrow_native_timestamps", True
)
self._use_arrow_native_timestamps = kwargs.get("_use_arrow_native_timestamps", True)

# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
Expand Down Expand Up @@ -152,7 +144,10 @@ def __init__(
ssl_context=ssl_context,
)

timeout = kwargs.get("_socket_timeout")
# HACK!
timeout = kwargs.get("_socket_timeout") or THRIFT_SOCKET_TIMEOUT
logger.info(f"Setting timeout HACK! to {timeout}")

# setTimeout defaults to None (i.e. no timeout), and is expected in ms
self._transport.setTimeout(timeout and (float(timeout) * 1000.0))

Expand All @@ -175,15 +170,11 @@ def _initialize_retry_args(self, kwargs):
given_or_default = type_(kwargs.get(key, default))
bound = _bound(min, max, given_or_default)
setattr(self, key, bound)
logger.debug(
"retry parameter: {} given_or_default {}".format(key, given_or_default)
)
logger.debug("retry parameter: {} given_or_default {}".format(key, given_or_default))
if bound != given_or_default:
logger.warning(
"Override out of policy retry parameter: "
+ "{} given {}, restricted to {}".format(
key, given_or_default, bound
)
+ "{} given {}, restricted to {}".format(key, given_or_default, bound)
)

# Fail on retry delay min > max; consider later adding fail on min > duration?
Expand Down Expand Up @@ -211,9 +202,7 @@ def _extract_error_message_from_headers(headers):
if THRIFT_ERROR_MESSAGE_HEADER in headers:
err_msg = headers[THRIFT_ERROR_MESSAGE_HEADER]
if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers:
if (
err_msg
): # We don't expect both to be set, but log both here just in case
if err_msg: # We don't expect both to be set, but log both here just in case
err_msg = "Thriftserver error: {}, Databricks error: {}".format(
err_msg, headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER]
)
Expand Down Expand Up @@ -406,10 +395,7 @@ def _check_initial_namespace(self, catalog, schema, response):
if not (catalog or schema):
return

if (
response.serverProtocolVersion
< ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4
):
if response.serverProtocolVersion < ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4:
raise InvalidServerResponseError(
"Setting initial namespace not supported by the DBR version, "
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0."
Expand All @@ -424,10 +410,7 @@ def _check_initial_namespace(self, catalog, schema, response):

def _check_session_configuration(self, session_configuration):
# This client expects timetampsAsString to be false, so we do not allow users to modify that
if (
session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower()
!= "false"
):
if session_configuration.get(TIMESTAMP_AS_STRING_CONFIG, "false").lower() != "false":
raise Error(
"Invalid session configuration: {} cannot be changed "
"while using the Databricks SQL connector, it must be false not {}".format(
Expand All @@ -439,18 +422,14 @@ def _check_session_configuration(self, session_configuration):
def open_session(self, session_configuration, catalog, schema):
try:
self._transport.open()
session_configuration = {
k: str(v) for (k, v) in (session_configuration or {}).items()
}
session_configuration = {k: str(v) for (k, v) in (session_configuration or {}).items()}
self._check_session_configuration(session_configuration)
# We want to receive proper Timestamp arrow types.
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
# but it doesn't hurt to also set for the whole session.
session_configuration[TIMESTAMP_AS_STRING_CONFIG] = "false"
if catalog or schema:
initial_namespace = ttypes.TNamespace(
catalogName=catalog, schemaName=schema
)
initial_namespace = ttypes.TNamespace(catalogName=catalog, schemaName=schema)
else:
initial_namespace = None

Expand All @@ -476,9 +455,7 @@ def close_session(self, session_handle) -> None:
finally:
self._transport.close()

def _check_command_not_in_error_or_closed_state(
self, op_handle, get_operations_resp
):
def _check_command_not_in_error_or_closed_state(self, op_handle, get_operations_resp):
if get_operations_resp.operationState == ttypes.TOperationState.ERROR_STATE:
if get_operations_resp.displayMessage:
raise ServerOperationError(
Expand Down Expand Up @@ -513,17 +490,11 @@ def _poll_for_status(self, op_handle):

def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
if t_row_set.columns is not None:
(
arrow_table,
num_rows,
) = ThriftBackend._convert_column_based_set_to_arrow_table(
(arrow_table, num_rows,) = ThriftBackend._convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
elif t_row_set.arrowBatches is not None:
(
arrow_table,
num_rows,
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
(arrow_table, num_rows,) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
Expand All @@ -534,9 +505,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
def _convert_decimals_in_arrow_table(table, description):
for (i, col) in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
lambda v: v if v is None else Decimal(v)
)
decimal_col = col.to_pandas().apply(lambda v: v if v is None else Decimal(v))
precision, scale = description[i][4], description[i][5]
assert scale is not None
assert precision is not None
Expand All @@ -549,9 +518,7 @@ def _convert_decimals_in_arrow_table(table, description):
return table

@staticmethod
def _convert_arrow_based_set_to_arrow_table(
arrow_batches, lz4_compressed, schema_bytes
):
def _convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes):
ba = bytearray()
ba += schema_bytes
n_rows = 0
Expand Down Expand Up @@ -597,9 +564,7 @@ def _convert_column_to_arrow_array(t_col):
for field in field_name_to_arrow_type.keys():
wrapper = getattr(t_col, field)
if wrapper:
return ThriftBackend._create_arrow_array(
wrapper, field_name_to_arrow_type[field]
)
return ThriftBackend._create_arrow_array(wrapper, field_name_to_arrow_type[field])

raise OperationalError("Empty TColumn instance {}".format(t_col))

Expand Down Expand Up @@ -654,9 +619,7 @@ def map_type(t_type_entry):
else:
# Current thriftserver implementation should always return a primitiveEntry,
# even for complex types
raise OperationalError(
"Thrift protocol error: t_type_entry not a primitiveEntry"
)
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")

def convert_col(t_column_desc):
return pyarrow.field(
Expand All @@ -674,9 +637,7 @@ def _col_to_description(col):
# Drop _TYPE suffix
cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower()
else:
raise OperationalError(
"Thrift protocol error: t_type_entry not a primitiveEntry"
)
raise OperationalError("Thrift protocol error: t_type_entry not a primitiveEntry")

if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
qualifiers = type_entry.primitiveEntry.typeQualifiers.qualifiers
Expand All @@ -697,9 +658,7 @@ def _col_to_description(col):

@staticmethod
def _hive_schema_to_description(t_table_schema):
return [
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
]
return [ThriftBackend._col_to_description(col) for col in t_table_schema.columns]

def _results_message_to_execute_response(self, resp, operation_state):
if resp.directResults and resp.directResults.resultSetMetadata:
Expand All @@ -726,9 +685,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
or (not direct_results.resultSet)
or direct_results.resultSet.hasMoreRows
)
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
description = self._hive_schema_to_description(t_result_set_metadata_resp.schema)
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
Expand Down Expand Up @@ -768,8 +725,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
op_handle, initial_operation_status_resp
)
operation_state = (
initial_operation_status_resp
and initial_operation_status_resp.operationState
initial_operation_status_resp and initial_operation_status_resp.operationState
)
while not operation_state or operation_state in [
ttypes.TOperationState.RUNNING_STATE,
Expand All @@ -784,21 +740,13 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
def _check_direct_results_for_error(t_spark_direct_results):
if t_spark_direct_results:
if t_spark_direct_results.operationStatus:
ThriftBackend._check_response_for_error(
t_spark_direct_results.operationStatus
)
ThriftBackend._check_response_for_error(t_spark_direct_results.operationStatus)
if t_spark_direct_results.resultSetMetadata:
ThriftBackend._check_response_for_error(
t_spark_direct_results.resultSetMetadata
)
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSetMetadata)
if t_spark_direct_results.resultSet:
ThriftBackend._check_response_for_error(
t_spark_direct_results.resultSet
)
ThriftBackend._check_response_for_error(t_spark_direct_results.resultSet)
if t_spark_direct_results.closeOperation:
ThriftBackend._check_response_for_error(
t_spark_direct_results.closeOperation
)
ThriftBackend._check_response_for_error(t_spark_direct_results.closeOperation)

def execute_command(
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
Expand All @@ -817,9 +765,7 @@ def execute_command(
sessionHandle=session_handle,
statement=operation,
runAsync=True,
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
canReadArrowResult=True,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=False,
Expand All @@ -837,9 +783,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):

req = ttypes.TGetCatalogsReq(
sessionHandle=session_handle,
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
)
resp = self.make_request(self._client.GetCatalogs, req)
return self._handle_execute_response(resp, cursor)
Expand All @@ -857,9 +801,7 @@ def get_schemas(

req = ttypes.TGetSchemasReq(
sessionHandle=session_handle,
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
catalogName=catalog_name,
schemaName=schema_name,
)
Expand All @@ -881,9 +823,7 @@ def get_tables(

req = ttypes.TGetTablesReq(
sessionHandle=session_handle,
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
catalogName=catalog_name,
schemaName=schema_name,
tableName=table_name,
Expand All @@ -907,9 +847,7 @@ def get_columns(

req = ttypes.TGetColumnsReq(
sessionHandle=session_handle,
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
getDirectResults=ttypes.TSparkGetDirectResults(maxRows=max_rows, maxBytes=max_bytes),
catalogName=catalog_name,
schemaName=schema_name,
tableName=table_name,
Expand Down

0 comments on commit 8347257

Please sign in to comment.