Skip to content

Commit

Permalink
backport to version 2
Browse files Browse the repository at this point in the history
  • Loading branch information
andrefurlan-db committed Feb 16, 2024
1 parent a2f5939 commit 9cf92be
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 56 deletions.
33 changes: 33 additions & 0 deletions examples/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("pysqllogs.log")
fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
use_cloud_fetch=True,
max_download_threads = 2
) as connection:

with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
print(
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
)
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
try:
while True:
row = cursor.fetchone()
if row is None:
break
print(f"row: {row}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
31 changes: 7 additions & 24 deletions examples/query_execute.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,13 @@
import threading
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('pysqllogs.log')
fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN"),
# max_download_threads = 2
) as connection:
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()

with connection.cursor(
# arraysize=100
) as cursor:
# cursor.execute("SELECT * FROM range(0, 10000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
cursor.execute("SELECT * FROM andre.plotly_iot_dashboard.bronze_sensors limit 1000001")
try:
result = cursor.fetchall()
print(f"result length: {len(result)}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
# buffer_size_bytes=4857600
for row in result:
print(row)
5 changes: 3 additions & 2 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def get_next_downloaded_file(
return result
# Download was not successful for next download item. Fail
self._shutdown_manager()
raise ResultSetDownloadError(f"Download failed for result set starting at {next_row_offset}")
raise ResultSetDownloadError(
f"Download failed for result set starting at {next_row_offset}"
)

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -134,7 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None


def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
Expand Down
15 changes: 6 additions & 9 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from threading import get_ident
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logging.basicConfig(format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))
Expand Down Expand Up @@ -93,7 +92,6 @@ def run(self):
file, and signals to waiting threads that the download is finished and whether it was successful.
"""
self._reset()


try:
# Check if link is already expired or is expiring
Expand All @@ -109,7 +107,7 @@ def run(self):

# Get the file via HTTP request
response = http_get_with_retry(
url=self.result_link.fileLink,
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
Expand Down Expand Up @@ -154,7 +152,6 @@ def run(self):
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()


def _reset(self):
"""
Reset download-related flags for every retry of run()
Expand All @@ -179,9 +176,7 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug(
f"{(os.getpid(), get_ident())} - link expired"
)
logger.debug("link expired")
return True
return False

Expand Down Expand Up @@ -229,11 +224,13 @@ def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=6
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor ** attempts
wait_time = backoff_factor**attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(f"exceeded maximum number of retries ({max_retries}) while downloading result.")
logger.error(
f"exceeded maximum number of retries ({max_retries}) while downloading result."
)
return None
2 changes: 1 addition & 1 deletion src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ class CursorAlreadyClosedError(RequestError):


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
"""Thrown if there was an error during the download of a result set"""
20 changes: 9 additions & 11 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def attempt_request(attempt):

this_method_name = getattr(method, "__name__")

logger.debug("sending thrift request: {}(<REDACTED>)".format(this_method_name))
logger.debug(
"sending thrift request: {}(<REDACTED>)".format(this_method_name)
)
unsafe_logger.debug("Sending request: {}".format(request))

# These three lines are no-ops if the v3 retry policy is not in use
Expand All @@ -387,7 +389,9 @@ def attempt_request(attempt):

# We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses
logger.debug(
"received thrift response: {}(<REDACTED>)".format(type(response).__name__)
"received thrift response: {}(<REDACTED>)".format(
type(response).__name__
)
)
unsafe_logger.debug("Received response: {}".format(response))
return response
Expand Down Expand Up @@ -741,9 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
logger.debug(
f"received direct results"
)
logger.debug(f"received direct results")
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

Expand All @@ -756,9 +758,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
description=description,
)
else:
logger.debug(
f"must fetch results"
)
logger.debug(f"must fetch results")
arrow_queue_opt = None
return ExecuteResponse(
arrow_queue=arrow_queue_opt,
Expand Down Expand Up @@ -940,9 +940,7 @@ def get_columns(
return self._handle_execute_response(resp, cursor)

def _handle_execute_response(self, resp, cursor):
logger.debug(
f"got execute response"
)
logger.debug(f"got execute response")
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)

Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import decimal
from enum import Enum
import logging
import lz4.frame
from typing import Dict, List, Union, Any
import pyarrow
Expand All @@ -18,7 +19,7 @@
)

BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]

logger = logging.getLogger(__name__)

class ResultSetQueue(ABC):
@abstractmethod
Expand Down
20 changes: 12 additions & 8 deletions tests/unit/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_run_link_past_expiry_buffer(self, mock_time):
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False, status_code=500))))
@patch('time.time', return_value=1000)
def test_run_get_response_not_ok(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0)
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, max_retries = 5, backoff_factor = 2)
settings.download_timeout = 0
settings.use_proxy = False
result_link = Mock(expiryTime=1001)
Expand All @@ -56,7 +56,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session):
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 9))))
@patch('time.time', return_value=1000)
def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False)
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False, max_retries = 5, backoff_factor = 2)
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.startRowOffset = 0
result_link.rowCount = 100
Expand All @@ -70,7 +70,7 @@ def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session):
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200))))
@patch('time.time', return_value=1000)
def test_run_compressed_data_length_incorrect(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2)
settings.is_lz4_compressed = True
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.startRowOffset = 0
Expand All @@ -88,7 +88,7 @@ def test_run_compressed_data_length_incorrect(self, mock_time, mock_session):
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200, content=b"1234567890" * 10))))
@patch('time.time', return_value=1000)
def test_run_uncompressed_successful(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2)
settings.is_lz4_compressed = False
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.startRowOffset = 0
Expand All @@ -100,16 +100,20 @@ def test_run_uncompressed_successful(self, mock_time, mock_session):
assert d.is_file_downloaded_successfully
assert d.is_download_finished.is_set()

@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, status_code=200))))
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(
ok=True,
content=b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00',
status_code=200
))))
@patch('time.time', return_value=1000)
def test_run_compressed_successful(self, mock_time, mock_session):
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, max_retries = 5, backoff_factor = 2)
settings.is_lz4_compressed = True
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.startRowOffset = 0
result_link.rowCount = 100
mock_session.return_value.get.return_value.content = \
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
# mock_session.return_value.get.return_value.content = \
# b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

d = downloader.ResultSetDownloadHandler(settings, result_link)
d.run()
Expand Down

0 comments on commit 9cf92be

Please sign in to comment.