1
- from decimal import Decimal
2
1
import errno
3
2
import logging
4
3
import math
5
- import time
4
+ import os
6
5
import threading
7
- import lz4 .frame
6
+ import time
7
+ from decimal import Decimal
8
8
from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
9
9
from typing import List , Union
10
10
11
+ import databricks .sql .auth .thrift_http_client
12
+ import lz4 .frame
11
13
import pyarrow
12
- import thrift .transport .THttpClient
13
14
import thrift .protocol .TBinaryProtocol
15
+ import thrift .transport .THttpClient
14
16
import thrift .transport .TSocket
15
17
import thrift .transport .TTransport
16
-
17
- import databricks .sql .auth .thrift_http_client
18
+ from databricks .sql import *
18
19
from databricks .sql .auth .authenticators import AuthProvider
19
20
from databricks .sql .thrift_api .TCLIService import TCLIService , ttypes
20
- from databricks .sql import *
21
- from databricks .sql .thrift_api .TCLIService .TCLIService import (
22
- Client as TCLIServiceClient ,
23
- )
24
-
25
- from databricks .sql .utils import (
26
- ArrowQueue ,
27
- ExecuteResponse ,
28
- _bound ,
29
- RequestErrorInfo ,
30
- NoRetryReason ,
31
- )
21
+ from databricks .sql .thrift_api .TCLIService .TCLIService import \
22
+ Client as TCLIServiceClient
23
+ from databricks .sql .utils import (ArrowQueue , ExecuteResponse , NoRetryReason ,
24
+ RequestErrorInfo , _bound )
32
25
33
26
logger = logging .getLogger (__name__ )
34
27
38
31
39
32
TIMESTAMP_AS_STRING_CONFIG = "spark.thriftserver.arrowBasedRowSet.timestampAsString"
40
33
34
+ # HACK!
35
+ THRIFT_SOCKET_TIMEOUT = os .getenv ("THRIFT_SOCKET_TIMEOUT" , None )
36
+
41
37
# see Connection.__init__ for parameter descriptions.
42
38
# - Min/Max avoids unsustainable configs (sane values are far more constrained)
43
39
# - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins)
@@ -114,13 +110,9 @@ def __init__(
114
110
115
111
self .staging_allowed_local_path = staging_allowed_local_path
116
112
self ._initialize_retry_args (kwargs )
117
- self ._use_arrow_native_complex_types = kwargs .get (
118
- "_use_arrow_native_complex_types" , True
119
- )
113
+ self ._use_arrow_native_complex_types = kwargs .get ("_use_arrow_native_complex_types" , True )
120
114
self ._use_arrow_native_decimals = kwargs .get ("_use_arrow_native_decimals" , True )
121
- self ._use_arrow_native_timestamps = kwargs .get (
122
- "_use_arrow_native_timestamps" , True
123
- )
115
+ self ._use_arrow_native_timestamps = kwargs .get ("_use_arrow_native_timestamps" , True )
124
116
125
117
# Configure tls context
126
118
ssl_context = create_default_context (cafile = kwargs .get ("_tls_trusted_ca_file" ))
@@ -152,7 +144,10 @@ def __init__(
152
144
ssl_context = ssl_context ,
153
145
)
154
146
155
- timeout = kwargs .get ("_socket_timeout" )
147
+ # HACK!
148
+ timeout = kwargs .get ("_socket_timeout" ) or THRIFT_SOCKET_TIMEOUT
149
+ logger .info (f"Setting timeout HACK! to { timeout } " )
150
+
156
151
# setTimeout defaults to None (i.e. no timeout), and is expected in ms
157
152
self ._transport .setTimeout (timeout and (float (timeout ) * 1000.0 ))
158
153
@@ -175,15 +170,11 @@ def _initialize_retry_args(self, kwargs):
175
170
given_or_default = type_ (kwargs .get (key , default ))
176
171
bound = _bound (min , max , given_or_default )
177
172
setattr (self , key , bound )
178
- logger .debug (
179
- "retry parameter: {} given_or_default {}" .format (key , given_or_default )
180
- )
173
+ logger .debug ("retry parameter: {} given_or_default {}" .format (key , given_or_default ))
181
174
if bound != given_or_default :
182
175
logger .warning (
183
176
"Override out of policy retry parameter: "
184
- + "{} given {}, restricted to {}" .format (
185
- key , given_or_default , bound
186
- )
177
+ + "{} given {}, restricted to {}" .format (key , given_or_default , bound )
187
178
)
188
179
189
180
# Fail on retry delay min > max; consider later adding fail on min > duration?
@@ -211,9 +202,7 @@ def _extract_error_message_from_headers(headers):
211
202
if THRIFT_ERROR_MESSAGE_HEADER in headers :
212
203
err_msg = headers [THRIFT_ERROR_MESSAGE_HEADER ]
213
204
if DATABRICKS_ERROR_OR_REDIRECT_HEADER in headers :
214
- if (
215
- err_msg
216
- ): # We don't expect both to be set, but log both here just in case
205
+ if err_msg : # We don't expect both to be set, but log both here just in case
217
206
err_msg = "Thriftserver error: {}, Databricks error: {}" .format (
218
207
err_msg , headers [DATABRICKS_ERROR_OR_REDIRECT_HEADER ]
219
208
)
@@ -406,10 +395,7 @@ def _check_initial_namespace(self, catalog, schema, response):
406
395
if not (catalog or schema ):
407
396
return
408
397
409
- if (
410
- response .serverProtocolVersion
411
- < ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V4
412
- ):
398
+ if response .serverProtocolVersion < ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V4 :
413
399
raise InvalidServerResponseError (
414
400
"Setting initial namespace not supported by the DBR version, "
415
401
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0."
@@ -424,10 +410,7 @@ def _check_initial_namespace(self, catalog, schema, response):
424
410
425
411
def _check_session_configuration (self , session_configuration ):
426
412
# This client expects timetampsAsString to be false, so we do not allow users to modify that
427
- if (
428
- session_configuration .get (TIMESTAMP_AS_STRING_CONFIG , "false" ).lower ()
429
- != "false"
430
- ):
413
+ if session_configuration .get (TIMESTAMP_AS_STRING_CONFIG , "false" ).lower () != "false" :
431
414
raise Error (
432
415
"Invalid session configuration: {} cannot be changed "
433
416
"while using the Databricks SQL connector, it must be false not {}" .format (
@@ -439,18 +422,14 @@ def _check_session_configuration(self, session_configuration):
439
422
def open_session (self , session_configuration , catalog , schema ):
440
423
try :
441
424
self ._transport .open ()
442
- session_configuration = {
443
- k : str (v ) for (k , v ) in (session_configuration or {}).items ()
444
- }
425
+ session_configuration = {k : str (v ) for (k , v ) in (session_configuration or {}).items ()}
445
426
self ._check_session_configuration (session_configuration )
446
427
# We want to receive proper Timestamp arrow types.
447
428
# We set it also in confOverlay in TExecuteStatementReq on a per query basic,
448
429
# but it doesn't hurt to also set for the whole session.
449
430
session_configuration [TIMESTAMP_AS_STRING_CONFIG ] = "false"
450
431
if catalog or schema :
451
- initial_namespace = ttypes .TNamespace (
452
- catalogName = catalog , schemaName = schema
453
- )
432
+ initial_namespace = ttypes .TNamespace (catalogName = catalog , schemaName = schema )
454
433
else :
455
434
initial_namespace = None
456
435
@@ -476,9 +455,7 @@ def close_session(self, session_handle) -> None:
476
455
finally :
477
456
self ._transport .close ()
478
457
479
- def _check_command_not_in_error_or_closed_state (
480
- self , op_handle , get_operations_resp
481
- ):
458
+ def _check_command_not_in_error_or_closed_state (self , op_handle , get_operations_resp ):
482
459
if get_operations_resp .operationState == ttypes .TOperationState .ERROR_STATE :
483
460
if get_operations_resp .displayMessage :
484
461
raise ServerOperationError (
@@ -513,17 +490,11 @@ def _poll_for_status(self, op_handle):
513
490
514
491
def _create_arrow_table (self , t_row_set , lz4_compressed , schema_bytes , description ):
515
492
if t_row_set .columns is not None :
516
- (
517
- arrow_table ,
518
- num_rows ,
519
- ) = ThriftBackend ._convert_column_based_set_to_arrow_table (
493
+ (arrow_table , num_rows ,) = ThriftBackend ._convert_column_based_set_to_arrow_table (
520
494
t_row_set .columns , description
521
495
)
522
496
elif t_row_set .arrowBatches is not None :
523
- (
524
- arrow_table ,
525
- num_rows ,
526
- ) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
497
+ (arrow_table , num_rows ,) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
527
498
t_row_set .arrowBatches , lz4_compressed , schema_bytes
528
499
)
529
500
else :
@@ -534,9 +505,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
534
505
def _convert_decimals_in_arrow_table (table , description ):
535
506
for (i , col ) in enumerate (table .itercolumns ()):
536
507
if description [i ][1 ] == "decimal" :
537
- decimal_col = col .to_pandas ().apply (
538
- lambda v : v if v is None else Decimal (v )
539
- )
508
+ decimal_col = col .to_pandas ().apply (lambda v : v if v is None else Decimal (v ))
540
509
precision , scale = description [i ][4 ], description [i ][5 ]
541
510
assert scale is not None
542
511
assert precision is not None
@@ -549,9 +518,7 @@ def _convert_decimals_in_arrow_table(table, description):
549
518
return table
550
519
551
520
@staticmethod
552
- def _convert_arrow_based_set_to_arrow_table (
553
- arrow_batches , lz4_compressed , schema_bytes
554
- ):
521
+ def _convert_arrow_based_set_to_arrow_table (arrow_batches , lz4_compressed , schema_bytes ):
555
522
ba = bytearray ()
556
523
ba += schema_bytes
557
524
n_rows = 0
@@ -597,9 +564,7 @@ def _convert_column_to_arrow_array(t_col):
597
564
for field in field_name_to_arrow_type .keys ():
598
565
wrapper = getattr (t_col , field )
599
566
if wrapper :
600
- return ThriftBackend ._create_arrow_array (
601
- wrapper , field_name_to_arrow_type [field ]
602
- )
567
+ return ThriftBackend ._create_arrow_array (wrapper , field_name_to_arrow_type [field ])
603
568
604
569
raise OperationalError ("Empty TColumn instance {}" .format (t_col ))
605
570
@@ -654,9 +619,7 @@ def map_type(t_type_entry):
654
619
else :
655
620
# Current thriftserver implementation should always return a primitiveEntry,
656
621
# even for complex types
657
- raise OperationalError (
658
- "Thrift protocol error: t_type_entry not a primitiveEntry"
659
- )
622
+ raise OperationalError ("Thrift protocol error: t_type_entry not a primitiveEntry" )
660
623
661
624
def convert_col (t_column_desc ):
662
625
return pyarrow .field (
@@ -674,9 +637,7 @@ def _col_to_description(col):
674
637
# Drop _TYPE suffix
675
638
cleaned_type = (name [:- 5 ] if name .endswith ("_TYPE" ) else name ).lower ()
676
639
else :
677
- raise OperationalError (
678
- "Thrift protocol error: t_type_entry not a primitiveEntry"
679
- )
640
+ raise OperationalError ("Thrift protocol error: t_type_entry not a primitiveEntry" )
680
641
681
642
if type_entry .primitiveEntry .type == ttypes .TTypeId .DECIMAL_TYPE :
682
643
qualifiers = type_entry .primitiveEntry .typeQualifiers .qualifiers
@@ -697,9 +658,7 @@ def _col_to_description(col):
697
658
698
659
@staticmethod
699
660
def _hive_schema_to_description (t_table_schema ):
700
- return [
701
- ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns
702
- ]
661
+ return [ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns ]
703
662
704
663
def _results_message_to_execute_response (self , resp , operation_state ):
705
664
if resp .directResults and resp .directResults .resultSetMetadata :
@@ -726,9 +685,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
726
685
or (not direct_results .resultSet )
727
686
or direct_results .resultSet .hasMoreRows
728
687
)
729
- description = self ._hive_schema_to_description (
730
- t_result_set_metadata_resp .schema
731
- )
688
+ description = self ._hive_schema_to_description (t_result_set_metadata_resp .schema )
732
689
schema_bytes = (
733
690
t_result_set_metadata_resp .arrowSchema
734
691
or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
@@ -768,8 +725,7 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
768
725
op_handle , initial_operation_status_resp
769
726
)
770
727
operation_state = (
771
- initial_operation_status_resp
772
- and initial_operation_status_resp .operationState
728
+ initial_operation_status_resp and initial_operation_status_resp .operationState
773
729
)
774
730
while not operation_state or operation_state in [
775
731
ttypes .TOperationState .RUNNING_STATE ,
@@ -784,21 +740,13 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
784
740
def _check_direct_results_for_error (t_spark_direct_results ):
785
741
if t_spark_direct_results :
786
742
if t_spark_direct_results .operationStatus :
787
- ThriftBackend ._check_response_for_error (
788
- t_spark_direct_results .operationStatus
789
- )
743
+ ThriftBackend ._check_response_for_error (t_spark_direct_results .operationStatus )
790
744
if t_spark_direct_results .resultSetMetadata :
791
- ThriftBackend ._check_response_for_error (
792
- t_spark_direct_results .resultSetMetadata
793
- )
745
+ ThriftBackend ._check_response_for_error (t_spark_direct_results .resultSetMetadata )
794
746
if t_spark_direct_results .resultSet :
795
- ThriftBackend ._check_response_for_error (
796
- t_spark_direct_results .resultSet
797
- )
747
+ ThriftBackend ._check_response_for_error (t_spark_direct_results .resultSet )
798
748
if t_spark_direct_results .closeOperation :
799
- ThriftBackend ._check_response_for_error (
800
- t_spark_direct_results .closeOperation
801
- )
749
+ ThriftBackend ._check_response_for_error (t_spark_direct_results .closeOperation )
802
750
803
751
def execute_command (
804
752
self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor
@@ -817,9 +765,7 @@ def execute_command(
817
765
sessionHandle = session_handle ,
818
766
statement = operation ,
819
767
runAsync = True ,
820
- getDirectResults = ttypes .TSparkGetDirectResults (
821
- maxRows = max_rows , maxBytes = max_bytes
822
- ),
768
+ getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
823
769
canReadArrowResult = True ,
824
770
canDecompressLZ4Result = lz4_compression ,
825
771
canDownloadResult = False ,
@@ -837,9 +783,7 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
837
783
838
784
req = ttypes .TGetCatalogsReq (
839
785
sessionHandle = session_handle ,
840
- getDirectResults = ttypes .TSparkGetDirectResults (
841
- maxRows = max_rows , maxBytes = max_bytes
842
- ),
786
+ getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
843
787
)
844
788
resp = self .make_request (self ._client .GetCatalogs , req )
845
789
return self ._handle_execute_response (resp , cursor )
@@ -857,9 +801,7 @@ def get_schemas(
857
801
858
802
req = ttypes .TGetSchemasReq (
859
803
sessionHandle = session_handle ,
860
- getDirectResults = ttypes .TSparkGetDirectResults (
861
- maxRows = max_rows , maxBytes = max_bytes
862
- ),
804
+ getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
863
805
catalogName = catalog_name ,
864
806
schemaName = schema_name ,
865
807
)
@@ -881,9 +823,7 @@ def get_tables(
881
823
882
824
req = ttypes .TGetTablesReq (
883
825
sessionHandle = session_handle ,
884
- getDirectResults = ttypes .TSparkGetDirectResults (
885
- maxRows = max_rows , maxBytes = max_bytes
886
- ),
826
+ getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
887
827
catalogName = catalog_name ,
888
828
schemaName = schema_name ,
889
829
tableName = table_name ,
@@ -907,9 +847,7 @@ def get_columns(
907
847
908
848
req = ttypes .TGetColumnsReq (
909
849
sessionHandle = session_handle ,
910
- getDirectResults = ttypes .TSparkGetDirectResults (
911
- maxRows = max_rows , maxBytes = max_bytes
912
- ),
850
+ getDirectResults = ttypes .TSparkGetDirectResults (maxRows = max_rows , maxBytes = max_bytes ),
913
851
catalogName = catalog_name ,
914
852
schemaName = schema_name ,
915
853
tableName = table_name ,
0 commit comments