-
Notifications
You must be signed in to change notification settings - Fork 112
Implement SeaDatabricksClient (Complete Execution Spec) #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 86 commits
138c2ae
3e3ab94
4a78165
0dac4aa
1b794c7
da5a6fe
686ade4
31e6c83
69ea238
66d7517
71feef9
ae9862f
d8aa69e
db139bc
b977b12
da615c0
0da04a6
ea9d456
8985c62
d9bcdbe
ee9fa1c
24c6152
67fd101
271fcaf
bf26ea3
ed7cf91
dae15e3
db5bbea
d5d3699
6137a3d
75b0773
4494dcd
4d0aeca
7cece5e
8977c06
0216d7a
4cb15fd
dee47f7
e385d5b
484064e
030edf8
3e22c6c
787f1f7
165c4f3
a6e40d0
52e3088
641c09b
8bd12d8
ffded6e
227f6b3
68657a3
3940eec
37813ba
267c9f4
2967119
47fd60d
982fdf2
9e14d48
be1997e
e8e8ee7
05ee4e7
cbace3f
c075b07
c62f76d
199402e
8ac574b
398ca70
b1acc5b
ef2a7ee
699942d
af8f74e
5540c5c
efe3881
36ab59b
1d57c99
df6dac2
ad0e527
ed446a0
38e4b5c
94879c0
1809956
da5260c
0385ffb
90bb09c
cd22389
82e0f8b
059657e
68d6276
91d28b2
c038f22
398909c
7ec43e1
ec95c76
5c1166a
df9f849
3eb582f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,49 @@ | ||
import logging | ||
import time | ||
import re | ||
from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set | ||
from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set | ||
|
||
from databricks.sql.backend.sea.utils.constants import ( | ||
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, | ||
ResultFormat, | ||
ResultDisposition, | ||
ResultCompression, | ||
WaitTimeout, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from databricks.sql.client import Cursor | ||
from databricks.sql.result_set import ResultSet | ||
|
||
from databricks.sql.backend.databricks_client import DatabricksClient | ||
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType | ||
from databricks.sql.backend.types import ( | ||
SessionId, | ||
CommandId, | ||
CommandState, | ||
BackendType, | ||
ExecuteResponse, | ||
) | ||
from databricks.sql.exc import ServerOperationError | ||
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient | ||
from databricks.sql.backend.sea.utils.constants import ( | ||
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, | ||
) | ||
from databricks.sql.thrift_api.TCLIService import ttypes | ||
from databricks.sql.types import SSLOptions | ||
|
||
from databricks.sql.backend.sea.models import ( | ||
ExecuteStatementRequest, | ||
GetStatementRequest, | ||
CancelStatementRequest, | ||
CloseStatementRequest, | ||
CreateSessionRequest, | ||
DeleteSessionRequest, | ||
StatementParameter, | ||
ExecuteStatementResponse, | ||
GetStatementResponse, | ||
CreateSessionResponse, | ||
) | ||
from databricks.sql.backend.sea.models.responses import ( | ||
parse_status, | ||
parse_manifest, | ||
parse_result, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -262,8 +286,79 @@ def get_allowed_session_configurations() -> List[str]: | |
""" | ||
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) | ||
|
||
# == Not Implemented Operations == | ||
# These methods will be implemented in future iterations | ||
def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Extract column description from a manifest object. | ||
|
||
Args: | ||
manifest_obj: The ResultManifest object containing schema information | ||
|
||
Returns: | ||
Optional[List]: A list of column tuples or None if no columns are found | ||
""" | ||
|
||
schema_data = manifest_obj.schema | ||
columns_data = schema_data.get("columns", []) | ||
|
||
if not columns_data: | ||
return None | ||
|
||
columns = [] | ||
for col_data in columns_data: | ||
if not isinstance(col_data, dict): | ||
continue | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) | ||
columns.append( | ||
( | ||
col_data.get("name", ""), # name | ||
col_data.get("type_name", ""), # type_code | ||
None, # display_size (not provided by SEA) | ||
None, # internal_size (not provided by SEA) | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
col_data.get("precision"), # precision | ||
col_data.get("scale"), # scale | ||
col_data.get("nullable", True), # null_ok | ||
) | ||
) | ||
|
||
return columns if columns else None | ||
|
||
def _results_message_to_execute_response(self, sea_response, command_id): | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Convert a SEA response to an ExecuteResponse and extract result data. | ||
|
||
Args: | ||
sea_response: The response from the SEA API | ||
command_id: The command ID | ||
|
||
Returns: | ||
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, | ||
result data object, and manifest object | ||
""" | ||
|
||
# Parse the response | ||
status = parse_status(sea_response) | ||
manifest_obj = parse_manifest(sea_response) | ||
result_data_obj = parse_result(sea_response) | ||
|
||
# Extract description from manifest schema | ||
description = self._extract_description_from_manifest(manifest_obj) | ||
|
||
# Check for compression | ||
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME" | ||
|
||
execute_response = ExecuteResponse( | ||
command_id=command_id, | ||
status=status.state, | ||
description=description, | ||
has_been_closed_server_side=False, | ||
lz4_compressed=lz4_compressed, | ||
is_staging_operation=False, | ||
arrow_schema_bytes=None, | ||
result_format=manifest_obj.format, | ||
) | ||
|
||
return execute_response, result_data_obj, manifest_obj | ||
|
||
def execute_command( | ||
self, | ||
|
@@ -274,41 +369,242 @@ def execute_command( | |
lz4_compression: bool, | ||
cursor: "Cursor", | ||
use_cloud_fetch: bool, | ||
parameters: List[ttypes.TSparkParameter], | ||
parameters: List, | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
async_op: bool, | ||
enforce_embedded_schema_correctness: bool, | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"execute_command is not yet implemented for SEA backend" | ||
) -> Union["ResultSet", None]: | ||
""" | ||
Execute a SQL command using the SEA backend. | ||
|
||
Args: | ||
operation: SQL command to execute | ||
session_id: Session identifier | ||
max_rows: Maximum number of rows to fetch | ||
max_bytes: Maximum number of bytes to fetch | ||
lz4_compression: Whether to use LZ4 compression | ||
cursor: Cursor executing the command | ||
use_cloud_fetch: Whether to use cloud fetch | ||
parameters: SQL parameters | ||
async_op: Whether to execute asynchronously | ||
enforce_embedded_schema_correctness: Whether to enforce schema correctness | ||
|
||
Returns: | ||
ResultSet: A SeaResultSet instance for the executed command | ||
""" | ||
|
||
if session_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA session ID") | ||
|
||
sea_session_id = session_id.to_sea_session_id() | ||
|
||
# Convert parameters to StatementParameter objects | ||
sea_parameters = [] | ||
if parameters: | ||
for param in parameters: | ||
sea_parameters.append( | ||
StatementParameter( | ||
name=param.name, | ||
value=param.value, | ||
type=param.type if hasattr(param, "type") else None, | ||
) | ||
) | ||
|
||
format = ( | ||
ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY | ||
).value | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
disposition = ( | ||
ResultDisposition.EXTERNAL_LINKS | ||
if use_cloud_fetch | ||
else ResultDisposition.INLINE | ||
).value | ||
result_compression = ( | ||
ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE | ||
).value | ||
|
||
request = ExecuteStatementRequest( | ||
warehouse_id=self.warehouse_id, | ||
session_id=sea_session_id, | ||
statement=operation, | ||
disposition=disposition, | ||
format=format, | ||
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, | ||
on_wait_timeout="CONTINUE", | ||
row_limit=max_rows, | ||
parameters=sea_parameters if sea_parameters else None, | ||
result_compression=result_compression, | ||
) | ||
|
||
response_data = self.http_client._make_request( | ||
method="POST", path=self.STATEMENT_PATH, data=request.to_dict() | ||
) | ||
response = ExecuteStatementResponse.from_dict(response_data) | ||
statement_id = response.statement_id | ||
if not statement_id: | ||
raise ServerOperationError( | ||
"Failed to execute command: No statement ID returned", | ||
{ | ||
"operation-id": None, | ||
"diagnostic-info": None, | ||
}, | ||
) | ||
|
||
command_id = CommandId.from_sea_statement_id(statement_id) | ||
|
||
# Store the command ID in the cursor | ||
cursor.active_command_id = command_id | ||
|
||
# If async operation, return and let the client poll for results | ||
if async_op: | ||
return None | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# For synchronous operation, wait for the statement to complete | ||
status = response.status | ||
state = status.state | ||
|
||
# Keep polling until we reach a terminal state | ||
while state in [CommandState.PENDING, CommandState.RUNNING]: | ||
time.sleep(0.5) # add a small delay to avoid excessive API calls | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make this delay as a constant and then later on maybe a connection prop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed it entirely for now because the Thrift implementation does not seem to have a polling interval. Should I add it back? I will add it as a constant if it should be present. |
||
state = self.get_query_state(command_id) | ||
|
||
if state != CommandState.SUCCEEDED: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to handle any other states gracefully? is it possible to reuse the state handling of thrift since i think CommandState is used in thrift client as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave the code another read, and looks like the Thrift implementation was explicitly checking for command closure and failure. I'll add this to the code. I do not think we should reuse the state handling of Thrift just yet - in the current Python connector implementation it checks for some specific state that is not uniquely accounted for by We should normalise the Exceptions raised before we abstract the state handling. Should we do this? For now, a quick solution would be to create our own characterisation by making |
||
raise ServerOperationError( | ||
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", | ||
{ | ||
"operation-id": command_id.to_sea_statement_id(), | ||
"diagnostic-info": None, | ||
}, | ||
) | ||
|
||
return self.get_execution_result(command_id, cursor) | ||
|
||
def cancel_command(self, command_id: CommandId) -> None: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"cancel_command is not yet implemented for SEA backend" | ||
""" | ||
Cancel a running command. | ||
|
||
Args: | ||
command_id: Command identifier to cancel | ||
|
||
Raises: | ||
ValueError: If the command ID is invalid | ||
""" | ||
|
||
if command_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA command ID") | ||
|
||
sea_statement_id = command_id.to_sea_statement_id() | ||
|
||
request = CancelStatementRequest(statement_id=sea_statement_id) | ||
self.http_client._make_request( | ||
method="POST", | ||
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), | ||
data=request.to_dict(), | ||
) | ||
|
||
def close_command(self, command_id: CommandId) -> None: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"close_command is not yet implemented for SEA backend" | ||
""" | ||
Close a command and release resources. | ||
|
||
Args: | ||
command_id: Command identifier to close | ||
|
||
Raises: | ||
ValueError: If the command ID is invalid | ||
""" | ||
|
||
if command_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA command ID") | ||
|
||
sea_statement_id = command_id.to_sea_statement_id() | ||
|
||
request = CloseStatementRequest(statement_id=sea_statement_id) | ||
self.http_client._make_request( | ||
method="DELETE", | ||
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), | ||
data=request.to_dict(), | ||
) | ||
|
||
def get_query_state(self, command_id: CommandId) -> CommandState: | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"get_query_state is not yet implemented for SEA backend" | ||
""" | ||
Get the state of a running query. | ||
|
||
Args: | ||
command_id: Command identifier | ||
|
||
Returns: | ||
CommandState: The current state of the command | ||
|
||
Raises: | ||
ValueError: If the command ID is invalid | ||
""" | ||
|
||
if command_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA command ID") | ||
|
||
sea_statement_id = command_id.to_sea_statement_id() | ||
|
||
request = GetStatementRequest(statement_id=sea_statement_id) | ||
response_data = self.http_client._make_request( | ||
method="GET", | ||
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), | ||
data=request.to_dict(), | ||
) | ||
|
||
# Parse the response | ||
response = GetStatementResponse.from_dict(response_data) | ||
return response.status.state | ||
|
||
def get_execution_result( | ||
self, | ||
command_id: CommandId, | ||
cursor: "Cursor", | ||
): | ||
"""Not implemented yet.""" | ||
raise NotImplementedError( | ||
"get_execution_result is not yet implemented for SEA backend" | ||
) -> "ResultSet": | ||
varun-edachali-dbx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Get the result of a command execution. | ||
|
||
Args: | ||
command_id: Command identifier | ||
cursor: Cursor executing the command | ||
|
||
Returns: | ||
ResultSet: A SeaResultSet instance with the execution results | ||
|
||
Raises: | ||
ValueError: If the command ID is invalid | ||
""" | ||
|
||
if command_id.backend_type != BackendType.SEA: | ||
raise ValueError("Not a valid SEA command ID") | ||
|
||
sea_statement_id = command_id.to_sea_statement_id() | ||
|
||
# Create the request model | ||
request = GetStatementRequest(statement_id=sea_statement_id) | ||
|
||
# Get the statement result | ||
response_data = self.http_client._make_request( | ||
method="GET", | ||
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), | ||
data=request.to_dict(), | ||
) | ||
|
||
# Create and return a SeaResultSet | ||
from databricks.sql.result_set import SeaResultSet | ||
|
||
# Convert the response to an ExecuteResponse and extract result data | ||
( | ||
execute_response, | ||
result_data, | ||
manifest, | ||
) = self._results_message_to_execute_response(response_data, command_id) | ||
|
||
return SeaResultSet( | ||
connection=cursor.connection, | ||
execute_response=execute_response, | ||
sea_client=self, | ||
buffer_size_bytes=cursor.buffer_size_bytes, | ||
arraysize=cursor.arraysize, | ||
result_data=result_data, | ||
manifest=manifest, | ||
) | ||
|
||
# == Metadata Operations == | ||
|
Uh oh!
There was an error while loading. Please reload this page.