Skip to content

Commit

Permalink
fixes for cloud fetch
Browse files Browse the repository at this point in the history
backport to version 2

Signed-off-by: Andre Furlan <[email protected]>
  • Loading branch information
rcypher-databricks authored and andrefurlan-db committed Feb 16, 2024
1 parent a737ef3 commit e3d4efe
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 141 deletions.
33 changes: 33 additions & 0 deletions examples/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("pysqllogs.log")
fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
use_cloud_fetch=True,
max_download_threads = 2
) as connection:

with connection.cursor(arraysize=1000, buffer_size_bytes=54857600) as cursor:
print(
"executing query: SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2"
)
cursor.execute("SELECT * FROM range(0, 20000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
try:
while True:
row = cursor.fetchone()
if row is None:
break
print(f"row: {row}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
40 changes: 8 additions & 32 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResultSetDownloadHandler,
DownloadableResultSettings,
)
from databricks.sql.exc import ResultSetDownloadError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
Expand Down Expand Up @@ -81,13 +80,15 @@ def get_next_downloaded_file(

# Find next file
idx = self._find_next_file_index(next_row_offset)
# is this correct?
if idx is None:
self._shutdown_manager()
logger.debug("could not find next file index")
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
if handler.is_file_download_successful():
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
Expand All @@ -97,9 +98,11 @@ def get_next_downloaded_file(
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
# Download was not successful for next download item, force a retry
# Download was not successful for next download item. Fail
self._shutdown_manager()
return None
raise ResultSetDownloadError(
f"Download failed for result set starting at {next_row_offset}"
)

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -133,33 +136,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
Expand Down
115 changes: 89 additions & 26 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging
from dataclasses import dataclass

from datetime import datetime
import requests
import lz4.frame
import threading
import time

import os
from threading import get_ident
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))


@dataclass
class DownloadableResultSettings:
Expand All @@ -20,13 +23,17 @@ class DownloadableResultSettings:
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
download_max_retries (int): Number of consecutive download retries before shutting down.
max_retries (int): Number of consecutive download retries before shutting down.
backoff_factor (int): Factor to increase wait time between retries.
"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
max_retries: int = 5
backoff_factor: int = 2


class ResultSetDownloadHandler(threading.Thread):
Expand Down Expand Up @@ -57,16 +64,21 @@ def is_file_download_successful(self) -> bool:
else None
)
try:
logger.debug(
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return False
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
return self.is_file_downloaded_successfully

logger.debug(
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
except Exception as e:
logger.error(e)
return False
Expand All @@ -81,24 +93,36 @@ def run(self):
"""
self._reset()

# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return
try:
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
logger.debug(
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
response = http_get_with_retry(
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
)

if not response.ok:
self.is_file_downloaded_successfully = False
if not response:
logger.error(
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return

logger.debug(
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
Expand All @@ -109,15 +133,22 @@ def run(self):
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
success = len(self.result_file) == self.result_link.bytesNum
logger.debug(
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
self.is_file_downloaded_successfully = success
except Exception as e:
logger.debug(
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
logger.debug(
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()

Expand Down Expand Up @@ -145,6 +176,7 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug("link expired")
return True
return False

Expand All @@ -171,3 +203,34 @@ def decompress_data(compressed_data: bytes) -> bytes:
uncompressed_data += data
start += num_bytes
return uncompressed_data


def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
attempts = 0

while attempts < max_retries:
try:
session = requests.Session()
session.timeout = download_timeout
response = session.get(url)

# Check if the response status code is in the 2xx range for success
if response.status_code == 200:
return response
else:
logger.error(response)
except requests.RequestException as e:
print(f"request failed with exception: {e}")
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor**attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(
f"exceeded maximum number of retries ({max_retries}) while downloading result."
)
return None
4 changes: 4 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
16 changes: 14 additions & 2 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def attempt_request(attempt):

this_method_name = getattr(method, "__name__")

logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
logger.debug(
"sending thrift request: {}(<REDACTED>)".format(this_method_name)
)
unsafe_logger.debug("Sending request: {}".format(request))

# These three lines are no-ops if the v3 retry policy is not in use
Expand All @@ -387,7 +389,9 @@ def attempt_request(attempt):

# We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses
logger.debug(
"Received response: {}(<REDACTED>)".format(type(response).__name__)
"received thrift response: {}(<REDACTED>)".format(
type(response).__name__
)
)
unsafe_logger.debug("Received response: {}".format(response))
return response
Expand Down Expand Up @@ -741,6 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
logger.debug(f"received direct results")
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

Expand All @@ -753,6 +758,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
description=description,
)
else:
logger.debug(f"must fetch results")
arrow_queue_opt = None
return ExecuteResponse(
arrow_queue=arrow_queue_opt,
Expand Down Expand Up @@ -816,6 +822,10 @@ def execute_command(
):
assert session_handle is not None

logger.debug(
f"executing: cloud fetch: {use_cloud_fetch}, max rows: {max_rows}, max bytes: {max_bytes}"
)

spark_arrow_types = ttypes.TSparkArrowTypes(
timestampAsArrow=self._use_arrow_native_timestamps,
decimalAsArrow=self._use_arrow_native_decimals,
Expand Down Expand Up @@ -930,6 +940,7 @@ def get_columns(
return self._handle_execute_response(resp, cursor)

def _handle_execute_response(self, resp, cursor):
logger.debug(f"got execute response")
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)

Expand All @@ -950,6 +961,7 @@ def fetch_results(
arrow_schema_bytes,
description,
):
logger.debug("started to fetch results")
assert op_handle is not None

req = ttypes.TFetchResultsReq(
Expand Down
Loading

0 comments on commit e3d4efe

Please sign in to comment.