Skip to content

Complete Fetch Phase (EXTERNAL_LINKS disposition and ARROW format) #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 66 commits into
base: fetch-json-inline
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
6ec265f
[squashed from cloudfetch-sea] introduce external links + arrow funct…
varun-edachali-dbx Jun 16, 2025
b2ad5e6
reduce responsibility of Queue
varun-edachali-dbx Jun 16, 2025
66d0df6
reduce repetition in arrow tablee creation
varun-edachali-dbx Jun 16, 2025
eb7ec80
reduce redundant code in CloudFetchQueue
varun-edachali-dbx Jun 16, 2025
a3a8a4a
move chunk link progression to separate func
varun-edachali-dbx Jun 16, 2025
ea79bc8
remove redundant log
varun-edachali-dbx Jun 16, 2025
5b49405
improve logging
varun-edachali-dbx Jun 16, 2025
015fb76
remove reliance on schema_bytes in SEA
varun-edachali-dbx Jun 16, 2025
5380c7a
use more fetch methods
varun-edachali-dbx Jun 16, 2025
27b781f
remove redundant schema_bytes from parent constructor
varun-edachali-dbx Jun 16, 2025
238dc0a
only call get_chunk_link with non null chunk index
varun-edachali-dbx Jun 16, 2025
b3bb07e
align SeaResultSet structure with ThriftResultSet
varun-edachali-dbx Jun 16, 2025
13e6346
remvoe _fill_result_buffer from SeaResultSet
varun-edachali-dbx Jun 16, 2025
f90b4d4
reduce code repetition
varun-edachali-dbx Jun 16, 2025
fb53dd9
pre-fetch next chunk link on processing current
varun-edachali-dbx Jun 17, 2025
d893877
reduce nesting
varun-edachali-dbx Jun 17, 2025
a165f1c
line break after multi line pydoc
varun-edachali-dbx Jun 17, 2025
d68e4ea
re-introduce schema_bytes for better abstraction (likely temporary)
varun-edachali-dbx Jun 17, 2025
be17812
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 17, 2025
d33e5fd
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 17, 2025
e3cef5c
add GetChunksResponse
varun-edachali-dbx Jun 17, 2025
ac50669
remove changes to sea test
varun-edachali-dbx Jun 17, 2025
03cdc4f
re-introduce accidentally removed description extraction method
varun-edachali-dbx Jun 17, 2025
e1842d8
fix type errors (ssl_options, CHUNK_PATH_WITH_ID..., etc.)
varun-edachali-dbx Jun 17, 2025
89a46af
access ssl_options through connection
varun-edachali-dbx Jun 17, 2025
1d0b28b
DEBUG level
varun-edachali-dbx Jun 17, 2025
c8820d4
remove explicit multi chunk test
varun-edachali-dbx Jun 17, 2025
fe47787
move cloud fetch queues back into utils.py
varun-edachali-dbx Jun 17, 2025
74f59b7
remove excess docstrings
varun-edachali-dbx Jun 17, 2025
4b456b2
move ThriftCloudFetchQueue above SeaCloudFetchQueue
varun-edachali-dbx Jun 17, 2025
a4447a1
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 17, 2025
4883aff
correct patch module path in cloud fetch queue tests
varun-edachali-dbx Jun 17, 2025
cd3378c
correct add_link docstring
varun-edachali-dbx Jun 17, 2025
bc467d1
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 17, 2025
dd7dc6a
convert complex types to string if not _use_arrow_native_complex_types
varun-edachali-dbx Jun 23, 2025
dabba55
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 23, 2025
48ad7b3
Revert "Merge branch 'fetch-json-inline' into ext-links-sea"
varun-edachali-dbx Jun 23, 2025
a1f9b9c
reduce verbosity of ResultSetFilter docstring
varun-edachali-dbx Jun 23, 2025
3a999c0
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 23, 2025
c313c2b
Revert "Merge branch 'fetch-json-inline' into ext-links-sea"
varun-edachali-dbx Jun 23, 2025
3bc615e
Revert "reduce verbosity of ResultSetFilter docstring"
varun-edachali-dbx Jun 23, 2025
b6e1a10
Reapply "Merge branch 'fetch-json-inline' into ext-links-sea"
varun-edachali-dbx Jun 23, 2025
2df3d39
Revert "Merge branch 'fetch-json-inline' into ext-links-sea"
varun-edachali-dbx Jun 23, 2025
5e75fb5
remove un-necessary filters changes
varun-edachali-dbx Jun 23, 2025
20822e4
remove un-necessary backend changes
varun-edachali-dbx Jun 23, 2025
802d045
remove constants changes
varun-edachali-dbx Jun 23, 2025
f3f795a
remove changes in filters tests
varun-edachali-dbx Jun 23, 2025
f6c5950
remove unit test backend and JSON queue changes
varun-edachali-dbx Jun 23, 2025
d210ccd
remove changes in sea result set testing
varun-edachali-dbx Jun 23, 2025
22a953e
Revert "remove changes in sea result set testing"
varun-edachali-dbx Jun 23, 2025
3aed144
Revert "remove unit test backend and JSON queue changes"
varun-edachali-dbx Jun 23, 2025
0fe4da4
Revert "remove changes in filters tests"
varun-edachali-dbx Jun 23, 2025
0e3c0a1
Revert "remove constants changes"
varun-edachali-dbx Jun 23, 2025
93edb93
Revert "remove un-necessary backend changes"
varun-edachali-dbx Jun 23, 2025
871a44f
Revert "remove un-necessary filters changes"
varun-edachali-dbx Jun 23, 2025
08ca60d
Merge branch 'fetch-json-inline' into ext-links-sea
varun-edachali-dbx Jun 23, 2025
8c5cc77
working version
varun-edachali-dbx Jun 23, 2025
7f5c715
adopy _wait_until_command_done
varun-edachali-dbx Jun 23, 2025
9ef5fad
introduce metadata commands
varun-edachali-dbx Jun 23, 2025
44183db
use new backend structure
varun-edachali-dbx Jun 23, 2025
d59b351
constrain backend diff
varun-edachali-dbx Jun 23, 2025
1edc80a
remove changes to filters
varun-edachali-dbx Jun 23, 2025
f82658a
make _parse methods in models internal
varun-edachali-dbx Jun 23, 2025
54eb0a4
reduce changes in unit tests
varun-edachali-dbx Jun 23, 2025
8a138e8
allow empty schema bytes for alignment with SEA
varun-edachali-dbx Jun 25, 2025
82f9d6b
pass is_vl_op to Sea backend ExecuteResponse
varun-edachali-dbx Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/databricks/sql/backend/databricks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING

from databricks.sql.types import SSLOptions

if TYPE_CHECKING:
from databricks.sql.client import Cursor

Expand All @@ -25,6 +27,13 @@


class DatabricksClient(ABC):
def __init__(self, ssl_options: SSLOptions, **kwargs):
self._use_arrow_native_complex_types = kwargs.get(
"_use_arrow_native_complex_types", True
)
self._max_download_threads = kwargs.get("max_download_threads", 10)
self._ssl_options = ssl_options

# == Connection and Session Management ==
@abstractmethod
def open_session(
Expand Down Expand Up @@ -82,7 +91,7 @@ def execute_command(
lz4_compression: bool,
cursor: "Cursor",
use_cloud_fetch: bool,
parameters: List,
parameters: List[ttypes.TSparkParameter],
async_op: bool,
enforce_embedded_schema_correctness: bool,
) -> Union["ResultSet", None]:
Expand Down
41 changes: 36 additions & 5 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set

from databricks.sql.backend.sea.models.base import ResultManifest
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
from databricks.sql.backend.sea.utils.constants import (
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
ResultFormat,
Expand Down Expand Up @@ -41,6 +41,7 @@
GetStatementResponse,
CreateSessionResponse,
)
from databricks.sql.backend.sea.models.responses import GetChunksResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,6 +86,7 @@ class SeaDatabricksClient(DatabricksClient):
STATEMENT_PATH = BASE_PATH + "statements"
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"

# SEA constants
POLL_INTERVAL_SECONDS = 0.2
Expand Down Expand Up @@ -119,7 +121,7 @@ def __init__(
http_path,
)

self._max_download_threads = kwargs.get("max_download_threads", 10)
super().__init__(ssl_options=ssl_options, **kwargs)

# Extract warehouse ID from http_path
self.warehouse_id = self._extract_warehouse_id(http_path)
Expand All @@ -131,7 +133,7 @@ def __init__(
http_path=http_path,
http_headers=http_headers,
auth_provider=auth_provider,
ssl_options=ssl_options,
ssl_options=self._ssl_options,
**kwargs,
)

Expand Down Expand Up @@ -342,7 +344,7 @@ def _results_message_to_execute_response(

# Check for compression
lz4_compressed = (
response.manifest.result_compression == ResultCompression.LZ4_FRAME
response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
)

execute_response = ExecuteResponse(
Expand All @@ -351,7 +353,7 @@ def _results_message_to_execute_response(
description=description,
has_been_closed_server_side=False,
lz4_compressed=lz4_compressed,
is_staging_operation=False,
is_staging_operation=response.manifest.is_volume_operation,
arrow_schema_bytes=None,
result_format=response.manifest.format,
)
Expand Down Expand Up @@ -620,6 +622,35 @@ def get_execution_result(
manifest=response.manifest,
)

def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
"""
Get links for chunks starting from the specified index.
Args:
statement_id: The statement ID
chunk_index: The starting chunk index
Returns:
ExternalLink: External link for the chunk
"""

response_data = self.http_client._make_request(
method="GET",
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
)
response = GetChunksResponse.from_dict(response_data)

links = response.external_links
link = next((l for l in links if l.chunk_index == chunk_index), None)
if not link:
raise ServerOperationError(
f"No link found for chunk index {chunk_index}",
{
"operation-id": statement_id,
"diagnostic-info": None,
},
)

return link

# == Metadata Operations ==

def get_catalogs(
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/sea/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
GetChunksResponse,
)

__all__ = [
Expand All @@ -49,4 +50,5 @@
"ExecuteStatementResponse",
"GetStatementResponse",
"CreateSessionResponse",
"GetChunksResponse",
]
37 changes: 36 additions & 1 deletion src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
These models define the structures used in SEA API responses.
"""

from typing import Dict, Any
from typing import Dict, Any, List
from dataclasses import dataclass

from databricks.sql.backend.types import CommandState
Expand Down Expand Up @@ -154,3 +154,38 @@ class CreateSessionResponse:
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
"""Create a CreateSessionResponse from a dictionary."""
return cls(session_id=data.get("session_id", ""))


@dataclass
class GetChunksResponse:
"""Response from getting chunks for a statement."""

statement_id: str
external_links: List[ExternalLink]

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
"""Create a GetChunksResponse from a dictionary."""
external_links = []
if "external_links" in data:
for link_data in data["external_links"]:
external_links.append(
ExternalLink(
external_link=link_data.get("external_link", ""),
expiration=link_data.get("expiration", ""),
chunk_index=link_data.get("chunk_index", 0),
byte_count=link_data.get("byte_count", 0),
row_count=link_data.get("row_count", 0),
row_offset=link_data.get("row_offset", 0),
next_chunk_index=link_data.get("next_chunk_index"),
next_chunk_internal_link=link_data.get(
"next_chunk_internal_link"
),
http_headers=link_data.get("http_headers"),
)
)

return cls(
statement_id=data.get("statement_id", ""),
external_links=external_links,
)
11 changes: 3 additions & 8 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)

from databricks.sql.utils import (
ThriftResultSetQueueFactory,
_bound,
RequestErrorInfo,
NoRetryReason,
Expand Down Expand Up @@ -148,6 +147,8 @@ def __init__(
http_path,
)

super().__init__(ssl_options, **kwargs)

port = port or 443
if kwargs.get("_connection_uri"):
uri = kwargs.get("_connection_uri")
Expand All @@ -161,19 +162,13 @@ def __init__(
raise ValueError("No valid connection settings.")

self._initialize_retry_args(kwargs)
self._use_arrow_native_complex_types = kwargs.get(
"_use_arrow_native_complex_types", True
)

self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
self._use_arrow_native_timestamps = kwargs.get(
"_use_arrow_native_timestamps", True
)

# Cloud fetch
self._max_download_threads = kwargs.get("max_download_threads", 10)

self._ssl_options = ssl_options

self._auth_provider = auth_provider

# Connector version 3 retry approach
Expand Down
18 changes: 18 additions & 0 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ def _schedule_downloads(self):
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)

def add_link(self, link: TSparkArrowResultLink):
"""
Add more links to the download manager.

Args:
link: Link to add
"""

if link.rowCount <= 0:
return

logger.debug(
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
link.startRowOffset, link.rowCount
)
)
self._pending_links.append(link)

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self._pending_links = []
Expand Down
73 changes: 60 additions & 13 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Optional, TYPE_CHECKING
import json
from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING

import logging
import time
import pandas

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.models.base import (
ExternalLink,
ResultData,
ResultManifest,
)
from databricks.sql.utils import SeaResultSetQueueFactory

try:
import pyarrow
Expand All @@ -16,14 +23,10 @@
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.client import Connection
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import Row
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
from databricks.sql.utils import (
ColumnTable,
ColumnQueue,
JsonQueue,
SeaResultSetQueueFactory,
)
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,7 +255,7 @@ def __init__(
description=execute_response.description,
is_staging_operation=execute_response.is_staging_operation,
lz4_compressed=execute_response.lz4_compressed,
arrow_schema_bytes=execute_response.arrow_schema_bytes,
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
)

# Initialize results queue if not provided
Expand Down Expand Up @@ -476,6 +479,7 @@ def __init__(
result_data,
manifest,
str(execute_response.command_id.to_sea_statement_id()),
ssl_options=connection.session.ssl_options,
description=execute_response.description,
max_download_threads=sea_client.max_download_threads,
sea_client=sea_client,
Expand Down Expand Up @@ -548,6 +552,43 @@ def fetchall_json(self):

return results

def _convert_complex_types_to_string(
self, rows: "pyarrow.Table"
) -> "pyarrow.Table":
"""
Convert complex types (array, struct, map) to string representation.

Args:
rows: Input PyArrow table

Returns:
PyArrow table with complex types converted to strings
"""

if not pyarrow:
return rows

def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array":
python_values = col.to_pylist()
json_strings = [
(None if val is None else json.dumps(val)) for val in python_values
]
return pyarrow.array(json_strings, type=pyarrow.string())

converted_columns = []
for col in rows.columns:
converted_col = col
if (
pyarrow.types.is_list(col.type)
or pyarrow.types.is_large_list(col.type)
or pyarrow.types.is_struct(col.type)
or pyarrow.types.is_map(col.type)
):
converted_col = convert_complex_column_to_string(col)
converted_columns.append(converted_col)

return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names)

def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows as an Arrow table.
Expand All @@ -568,6 +609,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
results = self.results.next_n_rows(size)
self._next_row_index += results.num_rows

if not self.backend._use_arrow_native_complex_types:
results = self._convert_complex_types_to_string(results)

return results

def fetchall_arrow(self) -> "pyarrow.Table":
Expand All @@ -577,6 +621,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":
results = self.results.remaining_rows()
self._next_row_index += results.num_rows

if not self.backend._use_arrow_native_complex_types:
results = self._convert_complex_types_to_string(results)

return results

def fetchone(self) -> Optional[Row]:
Expand All @@ -590,7 +637,7 @@ def fetchone(self) -> Optional[Row]:
if isinstance(self.results, JsonQueue):
res = self._convert_json_table(self.fetchmany_json(1))
else:
raise NotImplementedError("fetchone only supported for JSON data")
res = self._convert_arrow_table(self.fetchmany_arrow(1))

return res[0] if res else None

Expand All @@ -610,7 +657,7 @@ def fetchmany(self, size: int) -> List[Row]:
if isinstance(self.results, JsonQueue):
return self._convert_json_table(self.fetchmany_json(size))
else:
raise NotImplementedError("fetchmany only supported for JSON data")
return self._convert_arrow_table(self.fetchmany_arrow(size))

def fetchall(self) -> List[Row]:
"""
Expand All @@ -622,4 +669,4 @@ def fetchall(self) -> List[Row]:
if isinstance(self.results, JsonQueue):
return self._convert_json_table(self.fetchall_json())
else:
raise NotImplementedError("fetchall only supported for JSON data")
return self._convert_arrow_table(self.fetchall_arrow())
Loading
Loading