From 138c2aebab99659d1c970fa70e4a431fec78aae2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:24:22 +0000 Subject: [PATCH 001/105] [squash from exec-sea] bring over execution phase changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 30 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 99 ++- src/databricks/sql/backend/types.py | 64 +- src/databricks/sql/client.py | 1 - src/databricks/sql/result_set.py | 234 ++++-- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 7 - tests/unit/test_client.py | 22 +- tests/unit/test_fetches.py | 13 +- tests/unit/test_fetches_bench.py | 3 +- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + tests/unit/test_thrift_backend.py | 55 +- 22 files changed, 2375 insertions(+), 366 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..e03d6f235 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..e145e4e58 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..fc8595839 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2d..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b1172..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a7950870..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From 3e3ab94e8fa3dd02e4b05b5fc35939aef57793a2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:31:37 +0000 Subject: [PATCH 002/105] remove excess test Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +++----------------- 1 file changed, 14 insertions(+), 110 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() From 4a781653375d8f06dd7d9ad745446e49a355c680 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:33:02 +0000 Subject: [PATCH 003/105] add docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..cd347d9ab 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,33 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 0dac4aaf90dba50151dd7565adee270a794e8330 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:34:49 +0000 Subject: [PATCH 004/105] remvoe exec func in sea backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 360 +++------------------- 1 file changed, 35 insertions(+), 325 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 1b794c7df6f5e414ef793a5da0f2b8ba19c9bc61 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:35:40 +0000 Subject: [PATCH 005/105] remove excess files Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 -------------- tests/unit/test_result_set_filter.py | 246 ----------------------- tests/unit/test_sea_result_set.py | 275 -------------------------- 3 files changed, 664 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index f666fd613..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } - return mock_response - - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - execute_response=execute_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response - - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() From da5a6fe7511e927c511d61adb222b8a6a0da14d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:39:11 +0000 Subject: [PATCH 006/105] remove excess models Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/__init__.py | 30 ----- src/databricks/sql/backend/sea/models/base.py | 68 ----------- .../sql/backend/sea/models/requests.py | 110 +----------------- .../sql/backend/sea/models/responses.py | 95 +-------------- 4 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass From 686ade4fbf8e43a053b61f27220066852682167e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:40:50 +0000 Subject: [PATCH 007/105] remove excess sea backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 755 ++++----------------------------- 1 file changed, 94 insertions(+), 661 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 31e6c8305154e9c6384b422be35ac17b6f851e0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:54:05 +0000 Subject: [PATCH 008/105] cleanup Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 8 +- src/databricks/sql/backend/types.py | 38 ++++---- src/databricks/sql/result_set.py | 91 ++++++++------------ 3 files changed, 65 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e03d6f235..21a6befbe 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -913,7 +913,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return CommandState.from_thrift_state(operation_state) + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Invalid operation state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -1175,7 +1178,6 @@ def _handle_execute_response(self, resp, cursor): execute_response, arrow_schema_bytes, ) = self._results_message_to_execute_response(resp, final_operation_state) - execute_response.command_id = command_id return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): @@ -1237,7 +1239,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3107083fb 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -285,9 +285,6 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, - operation_type: Optional[int] = None, - has_result_set: bool = False, - modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -296,17 +293,34 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) - operation_type: The operation type (only used for Thrift) - has_result_set: Whether the command has a result set - modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret - self.operation_type = operation_type - self.has_result_set = has_result_set - self.modified_row_count = modified_row_count + + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) @classmethod def from_thrift_handle(cls, operation_handle): @@ -329,9 +343,6 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, - operation_handle.operationType, - operation_handle.hasResultSet, - operation_handle.modifiedRowCount, ) @classmethod @@ -364,9 +375,6 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, - operationType=self.operation_type, - hasResultSet=self.has_result_set, - modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fc8595839..12ee129cf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -5,6 +5,8 @@ import time import pandas +from databricks.sql.backend.sea.backend import SeaDatabricksClient + try: import pyarrow except ImportError: @@ -13,6 +15,7 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError @@ -31,21 +34,37 @@ class ResultSet(ABC): def __init__( self, - connection, - backend, + connection: "Connection", + backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, - command_id=None, - status=None, + command_id: CommandId, + status: CommandState, has_been_closed_server_side: bool = False, has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, ): - """Initialize the base ResultSet with common properties.""" + """ + A ResultSet manages the results of a single command. + + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation + """ + self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -240,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -387,12 +406,11 @@ class SeaResultSet(ResultSet): def __init__( self, - connection, - sea_client, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - execute_response=None, - sea_response=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -402,56 +420,21 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + execute_response: Response from the execute command """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 69ea23811e03705998baba569bcda259a0646de5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:56:09 +0000 Subject: [PATCH 009/105] re-introduce get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3107083fb..7a276c102 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -299,7 +299,6 @@ def __init__( self.guid = guid self.secret = secret - def __str__(self) -> str: """ Return a string representation of the CommandId. diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 12ee129cf..1fee995e5 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -59,12 +59,12 @@ def __init__( has_been_closed_server_side: Whether the command has been closed on the server has_more_rows: Whether the command has more rows results_queue: The results queue - description: column description of the results + description: column description of the results is_staging_operation: Whether the command is a staging operation """ self.connection = connection - self.backend = backend + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -400,6 +400,23 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" From 66d75171991f9fcc98d541729a3127aea0d37a81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:57:53 +0000 Subject: [PATCH 010/105] remove SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 72 -------------------------------- 1 file changed, 72 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 1fee995e5..eaabcc186 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -416,75 +416,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 71feef96b3a41889a5cd9313fc81910cebd7a084 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:01:22 +0000 Subject: [PATCH 011/105] clean imports and attributes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 1 + src/databricks/sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/result_set.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index cd347d9ab..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -88,6 +88,7 @@ def execute_command( ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles the response. It can operate in both synchronous and asynchronous modes. diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index eaabcc186..a33fc977d 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation From ae9862f90e7cf0a4949d6b1c7e04fdbae222c2d8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:05:53 +0000 Subject: [PATCH 012/105] pass CommandId to ExecResp Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 7 ++++++- src/databricks/sql/result_set.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 21a6befbe..316cf24a0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -866,9 +867,13 @@ def get_execution_result( ssl_options=self._ssl_options, ) + status = CommandState.from_thrift_state(resp.status) + if status is None: + raise ValueError(f"Invalid operation state: {resp.status}") + execute_response = ExecuteResponse( command_id=command_id, - status=resp.status, + status=status, description=description, has_more_rows=has_more_rows, results_queue=queue, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a33fc977d..a0cb73732 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From d8aa69e40438c33014e0d5afaec6a4175e64bea8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:08:04 +0000 Subject: [PATCH 013/105] remove changes in types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 57 +++++++++-------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 7a276c102..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -285,6 +262,9 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -293,11 +273,17 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count def __str__(self) -> str: """ @@ -332,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -342,6 +329,9 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, ) @classmethod @@ -374,6 +364,9 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): @@ -401,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From db139bc1179bb7cab6ec6f283cdfa0646b04b01b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:09:35 +0000 Subject: [PATCH 014/105] add back essential types (ExecResponse, from_sea_state) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..958eaa289 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,27 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + + class BackendType(Enum): """ @@ -394,3 +416,18 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False \ No newline at end of file From b977b1210a5d39543b8a3734128ba820e597337f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:11:23 +0000 Subject: [PATCH 015/105] fix fetch types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 4 ++-- src/databricks/sql/result_set.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 958eaa289..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -102,7 +102,6 @@ def from_sea_state(cls, state: str) -> Optional["CommandState"]: return state_mapping.get(state, None) - class BackendType(Enum): """ Enum representing the type of backend @@ -417,6 +416,7 @@ def to_hex_guid(self) -> str: else: return str(self.guid) + @dataclass class ExecuteResponse: """Response from executing a SQL command.""" @@ -430,4 +430,4 @@ class ExecuteResponse: results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True - is_staging_operation: bool = False \ No newline at end of file + is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0cb73732..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass From da615c0db8ba2037c106b533331cf1ca1c9f49f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:12:45 +0000 Subject: [PATCH 016/105] excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 0da04a6f1086998927a28759fc67da4e2c8c71c6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:15:59 +0000 Subject: [PATCH 017/105] reduce diff by maintaining logs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 316cf24a0..821559ad3 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -800,7 +800,7 @@ def _results_message_to_execute_response(self, resp, operation_state): status = CommandState.from_thrift_state(operation_state) if status is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return ( ExecuteResponse( From ea9d456ee9ca47434618a079698fa166b6c8a308 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:47:54 +0000 Subject: [PATCH 018/105] fix int test types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- tests/e2e/common/retry_test_mixins.py | 2 +- tests/e2e/test_driver.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 821559ad3..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -867,9 +867,7 @@ def get_execution_result( ssl_options=self._ssl_options, ) - status = CommandState.from_thrift_state(resp.status) - if status is None: - raise ValueError(f"Invalid operation state: {resp.status}") + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..dd509c062 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -326,7 +326,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", + "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False, ): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 22897644f..8cfed7c28 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -933,12 +933,12 @@ def test_result_set_close(self): result_set = cursor.active_result_set assert result_set is not None - initial_op_state = result_set.op_state + initial_op_state = result_set.status result_set.close() - assert result_set.op_state == CommandState.CLOSED - assert result_set.op_state != initial_op_state + assert result_set.status == CommandState.CLOSED + assert result_set.status != initial_op_state # Closing the result set again should be a no-op and not raise exceptions result_set.close() From 8985c624bcdbb7e0abfa73b7a1a2dbad15b4e1ec Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:55:24 +0000 Subject: [PATCH 019/105] [squashed from exec-sea] init execution func Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 118 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + 15 files changed, 2166 insertions(+), 219 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..a4beda629 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,16 +403,96 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d9bcdbef396433e01b298fca9a27b1bce2b1414b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:13 +0000 Subject: [PATCH 020/105] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +----- .../sql/backend/databricks_client.py | 30 ++ src/databricks/sql/backend/sea/backend.py | 360 ++---------------- .../sql/backend/sea/models/__init__.py | 30 -- src/databricks/sql/backend/sea/models/base.py | 68 ---- .../sql/backend/sea/models/requests.py | 110 +----- .../sql/backend/sea/models/responses.py | 95 +---- src/databricks/sql/backend/types.py | 64 ++-- 8 files changed, 107 insertions(+), 774 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,6 +16,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -86,6 +88,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -308,6 +285,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -394,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From ee9fa1c972bad75557ac0671d5eef96c0a0cff21 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:59 +0000 Subject: [PATCH 021/105] remove ResultSetFilter functionality Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 --------------- tests/unit/test_result_set_filter.py | 246 -------------------------- 2 files changed, 389 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 From 24c6152e9c2c003aa3074057c3d7d6e98d8d1916 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:06:23 +0000 Subject: [PATCH 022/105] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 +- tests/unit/test_sea_backend.py | 755 ++++------------------------ 2 files changed, 132 insertions(+), 662 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,26 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -394,3 +415,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 67fd1012f9496724aa05183f82d9c92f0c40f1ed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:10:48 +0000 Subject: [PATCH 023/105] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 2 - src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/result_set.py | 91 +++++++++---------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a4beda629..dd61408db 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -402,6 +402,33 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -413,53 +440,19 @@ def _get_schema_description(table_schema_message): execute_response: Response from the execute command (new style) sea_response: Direct SEA response (legacy style) """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 271fcafbb04e7c5e08423b7536dac57f9595c5b6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:12:13 +0000 Subject: [PATCH 024/105] even more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_session.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dd61408db..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From bf26ea3e4dae441d0e82d1f55c3da36ee2282568 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:19:46 +0000 Subject: [PATCH 025/105] remove sea response as init option Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 103 ++++-------------------------- 1 file changed, 14 insertions(+), 89 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f666fd613..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -27,38 +27,6 @@ def mock_sea_client(self): """Create a mock SEA client.""" return Mock() - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - @pytest.fixture def execute_response(self): """Create a sample execute response.""" @@ -72,78 +40,35 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, - sea_client=mock_sea_client, execute_response=execute_response, + sea_client=mock_sea_client, buffer_size_bytes=1000, arraysize=100, ) # Verify basic properties - assert result_set.statement_id == "test-statement-123" + assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA assert result_set.connection == mock_connection assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response + assert result_set.description == execute_response.description - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -157,13 +82,13 @@ def test_close(self, mock_connection, mock_sea_client, sea_response): assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -178,14 +103,14 @@ def test_close_when_already_closed_server_side( assert result_set.status == CommandState.CLOSED def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -199,13 +124,13 @@ def test_close_when_connection_closed( assert result_set.status == CommandState.CLOSED def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that unimplemented methods raise NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -258,13 +183,13 @@ def test_unimplemented_methods( pass def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that _fill_results_buffer raises NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -272,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From ed7cf9138e937774546fa0f3e793a6eb8768060a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:06:36 +0000 Subject: [PATCH 026/105] exec test example scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 147 ++++++++++------ examples/experimental/tests/__init__.py | 1 + .../tests/test_sea_async_query.py | 165 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 91 ++++++++++ .../experimental/tests/test_sea_session.py | 70 ++++++++ .../experimental/tests/test_sea_sync_query.py | 143 +++++++++++++++ 6 files changed, 566 insertions(+), 51 deletions(-) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..33b5af334 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,111 @@ +""" +Main script to run all SEA connector tests. + +This script imports and runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import importlib.util +from typing import Dict, Callable, List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. +# Define test modules and their main test functions +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + +def load_test_function(module_name: str) -> Callable: + """Load a test function from a module.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" + ) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get the main test function (assuming it starts with "test_") + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + # For sync and async query modules, we want the main function that runs both tests + if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": + return getattr(module, name) - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. + # Fallback to the first test function found + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + return getattr(module, name) - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ + raise ValueError(f"No test function found in module {module_name}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) + for module_name in TEST_MODULES: + try: + test_func = load_test_function(module_name) + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = test_func() + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + results.append((module_name, False)) - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") + return results + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent - ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) + passed = sum(1 for _, success in results if success) + total = len(results) - logger.info("SEA session test completed successfully") + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") if __name__ == "__main__": - test_sea_session() + # Check if required environment variables are set + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) \ No newline at end of file diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..5e1a8a58b --- /dev/null +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a4f3702f9 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,165 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch enabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch disabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..ba760b61a --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,91 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") + cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..c0f6817da --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,70 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..4879e587a --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,143 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file From dae15e37b6161740481084c405aeff84278c73cd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:10:23 +0000 Subject: [PATCH 027/105] formatting (black) Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 53 ++++++++------ examples/experimental/tests/__init__.py | 1 - .../tests/test_sea_async_query.py | 72 +++++++++++++------ .../experimental/tests/test_sea_metadata.py | 27 ++++--- .../experimental/tests/test_sea_session.py | 5 +- .../experimental/tests/test_sea_sync_query.py | 48 +++++++++---- 6 files changed, 133 insertions(+), 73 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 33b5af334..b03f8ff64 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,90 +22,99 @@ "test_sea_metadata", ] + def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "tests", - f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") + def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback + logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results + def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") + if __name__ == "__main__": # Check if required environment variables are set - required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) \ No newline at end of file + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index 5e1a8a58b..e69de29bb 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -1 +0,0 @@ -# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a4f3702f9..a776377c3 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -33,7 +33,9 @@ def test_sea_async_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -51,30 +53,39 @@ def test_sea_async_query_with_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch enabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -100,7 +111,9 @@ def test_sea_async_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -119,30 +132,39 @@ def test_sea_async_query_without_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch disabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -152,14 +174,18 @@ def test_sea_async_query_exec(): Run both asynchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_async_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index ba760b61a..c715e5984 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -28,9 +28,11 @@ def test_sea_metadata(): "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." ) return False - + if not catalog: - logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) return False try: @@ -55,37 +57,42 @@ def test_sea_metadata(): logger.info("Fetching catalogs...") cursor.catalogs() logger.info("Successfully fetched catalogs") - + # Test schemas logger.info(f"Fetching schemas for catalog '{catalog}'...") cursor.schemas(catalog_name=catalog) logger.info("Successfully fetched schemas") - + # Test tables logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") cursor.tables(catalog_name=catalog, schema_name="default") logger.info("Successfully fetched tables") - + # Test columns for a specific table # Using a common table that should exist in most environments - logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") - cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="information_schema" + ) logger.info("Successfully fetched columns") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error during SEA metadata test: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_metadata() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py index c0f6817da..516c1bbb8 100644 --- a/examples/experimental/tests/test_sea_session.py +++ b/examples/experimental/tests/test_sea_session.py @@ -55,16 +55,17 @@ def test_sea_session(): logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_session() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 4879e587a..07be8aafc 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -31,7 +31,9 @@ def test_sea_sync_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -49,20 +51,25 @@ def test_sea_sync_query_with_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch enabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -88,7 +95,9 @@ def test_sea_sync_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -107,20 +116,25 @@ def test_sea_sync_query_without_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch disabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -130,14 +144,18 @@ def test_sea_sync_query_exec(): Run both synchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) From db5bbea88eabcde2d0b86811391297baf8471c70 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:35:08 +0000 Subject: [PATCH 028/105] [squashed from sea-exec] merge sea stuffs Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 51 +- examples/experimental/tests/__init__.py | 1 + .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 359 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 106 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 92 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 30 +- tests/unit/test_session.py | 5 + 16 files changed, 1805 insertions(+), 232 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..128bc1aa1 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,99 +22,90 @@ "test_sea_metadata", ] - def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") - def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback - logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results - def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") - if __name__ == "__main__": # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" - ) + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index e69de29bb..5e1a8a58b 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..10100e86e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,221 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else None + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +513,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +538,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +573,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +621,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..e26b32e0a 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,107 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..2d4f3f346 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,14 +403,76 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") return [ (column.name, map_col_type(column.datatype), None, None, None, None, None) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..072b597a8 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -40,8 +40,36 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } return mock_response + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -197,4 +225,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d5d3699cea5c5e67a48c5e789ebdd66964f1e975 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:44:58 +0000 Subject: [PATCH 029/105] remove excess changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 51 ++++++++++++--------- examples/experimental/tests/__init__.py | 1 - 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 128bc1aa1..b03f8ff64 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,90 +22,99 @@ "test_sea_metadata", ] + def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "tests", - f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") + def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback + logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results + def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") + if __name__ == "__main__": # Check if required environment variables are set - required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index 5e1a8a58b..e69de29bb 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -1 +0,0 @@ -# This file makes the tests directory a Python package \ No newline at end of file From 6137a3dca8ea8d0c2105a175b99f45e77fa25f5b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:47:07 +0000 Subject: [PATCH 030/105] remove excess removed docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 75b077320c196104e47af149b379ebc4e95463e3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:48:33 +0000 Subject: [PATCH 031/105] remove excess changes in backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 3 ++- src/databricks/sql/backend/types.py | 25 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,10 +85,8 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. - Args: state: SEA state string - Returns: CommandState: The corresponding CommandState enum value """ @@ -308,6 +306,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +339,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None From 4494dcda4a503e6138e5761bc6155114d840be86 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:50:56 +0000 Subject: [PATCH 032/105] remove excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 4d0aeca0a2e9d887274cbdbd19c6f471f1a381a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:53:52 +0000 Subject: [PATCH 033/105] remove accidentally removed _get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 74 +++----------------------------- 1 file changed, 6 insertions(+), 68 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 2d4f3f346..e0b0289e6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -403,76 +403,14 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - - # Call parent constructor with common attributes - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 """ - raise NotImplementedError("fetchone is not implemented for SEA backend") + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ return [ (column.name, map_col_type(column.datatype), None, None, None, None, None) From 7cece5e0870cd31943e72c86888d98ed4e09c17c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:56:24 +0000 Subject: [PATCH 034/105] remove unnecessary init with sea_response tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 072b597a8..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -40,36 +40,8 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -225,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From 8977c06a27a68ae7c144a482e32c7bee1e18eaa3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:57:58 +0000 Subject: [PATCH 035/105] rmeove unnecessary changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e0b0289e6..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From 0216d7ac6de96ece431f8bdd0d31c0acb1c28324 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:07:04 +0000 Subject: [PATCH 036/105] formatting (black) Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..b691872af 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -197,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() From 4cb15fdaa8318b046f2ac082edb10679e7c7a501 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:47:34 +0000 Subject: [PATCH 037/105] improved models and filters from cloudfetch-sea branch Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 61 +++++--- src/databricks/sql/backend/sea/models/base.py | 13 +- .../sql/backend/sea/models/requests.py | 16 +- .../sql/backend/sea/models/responses.py | 146 ++++++++++++++++-- 4 files changed, 187 insertions(+), 49 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 7f48b6179..32fa78be4 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -9,14 +9,20 @@ List, Optional, Any, + Dict, Callable, + TypeVar, + Generic, + cast, TYPE_CHECKING, ) -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse, CommandId +from databricks.sql.backend.sea.models.base import ResultData -from databricks.sql.result_set import SeaResultSet +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet, SeaResultSet logger = logging.getLogger(__name__) @@ -43,26 +49,35 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data + # Get all remaining rows + original_index = result_set.results.cur_row_index + result_set.results.cur_row_index = 0 # Reset to beginning + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Import SeaResultSet here to avoid circular imports + from databricks.sql.result_set import SeaResultSet + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_more_rows=result_set._has_more_rows, + results_queue=JsonQueue(filtered_rows), + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=False, + is_staging_operation=False, + ) + return SeaResultSet( connection=result_set.connection, - sea_response=filtered_response, + execute_response=execute_response, sea_client=result_set.backend, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, @@ -92,6 +107,8 @@ def filter_by_column_values( allowed_values = [v.upper() for v in allowed_values] # Determine the type of result set and apply appropriate filtering + from databricks.sql.result_set import SeaResultSet + if isinstance(result_set, SeaResultSet): return ResultSetFilter._filter_sea_result_set( result_set, @@ -137,7 +154,7 @@ def filter_tables_by_type( table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES ) - # Table type is typically in the 6th column (index 5) + # Table type is the 6th column (index 5) return ResultSetFilter.filter_by_column_values( result_set, 5, valid_types, case_sensitive=False ) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 671f7be13..6175b4ca0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -34,6 +34,12 @@ class ExternalLink: external_link: str expiration: str chunk_index: int + byte_count: int = 0 + row_count: int = 0 + row_offset: int = 0 + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + http_headers: Optional[Dict[str, str]] = None @dataclass @@ -61,8 +67,11 @@ class ColumnInfo: class ResultManifest: """Manifest information for a result set.""" - schema: List[ColumnInfo] + format: str + schema: Dict[str, Any] # Will contain column information total_row_count: int total_byte_count: int + total_chunk_count: int truncated: bool = False - chunk_count: Optional[int] = None + chunks: Optional[List[Dict[str, Any]]] = None + result_compression: Optional[str] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index e26b32e0a..58921d793 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -21,18 +21,16 @@ class StatementParameter: class ExecuteStatementRequest: """Request to execute a SQL statement.""" - warehouse_id: str - statement: str session_id: str + statement: str + warehouse_id: str disposition: str = "EXTERNAL_LINKS" format: str = "JSON_ARRAY" + result_compression: Optional[str] = None + parameters: Optional[List[StatementParameter]] = None wait_timeout: str = "10s" on_wait_timeout: str = "CONTINUE" row_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert the request to a dictionary for JSON serialization.""" @@ -49,12 +47,6 @@ def to_dict(self) -> Dict[str, Any]: if self.row_limit is not None and self.row_limit > 0: result["row_limit"] = self.row_limit - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - if self.result_compression: result["result_compression"] = self.result_compression diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..6b5067506 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -13,6 +13,8 @@ ResultManifest, ResultData, ServiceError, + ExternalLink, + ColumnInfo, ) @@ -37,20 +39,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": error_code=error_data.get("error_code"), ) - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") status = StatementStatus( - state=state, + state=CommandState.from_sea_state(status_data.get("state", "")), error=error, sql_state=status_data.get("sql_state"), ) + # Parse manifest + manifest = None + if "manifest" in data: + manifest_data = data["manifest"] + manifest = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + # Parse result data + result = None + if "result" in data: + result_data = data["result"] + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + result = ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + return cls( statement_id=data.get("statement_id", ""), status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed + manifest=manifest, + result=result, ) @@ -75,21 +119,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": error_code=error_data.get("error_code"), ) - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, + state=CommandState.from_sea_state(status_data.get("state", "")), error=error, sql_state=status_data.get("sql_state"), ) + # Parse manifest + manifest = None + if "manifest" in data: + manifest_data = data["manifest"] + manifest = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + # Parse result data + result = None + if "result" in data: + result_data = data["result"] + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + result = ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + return cls( statement_id=data.get("statement_id", ""), status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed + manifest=manifest, + result=result, ) @@ -103,3 +188,38 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) + + +@dataclass +class GetChunksResponse: + """Response from getting chunks for a statement.""" + + statement_id: str + external_links: List[ExternalLink] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": + """Create a GetChunksResponse from a dictionary.""" + external_links = [] + if "external_links" in data: + for link_data in data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + return cls( + statement_id=data.get("statement_id", ""), + external_links=external_links, + ) From dee47f7f4558a8c7336c86bbd5a20bda3f4a9787 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 03:45:23 +0000 Subject: [PATCH 038/105] filters stuff (align with JDBC) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 18 ++++++------------ src/databricks/sql/backend/sea/backend.py | 3 --- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 32fa78be4..9fa0a5535 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -49,32 +49,26 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ - # Get all remaining rows - original_index = result_set.results.cur_row_index - result_set.results.cur_row_index = 0 # Reset to beginning + # Get all remaining rows from the current position (JDBC-aligned behavior) + # Note: This will only filter rows that haven't been read yet all_rows = result_set.results.remaining_rows() # Filter rows filtered_rows = [row for row in all_rows if filter_func(row)] - # Import SeaResultSet here to avoid circular imports - from databricks.sql.result_set import SeaResultSet - - # Reuse the command_id from the original result set - command_id = result_set.command_id - - # Create an ExecuteResponse with the filtered data execute_response = ExecuteResponse( - command_id=command_id, + command_id=result_set.command_id, status=result_set.status, description=result_set.description, - has_more_rows=result_set._has_more_rows, + has_more_rows=result_set.has_more_rows, results_queue=JsonQueue(filtered_rows), has_been_closed_server_side=result_set.has_been_closed_server_side, lz4_compressed=False, is_staging_operation=False, ) + from databricks.sql.result_set import SeaResultSet + return SeaResultSet( connection=result_set.connection, execute_response=execute_response, diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 10100e86e..a54337f0c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -66,9 +66,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths From e385d5b8b6f9be36183e763286f3406ca6c5c144 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:49:37 +0000 Subject: [PATCH 039/105] backend from cloudfetch-sea Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 375 +++++++++++------- .../sql/backend/sea/models/responses.py | 12 +- .../sql/backend/sea/utils/http_client.py | 2 +- 3 files changed, 233 insertions(+), 156 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a54337f0c..c1f21448b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,8 +1,8 @@ import logging -import re import uuid import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING +import re +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, @@ -11,13 +11,26 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet + from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, +) from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions +from databricks.sql.utils import SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ( + ResultData, + ExternalLink, + ResultManifest, +) from databricks.sql.backend.sea.models import ( ExecuteStatementRequest, @@ -66,6 +79,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -75,6 +91,8 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" + CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -107,6 +125,7 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) + self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -263,6 +282,19 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + @staticmethod + def is_session_configuration_parameter_supported(name: str) -> bool: + """ + Check if a session configuration parameter is supported. + + Args: + name: The name of the session configuration parameter + + Returns: + True if the parameter is supported, False otherwise + """ + return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP + @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -273,8 +305,182 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - # == Not Implemented Operations == - # These methods will be implemented in future iterations + def _get_schema_bytes(self, sea_response) -> Optional[bytes]: + """ + Extract schema bytes from the SEA response. + + For ARROW format, we need to get the schema bytes from the first chunk. + If the first chunk is not available, we need to get it from the server. + + Args: + sea_response: The response from the SEA API + + Returns: + bytes: The schema bytes or None if not available + """ + import requests + import lz4.frame + + # Check if we have the first chunk in the response + result_data = sea_response.get("result", {}) + external_links = result_data.get("external_links", []) + + if not external_links: + return None + + # Find the first chunk (chunk_index = 0) + first_chunk = None + for link in external_links: + if link.get("chunk_index") == 0: + first_chunk = link + break + + if not first_chunk: + # Try to fetch the first chunk from the server + statement_id = sea_response.get("statement_id") + if not statement_id: + return None + + chunks_response = self.get_chunk_links(statement_id, 0) + if not chunks_response.external_links: + return None + + first_chunk = chunks_response.external_links[0].__dict__ + + # Download the first chunk to get the schema bytes + external_link = first_chunk.get("external_link") + http_headers = first_chunk.get("http_headers", {}) + + if not external_link: + return None + + # Use requests to download the first chunk + http_response = requests.get( + external_link, + headers=http_headers, + verify=self.ssl_options.tls_verify, + ) + + if http_response.status_code != 200: + raise Error(f"Failed to download schema bytes: {http_response.text}") + + # Extract schema bytes from the Arrow file + # The schema is at the beginning of the file + data = http_response.content + if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": + data = lz4.frame.decompress(data) + + # Return the schema bytes + return data + + def _results_message_to_execute_response(self, sea_response, command_id): + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, + result data object, and manifest object + """ + # Extract status + status_data = sea_response.get("status", {}) + state = CommandState.from_sea_state(status_data.get("state", "")) + + # Extract description from manifest + description = None + manifest_data = sea_response.get("manifest", {}) + schema_data = manifest_data.get("schema", {}) + columns_data = schema_data.get("columns", []) + + if columns_data: + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + description = columns if columns else None + + # Extract schema bytes for Arrow format + schema_bytes = None + format = manifest_data.get("format") + if format == "ARROW_STREAM": + # For ARROW format, we need to get the schema bytes + schema_bytes = self._get_schema_bytes(sea_response) + + # Check for compression + lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" + + # Initialize result_data_obj and manifest_obj + result_data_obj = None + manifest_obj = None + + result_data = sea_response.get("result", {}) + if result_data: + # Convert external links + external_links = None + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers", {}), + ) + ) + + # Create the result data object + result_data_obj = ResultData( + data=result_data.get("data_array"), external_links=external_links + ) + + # Create the manifest object + manifest_obj = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + execute_response = ExecuteResponse( + command_id=command_id, + status=state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=False, + arrow_schema_bytes=schema_bytes, + result_format=manifest_data.get("format"), + ) + + return execute_response, result_data_obj, manifest_obj def execute_command( self, @@ -336,7 +542,7 @@ def execute_command( format=format, wait_timeout="0s" if async_op else "10s", on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, + row_limit=max_rows, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, ) @@ -494,157 +700,20 @@ def get_execution_result( # Create and return a SeaResultSet from databricks.sql.result_set import SeaResultSet + # Convert the response to an ExecuteResponse and extract result data + ( + execute_response, + result_data, + manifest, + ) = self._results_message_to_execute_response(response_data, command_id) + return SeaResultSet( connection=cursor.connection, - sea_response=response_data, + execute_response=execute_response, sea_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, + result_data=result_data, + manifest=manifest, ) - # == Metadata Operations == - - def get_catalogs( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result - - def get_schemas( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result - - def get_tables( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - table_name: Optional[str] = None, - table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result - - def get_columns( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - table_name: Optional[str] = None, - column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 6b5067506..d684a9c67 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -39,8 +39,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": error_code=error_data.get("error_code"), ) + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( - state=CommandState.from_sea_state(status_data.get("state", "")), + state=state, error=error, sql_state=status_data.get("sql_state"), ) @@ -119,8 +123,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": error_code=error_data.get("error_code"), ) + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( - state=CommandState.from_sea_state(status_data.get("state", "")), + state=state, error=error, sql_state=status_data.get("sql_state"), ) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider From 484064ef8cd24e2f6c5cf9ec268d2cfb5597ea4d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:51:22 +0000 Subject: [PATCH 040/105] remove filtering, metadata ops Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 154 ----------- src/databricks/sql/backend/sea/backend.py | 1 - tests/unit/test_result_set_filter.py | 246 ------------------ tests/unit/test_sea_backend.py | 302 ---------------------- 4 files changed, 703 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 9fa0a5535..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Dict, - Callable, - TypeVar, - Generic, - cast, - TYPE_CHECKING, -) - -from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory -from databricks.sql.backend.types import ExecuteResponse, CommandId -from databricks.sql.backend.sea.models.base import ResultData - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet, SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Get all remaining rows from the current position (JDBC-aligned behavior) - # Note: This will only filter rows that haven't been read yet - all_rows = result_set.results.remaining_rows() - - # Filter rows - filtered_rows = [row for row in all_rows if filter_func(row)] - - execute_response = ExecuteResponse( - command_id=result_set.command_id, - status=result_set.status, - description=result_set.description, - has_more_rows=result_set.has_more_rows, - results_queue=JsonQueue(filtered_rows), - has_been_closed_server_side=result_set.has_been_closed_server_side, - lz4_compressed=False, - is_staging_operation=False, - ) - - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=result_set.connection, - execute_response=execute_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - from databricks.sql.result_set import SeaResultSet - - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c1f21448b..80066ae82 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -716,4 +716,3 @@ def get_execution_result( result_data=result_data, manifest=manifest, ) - diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..2fa362b8e 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -546,305 +546,3 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) - - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) From 030edf8df3db487b7af8d910ee51240d1339229e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:55:56 +0000 Subject: [PATCH 041/105] raise NotImplementedErrror for metadata ops Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 57 +++++++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 80066ae82..b1ad7cf76 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,8 +1,7 @@ import logging -import uuid import time import re -from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set +from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, @@ -23,9 +22,7 @@ ) from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions -from databricks.sql.utils import SeaResultSetQueueFactory from databricks.sql.backend.sea.models.base import ( ResultData, ExternalLink, @@ -716,3 +713,55 @@ def get_execution_result( result_data=result_data, manifest=manifest, ) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + raise NotImplementedError("get_catalogs is not implemented for SEA backend") + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + raise NotImplementedError("get_schemas is not implemented for SEA backend") + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_tables is not implemented for SEA backend") + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_columns is not implemented for SEA backend") From 30f82666804d0104bb419836def6b56b5dda3f8e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:10:50 +0000 Subject: [PATCH 042/105] add metadata commands Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 167 ++++++++++++++++++++++ src/databricks/sql/backend/sea/backend.py | 103 ++++++++++++- tests/unit/test_filters.py | 120 ++++++++++++++++ 3 files changed, 386 insertions(+), 4 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 tests/unit/test_filters.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..2c0105aee --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,167 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Dict, + Callable, + TypeVar, + Generic, + cast, + TYPE_CHECKING, +) + +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse, CommandId +from databricks.sql.backend.sea.models.base import ResultData +from databricks.sql.backend.sea.backend import SeaDatabricksClient + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet, SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Get all remaining rows + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Import SeaResultSet here to avoid circular imports + from databricks.sql.result_set import SeaResultSet + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=result_set.lz4_compressed, + arrow_schema_bytes=result_set._arrow_schema_bytes, + is_staging_operation=False, + ) + + # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData + + result_data = ResultData(data=filtered_rows, external_links=None) + + # Create a new SeaResultSet with the filtered data + filtered_result_set = SeaResultSet( + connection=result_set.connection, + execute_response=execute_response, + sea_client=cast(SeaDatabricksClient, result_set.backend), + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + result_data=result_data, + ) + + return filtered_result_set + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + from databricks.sql.result_set import SeaResultSet + + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=True + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..2807975cd 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -724,7 +724,20 @@ def get_catalogs( cursor: "Cursor", ) -> "ResultSet": """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -736,7 +749,28 @@ def get_schemas( schema_name: Optional[str] = None, ) -> "ResultSet": """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -750,7 +784,41 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -764,4 +832,31 @@ def get_columns( column_name: Optional[str] = None, ) -> "ResultSet": """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result \ No newline at end of file diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 000000000..49bd1c328 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,120 @@ +""" +Tests for the ResultSetFilter class. +""" + +import unittest +from unittest.mock import MagicMock, patch +import sys +from typing import List, Dict, Any + +# Add the necessary path to import the filter module +sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src") + +from databricks.sql.backend.filters import ResultSetFilter + + +class TestResultSetFilter(unittest.TestCase): + """Tests for the ResultSetFilter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock SeaResultSet + self.mock_sea_result_set = MagicMock() + self.mock_sea_result_set._response = { + "result": { + "data_array": [ + ["catalog1", "schema1", "table1", "TABLE", ""], + ["catalog1", "schema1", "table2", "VIEW", ""], + ["catalog1", "schema1", "table3", "SYSTEM TABLE", ""], + ["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""], + ], + "row_count": 4, + } + } + + # Set up the connection and other required attributes + self.mock_sea_result_set.connection = MagicMock() + self.mock_sea_result_set.backend = MagicMock() + self.mock_sea_result_set.buffer_size_bytes = 1000 + self.mock_sea_result_set.arraysize = 100 + self.mock_sea_result_set.statement_id = "test-statement-id" + + # Create a mock CommandId + from databricks.sql.backend.types import CommandId, BackendType + + mock_command_id = CommandId(BackendType.SEA, "test-statement-id") + self.mock_sea_result_set.command_id = mock_command_id + + self.mock_sea_result_set.status = MagicMock() + self.mock_sea_result_set.description = [ + ("catalog_name", "string", None, None, None, None, True), + ("schema_name", "string", None, None, None, None, True), + ("table_name", "string", None, None, None, None, True), + ("table_type", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] + self.mock_sea_result_set.has_been_closed_server_side = False + + def test_filter_tables_by_type(self): + """Test filtering tables by type.""" + # Test with specific table types + table_types = ["TABLE", "VIEW"] + + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type_case_insensitive(self): + """Test filtering tables by type with case insensitivity.""" + # Test with lowercase table types + table_types = ["table", "view"] + + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type_default(self): + """Test filtering tables by type with default types.""" + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, None + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + +if __name__ == "__main__": + unittest.main() From 033ae73440dad3295ac097da5809eff4563be7b0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:12:04 +0000 Subject: [PATCH 043/105] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2807975cd..1e4eb3253 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -859,4 +859,4 @@ def get_columns( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" - return result \ No newline at end of file + return result From 33821f46f0531fbc2bb08dc28002c33b46e0f485 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:41:54 +0000 Subject: [PATCH 044/105] add metadata command unit tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 1 - tests/unit/test_sea_backend.py | 442 ++++++++++++++++++++++++++ 2 files changed, 442 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 2c0105aee..ec91d87da 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -17,7 +17,6 @@ TYPE_CHECKING, ) -from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import ExecuteResponse, CommandId from databricks.sql.backend.sea.models.base import ResultData from databricks.sql.backend.sea.backend import SeaDatabricksClient diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..0b6f10803 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -546,3 +546,445 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) + + # Tests for metadata commands + + def test_get_catalogs( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting catalogs metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + def test_get_schemas( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting schemas metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW SCHEMAS IN `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog name and schema pattern + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) + + def test_get_tables( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting tables metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the get_tables method to avoid import errors + original_get_tables = sea_client.get_tables + try: + # Replace get_tables with a simple version that doesn't use ResultSetFilter + def mock_get_tables( + session_id, + max_rows, + max_bytes, + cursor, + catalog_name, + schema_name=None, + table_name=None, + table_types=None, + ): + if catalog_name is None: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + return sea_client.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + sea_client.get_tables = mock_get_tables + + # Call the method + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog and schema name + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: With catalog, schema, and table name + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 4: With wildcard catalog + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 5: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) + finally: + # Restore the original method + sea_client.get_tables = original_get_tables + + def test_get_columns( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting columns metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog and schema name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: With catalog, schema, and table name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 4: With catalog, schema, table, and column name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="col%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'col%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 5: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) From 3e22c6c4f297a3c83dbebba7c57e3bc8c0c5fe9a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:34:34 +0000 Subject: [PATCH 045/105] change to valid table name Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index c715e5984..394c48b24 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -74,7 +74,7 @@ def test_sea_metadata(): f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." ) cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" + catalog_name=catalog, schema_name="default", table_name="customer" ) logger.info("Successfully fetched columns") From 165c4f35ce69f282b03e6522c6ea72c6d0a8f5fc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:18:39 +0000 Subject: [PATCH 046/105] remove un-necessary changes covered by #588 Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 11 +- src/databricks/sql/result_set.py | 73 ------- tests/unit/test_sea_result_set.py | 200 ------------------- tests/unit/test_thrift_backend.py | 32 +-- 4 files changed, 7 insertions(+), 309 deletions(-) delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d28a2c6fd..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -15,6 +15,7 @@ CommandId, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id try: @@ -841,8 +842,6 @@ def get_execution_result( status = self.get_query_state(command_id) - status = self.get_query_state(command_id) - execute_response = ExecuteResponse( command_id=command_id, status=status, @@ -895,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1189,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 97b10cbbe..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -438,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index 02421a915..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() \ No newline at end of file diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 88adcd3e9..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,13 +619,6 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), @@ -927,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -957,12 +948,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) @@ -977,7 +962,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,12 +982,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req tcli_service_instance.GetOperationStatus.return_value = ( ttypes.TGetOperationStatusResp( @@ -1694,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,8 +2233,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class, mock_result_set From a6e40d0dce9acd43c29e2de76f7d64ce96f775a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:25:51 +0000 Subject: [PATCH 047/105] simplify test module Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 41 +++++++++------------ 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..3a8b163f5 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,20 +1,18 @@ """ Main script to run all SEA connector tests. -This script imports and runs all the individual test modules and displays +This script runs all the individual test modules and displays a summary of test results with visual indicators. """ import os import sys import logging -import importlib.util -from typing import Dict, Callable, List, Tuple +import subprocess +from typing import List, Tuple -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Define test modules and their main test functions TEST_MODULES = [ "test_sea_session", "test_sea_sync_query", @@ -23,29 +21,27 @@ ] -def load_test_function(module_name: str) -> Callable: - """Load a test function from a module.""" +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" module_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) - # Get the main test function (assuming it starts with "test_") - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - # For sync and async query modules, we want the main function that runs both tests - if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": - return getattr(module, name) + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) - # Fallback to the first test function found - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - return getattr(module, name) + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) - raise ValueError(f"No test function found in module {module_name}") + return result.returncode == 0 def run_tests() -> List[Tuple[str, bool]]: @@ -54,12 +50,11 @@ def run_tests() -> List[Tuple[str, bool]]: for module_name in TEST_MODULES: try: - test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - success = test_func() + success = run_test_module(module_name) results.append((module_name, success)) status = "✅ PASSED" if success else "❌ FAILED" From 52e3088b31d659064e740388bd2f25df1c3b158f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:26:23 +0000 Subject: [PATCH 048/105] logging -> debug level Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 3a8b163f5..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,7 +10,7 @@ import subprocess from typing import List, Tuple -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) TEST_MODULES = [ From 641c09b0d2a5fb5c79b3b696f767f81d0b5283e4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:28:18 +0000 Subject: [PATCH 049/105] change table name in log Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index 394c48b24..a200d97d3 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -71,7 +71,7 @@ def test_sea_metadata(): # Test columns for a specific table # Using a common table that should exist in most environments logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." ) cursor.columns( catalog_name=catalog, schema_name="default", table_name="customer" From ffded6ee2c50eb2efc1cdd2e580d51e396ce2cdd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:39:37 +0000 Subject: [PATCH 050/105] remove un-necessary changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 168 +++---- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 -------- .../experimental/tests/test_sea_metadata.py | 98 ---- .../experimental/tests/test_sea_session.py | 71 --- .../experimental/tests/test_sea_sync_query.py | 161 ------- tests/unit/test_sea_backend.py | 453 ++++-------------- tests/unit/test_sea_result_set.py | 200 -------- tests/unit/test_thrift_backend.py | 32 +- 9 files changed, 155 insertions(+), 1219 deletions(-) delete mode 100644 examples/experimental/tests/__init__.py delete mode 100644 examples/experimental/tests/test_sea_async_query.py delete mode 100644 examples/experimental/tests/test_sea_metadata.py delete mode 100644 examples/experimental/tests/test_sea_session.py delete mode 100644 examples/experimental/tests/test_sea_sync_query.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,120 +1,66 @@ -""" -Main script to run all SEA connector tests. - -This script imports and runs all the individual test modules and displays -a summary of test results with visual indicators. -""" import os import sys import logging -import importlib.util -from typing import Dict, Callable, List, Tuple +from databricks.sql.client import Connection -# Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -# Define test modules and their main test functions -TEST_MODULES = [ - "test_sea_session", - "test_sea_sync_query", - "test_sea_async_query", - "test_sea_metadata", -] - - -def load_test_function(module_name: str) -> Callable: - """Load a test function from a module.""" - module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" - ) - - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Get the main test function (assuming it starts with "test_") - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - # For sync and async query modules, we want the main function that runs both tests - if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": - return getattr(module, name) - - # Fallback to the first test function found - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - return getattr(module, name) - - raise ValueError(f"No test function found in module {module_name}") - - -def run_tests() -> List[Tuple[str, bool]]: - """Run all tests and return results.""" - results = [] - - for module_name in TEST_MODULES: - try: - test_func = load_test_function(module_name) - logger.info(f"\n{'=' * 50}") - logger.info(f"Running test: {module_name}") - logger.info(f"{'-' * 50}") - - success = test_func() - results.append((module_name, success)) - - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"Test {module_name}: {status}") - - except Exception as e: - logger.error(f"Error loading or running test {module_name}: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - results.append((module_name, False)) - - return results - - -def print_summary(results: List[Tuple[str, bool]]) -> None: - """Print a summary of test results.""" - logger.info(f"\n{'=' * 50}") - logger.info("TEST SUMMARY") - logger.info(f"{'-' * 50}") - - passed = sum(1 for _, success in results if success) - total = len(results) - - for module_name, success in results: - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"{status} - {module_name}") - - logger.info(f"{'-' * 50}") - logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") - logger.info(f"{'=' * 50}") - - -if __name__ == "__main__": - # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] - missing_vars = [var for var in required_vars if not os.environ.get(var)] - - if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) - logger.error("Please set these variables before running the tests.") + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) sys.exit(1) + + logger.info("SEA session test completed successfully") - # Run all tests - results = run_tests() - - # Print summary - print_summary(results) - - # Exit with appropriate status code - all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) +if __name__ == "__main__": + test_sea_session() diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py deleted file mode 100644 index a776377c3..000000000 --- a/examples/experimental/tests/test_sea_async_query.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Test for SEA asynchronous query execution functionality. -""" -import os -import sys -import logging -import time -from databricks.sql.client import Connection -from databricks.sql.backend.types import CommandState - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_async_query_with_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch enabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_without_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch disabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_exec(): - """ - Run both asynchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info( - f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info( - f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_async_query_exec() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py deleted file mode 100644 index c715e5984..000000000 --- a/examples/experimental/tests/test_sea_metadata.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Test for SEA metadata functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_metadata(): - """ - Test metadata operations using the SEA backend. - - This function connects to a Databricks SQL endpoint using the SEA backend, - and executes metadata operations like catalogs(), schemas(), tables(), and columns(). - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - if not catalog: - logger.error( - "DATABRICKS_CATALOG environment variable is required for metadata tests." - ) - return False - - try: - # Create connection - logger.info("Creating connection for metadata operations") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Test catalogs - cursor = connection.cursor() - logger.info("Fetching catalogs...") - cursor.catalogs() - logger.info("Successfully fetched catalogs") - - # Test schemas - logger.info(f"Fetching schemas for catalog '{catalog}'...") - cursor.schemas(catalog_name=catalog) - logger.info("Successfully fetched schemas") - - # Test tables - logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") - cursor.tables(catalog_name=catalog, schema_name="default") - logger.info("Successfully fetched tables") - - # Test columns for a specific table - # Using a common table that should exist in most environments - logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." - ) - cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" - ) - logger.info("Successfully fetched columns") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error during SEA metadata test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_metadata() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py deleted file mode 100644 index 516c1bbb8..000000000 --- a/examples/experimental/tests/test_sea_session.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Test for SEA session management functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"Backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_session() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py deleted file mode 100644 index 07be8aafc..000000000 --- a/examples/experimental/tests/test_sea_sync_query.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Test for SEA synchronous query execution functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_sync_query_with_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_without_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_exec(): - """ - Run both synchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info( - f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info( - f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,348 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index b691872af..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 88adcd3e9..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,13 +619,6 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), @@ -927,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -957,12 +948,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) @@ -977,7 +962,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,12 +982,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req tcli_service_instance.GetOperationStatus.return_value = ( ttypes.TGetOperationStatusResp( @@ -1694,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,8 +2233,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class, mock_result_set From 227f6b36bd65cc8a7c903316334a18a8a8e249b1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:41:29 +0000 Subject: [PATCH 051/105] remove un-necessary backend cahnges Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 481 ++----------------- src/databricks/sql/backend/thrift_backend.py | 11 +- src/databricks/sql/result_set.py | 73 --- 3 files changed, 42 insertions(+), 523 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,44 +1,23 @@ import logging -import time import re -from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet - from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import ( - SessionId, - CommandId, - CommandState, - BackendType, - ExecuteResponse, -) -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.types import SSLOptions -from databricks.sql.backend.sea.models.base import ( - ResultData, - ExternalLink, - ResultManifest, +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, ) +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -76,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -88,8 +64,6 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" - CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" - CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -122,7 +96,6 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) - self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -279,19 +252,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) - @staticmethod - def is_session_configuration_parameter_supported(name: str) -> bool: - """ - Check if a session configuration parameter is supported. - - Args: - name: The name of the session configuration parameter - - Returns: - True if the parameter is supported, False otherwise - """ - return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP - @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -302,182 +262,8 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - def _get_schema_bytes(self, sea_response) -> Optional[bytes]: - """ - Extract schema bytes from the SEA response. - - For ARROW format, we need to get the schema bytes from the first chunk. - If the first chunk is not available, we need to get it from the server. - - Args: - sea_response: The response from the SEA API - - Returns: - bytes: The schema bytes or None if not available - """ - import requests - import lz4.frame - - # Check if we have the first chunk in the response - result_data = sea_response.get("result", {}) - external_links = result_data.get("external_links", []) - - if not external_links: - return None - - # Find the first chunk (chunk_index = 0) - first_chunk = None - for link in external_links: - if link.get("chunk_index") == 0: - first_chunk = link - break - - if not first_chunk: - # Try to fetch the first chunk from the server - statement_id = sea_response.get("statement_id") - if not statement_id: - return None - - chunks_response = self.get_chunk_links(statement_id, 0) - if not chunks_response.external_links: - return None - - first_chunk = chunks_response.external_links[0].__dict__ - - # Download the first chunk to get the schema bytes - external_link = first_chunk.get("external_link") - http_headers = first_chunk.get("http_headers", {}) - - if not external_link: - return None - - # Use requests to download the first chunk - http_response = requests.get( - external_link, - headers=http_headers, - verify=self.ssl_options.tls_verify, - ) - - if http_response.status_code != 200: - raise Error(f"Failed to download schema bytes: {http_response.text}") - - # Extract schema bytes from the Arrow file - # The schema is at the beginning of the file - data = http_response.content - if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": - data = lz4.frame.decompress(data) - - # Return the schema bytes - return data - - def _results_message_to_execute_response(self, sea_response, command_id): - """ - Convert a SEA response to an ExecuteResponse and extract result data. - - Args: - sea_response: The response from the SEA API - command_id: The command ID - - Returns: - tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, - result data object, and manifest object - """ - # Extract status - status_data = sea_response.get("status", {}) - state = CommandState.from_sea_state(status_data.get("state", "")) - - # Extract description from manifest - description = None - manifest_data = sea_response.get("manifest", {}) - schema_data = manifest_data.get("schema", {}) - columns_data = schema_data.get("columns", []) - - if columns_data: - columns = [] - for col_data in columns_data: - if not isinstance(col_data, dict): - continue - - # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) - columns.append( - ( - col_data.get("name", ""), # name - col_data.get("type_name", ""), # type_code - None, # display_size (not provided by SEA) - None, # internal_size (not provided by SEA) - col_data.get("precision"), # precision - col_data.get("scale"), # scale - col_data.get("nullable", True), # null_ok - ) - ) - description = columns if columns else None - - # Extract schema bytes for Arrow format - schema_bytes = None - format = manifest_data.get("format") - if format == "ARROW_STREAM": - # For ARROW format, we need to get the schema bytes - schema_bytes = self._get_schema_bytes(sea_response) - - # Check for compression - lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" - - # Initialize result_data_obj and manifest_obj - result_data_obj = None - manifest_obj = None - - result_data = sea_response.get("result", {}) - if result_data: - # Convert external links - external_links = None - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers", {}), - ) - ) - - # Create the result data object - result_data_obj = ResultData( - data=result_data.get("data_array"), external_links=external_links - ) - - # Create the manifest object - manifest_obj = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - execute_response = ExecuteResponse( - command_id=command_id, - status=state, - description=description, - has_been_closed_server_side=False, - lz4_compressed=lz4_compressed, - is_staging_operation=False, - arrow_schema_bytes=schema_bytes, - result_format=manifest_data.get("format"), - ) - - return execute_response, result_data_obj, manifest_obj + # == Not Implemented Operations == + # These methods will be implemented in future iterations def execute_command( self, @@ -488,230 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else None - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() - ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) - def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - # Convert the response to an ExecuteResponse and extract result data - ( - execute_response, - result_data, - manifest, - ) = self._results_message_to_execute_response(response_data, command_id) - - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - result_data=result_data, - manifest=manifest, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -722,9 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -734,9 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -748,9 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -762,6 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d28a2c6fd..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -15,6 +15,7 @@ CommandId, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id try: @@ -841,8 +842,6 @@ def get_execution_result( status = self.get_query_state(command_id) - status = self.get_query_state(command_id) - execute_response = ExecuteResponse( command_id=command_id, status=status, @@ -895,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1189,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 97b10cbbe..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -438,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 68657a3ba20080dde478b3e9d4b0940bdf4ca299 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 14:52:28 +0000 Subject: [PATCH 052/105] remove un-needed GetChunksResponse Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 1 - .../sql/backend/sea/models/responses.py | 35 ------------------- 2 files changed, 36 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..6d627162d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet - from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d684a9c67..1f73df409 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -196,38 +196,3 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) - - -@dataclass -class GetChunksResponse: - """Response from getting chunks for a statement.""" - - statement_id: str - external_links: List[ExternalLink] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": - """Create a GetChunksResponse from a dictionary.""" - external_links = [] - if "external_links" in data: - for link_data in data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - return cls( - statement_id=data.get("statement_id", ""), - external_links=external_links, - ) From 3940eecd0671deee86ef9b81a1853fcedaf31bb1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 14:53:15 +0000 Subject: [PATCH 053/105] remove un-needed GetChunksResponse only relevant in Fetch phase Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d71262d1d..51f0d4452 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -196,38 +196,3 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) - - -@dataclass -class GetChunksResponse: - """Response from getting chunks for a statement.""" - - statement_id: str - external_links: List[ExternalLink] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": - """Create a GetChunksResponse from a dictionary.""" - external_links = [] - if "external_links" in data: - for link_data in data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - return cls( - statement_id=data.get("statement_id", ""), - external_links=external_links, - ) From 37813ba6d1fe06d7f9f10d510a059b88dc552496 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:00:35 +0000 Subject: [PATCH 054/105] reduce code duplication in response parsing Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 219 +++++++----------- 1 file changed, 78 insertions(+), 141 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1f73df409..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,8 +4,8 @@ These models define the structures used in SEA API responses. """ -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field +from typing import Dict, Any +from dataclasses import dataclass from databricks.sql.backend.types import CommandState from databricks.sql.backend.sea.models.base import ( @@ -14,91 +14,92 @@ ResultData, ServiceError, ExternalLink, - ColumnInfo, ) +def _parse_status(data: Dict[str, Any]) -> StatementStatus: + """Parse status from response data.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + return StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: + """Parse manifest from response data.""" + + manifest_data = data.get("manifest", {}) + return ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + +def _parse_result(data: Dict[str, Any]) -> ResultData: + """Parse result data from response data.""" + result_data = data.get("result", {}) + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get("next_chunk_internal_link"), + http_headers=link_data.get("http_headers"), + ) + ) + + return ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + + @dataclass class ExecuteStatementResponse: """Response from executing a SQL statement.""" statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -108,81 +109,17 @@ class GetStatementResponse: statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) From 267c9f44e55778af748749336c26bb06ce0ab33c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:01:29 +0000 Subject: [PATCH 055/105] reduce code duplication Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 221 +++++++----------- 1 file changed, 79 insertions(+), 142 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 51f0d4452..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,8 +4,8 @@ These models define the structures used in SEA API responses. """ -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field +from typing import Dict, Any +from dataclasses import dataclass from databricks.sql.backend.types import CommandState from databricks.sql.backend.sea.models.base import ( @@ -14,91 +14,92 @@ ResultData, ServiceError, ExternalLink, - ColumnInfo, ) +def _parse_status(data: Dict[str, Any]) -> StatementStatus: + """Parse status from response data.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + return StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: + """Parse manifest from response data.""" + + manifest_data = data.get("manifest", {}) + return ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + +def _parse_result(data: Dict[str, Any]) -> ResultData: + """Parse result data from response data.""" + result_data = data.get("result", {}) + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get("next_chunk_internal_link"), + http_headers=link_data.get("http_headers"), + ) + ) + + return ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + + @dataclass class ExecuteStatementResponse: """Response from executing a SQL statement.""" statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -108,87 +109,23 @@ class GetStatementResponse: statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str From 296711946a5dd735a655961984641ed2a19d0f2a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:03:07 +0000 Subject: [PATCH 056/105] more clear docstrings Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/requests.py | 10 +++++----- src/databricks/sql/backend/sea/models/responses.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index d9483e51a..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Parameter for a SQL statement.""" + """Representation of a parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Request to execute a SQL statement.""" + """Representation of a request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Request to get information about a statement.""" + """Representation of a request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Request to cancel a statement.""" + """Representation of a request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Request to close a statement.""" + """Representation of a request to close a statement.""" statement_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index c16f19da3..a8cf0c998 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -85,7 +85,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: @dataclass class ExecuteStatementResponse: - """Response from executing a SQL statement.""" + """Representation of the response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -105,7 +105,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Response from getting information about a statement.""" + """Representation of the response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -125,7 +125,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str From 47fd60d2b20fcaf1f39300a88224899edb2c0a58 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:25:24 +0000 Subject: [PATCH 057/105] introduce strongly typed ChunkInfo Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/base.py | 12 +++++++++++- .../sql/backend/sea/models/responses.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 6175b4ca0..f63edba72 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,6 +42,16 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None +@dataclass +class ChunkInfo: + """Information about a chunk in the result set.""" + + chunk_index: int + byte_count: int + row_offset: int + row_count: int + + @dataclass class ResultData: """Result data from a statement execution.""" @@ -73,5 +83,5 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[Dict[str, Any]]] = None + chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index a8cf0c998..7388af193 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,6 +14,7 @@ ResultData, ServiceError, ExternalLink, + ChunkInfo, ) @@ -43,6 +44,18 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) + chunks = None + if "chunks" in manifest_data: + chunks = [ + ChunkInfo( + chunk_index=chunk.get("chunk_index", 0), + byte_count=chunk.get("byte_count", 0), + row_offset=chunk.get("row_offset", 0), + row_count=chunk.get("row_count", 0), + ) + for chunk in manifest_data.get("chunks", []) + ] + return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -50,8 +63,9 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), + chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) From 982fdf2df8480d6ddd8c93b5f8839e4cf5ccce2e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 03:08:31 +0000 Subject: [PATCH 058/105] remove is_volume_operation from response Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/responses.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 7388af193..42dcd356a 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -65,7 +65,6 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: truncated=manifest_data.get("truncated", False), chunks=chunks, result_compression=manifest_data.get("result_compression"), - is_volume_operation=manifest_data.get("is_volume_operation"), ) From 9e14d48fdb03500ad13e098cd963d7a04dadd9a0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:06:47 +0000 Subject: [PATCH 059/105] add is_volume_op and more ResultData fields Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/base.py | 8 ++++++++ src/databricks/sql/backend/sea/models/responses.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index f63edba72..b12c26eb0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -58,6 +58,13 @@ class ResultData: data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + attachment: Optional[bytes] = None @dataclass @@ -85,3 +92,4 @@ class ResultManifest: truncated: bool = False chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None + is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 42dcd356a..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -65,6 +65,7 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: truncated=manifest_data.get("truncated", False), chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -93,6 +94,13 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, + byte_count=result_data.get("byte_count"), + chunk_index=result_data.get("chunk_index"), + next_chunk_index=result_data.get("next_chunk_index"), + next_chunk_internal_link=result_data.get("next_chunk_internal_link"), + row_count=result_data.get("row_count"), + row_offset=result_data.get("row_offset"), + attachment=result_data.get("attachment"), ) From 05ee4e78fe72c200e90842d5d916546b08a1a51c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:11:25 +0000 Subject: [PATCH 060/105] add test scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 98 +++++++++ .../experimental/tests/test_sea_session.py | 71 +++++++ .../experimental/tests/test_sea_sync_query.py | 161 +++++++++++++++ 5 files changed, 521 insertions(+) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a776377c3 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,191 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..a200d97d3 --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,98 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="customer" + ) + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..516c1bbb8 --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,71 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..07be8aafc --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,161 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) From 2952d8dc2de6adf25ac1c9dd358fc7f5bfc6f495 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:15:01 +0000 Subject: [PATCH 061/105] Revert "Merge branch 'sea-migration' into exec-models-sea" This reverts commit 8bd12d829ea13abf8fc1507fff8cb21751001c67, reversing changes made to 030edf8df3db487b7af8d910ee51240d1339229e. --- .../sql/backend/sea/models/requests.py | 4 +- .../sql/backend/sea/models/responses.py | 2 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 130 ++++++++---------- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 86 +++++------- src/databricks/sql/utils.py | 6 +- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +++--- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_thrift_backend.py | 106 +++++--------- 11 files changed, 159 insertions(+), 237 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 4c5071dba..8524275d4 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Representation of a request to create a new session.""" + """Request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Representation of a request to delete a session.""" + """Request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..4dcd4af02 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..48e9a115f 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,21 +3,24 @@ import logging import math import time +import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, + BackendType, + guid_to_hex_id, ExecuteResponse, ) from databricks.sql.backend.utils import guid_to_hex_id - try: import pyarrow except ImportError: @@ -757,13 +760,11 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - - is_direct_results = ( + has_more_rows = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,25 +780,43 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + if direct_results and direct_results.resultSet: + assert direct_results.resultSet.results.startRowOffset == 0 + assert direct_results.resultSetMetadata + + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + else: + arrow_queue_opt = None + command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - execute_response = ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) - return execute_response, is_direct_results - def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -822,6 +841,9 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -836,21 +858,25 @@ def get_execution_result( else: schema_bytes = None - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - is_direct_results = resp.hasMoreRows - - status = self.get_query_state(command_id) + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, + has_more_rows=has_more_rows, + results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -860,10 +886,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=resp.results, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -976,14 +999,10 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -991,10 +1010,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1016,14 +1032,10 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1031,10 +1043,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1060,14 +1069,10 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1075,10 +1080,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1108,14 +1110,10 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1123,10 +1121,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1156,14 +1151,10 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1171,10 +1162,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,9 +423,11 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None - result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cf6940bb2..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - is_direct_results: bool = False, + has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Parameters: - :param connection: The parent connection - :param backend: The backend client - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - :param command_id: The command ID - :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server - :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue - :param description: column description of the results - :param is_staging_operation: Whether the command is a staging operation + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,47 +157,25 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - t_row_set=None, - max_download_threads: int = 10, - ssl_options=None, - is_direct_results: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Parameters: - :param connection: The parent connection - :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access - :param buffer_size_bytes: Buffer size for fetching results - :param arraysize: Default number of rows to fetch - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - :param t_row_set: The TRowSet containing result data (if available) - :param max_download_threads: Maximum number of download threads for cloud fetch - :param ssl_options: SSL options for cloud fetch - :param is_direct_results: Whether there are more rows to fetch + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, @@ -207,8 +185,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - is_direct_results=is_direct_results, - results_queue=results_queue, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -218,7 +196,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -229,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -313,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -338,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -353,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -379,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d7b1b74b4..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2054d01d1..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - is_direct_results=True, + has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,7 +104,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -185,7 +184,6 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -212,7 +210,6 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -257,10 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) - - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -478,6 +472,7 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq + mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,30 +40,25 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - - # Create a mock backend that will return the queue when _fill_results_buffer is called - mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) - - num_cols = len(initial_results[0]) if initial_results else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - description=description, - lz4_compressed=True, + has_more_rows=False, + description=Mock(), + lz4_compressed=Mock(), + results_queue=arrow_queue, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, - t_row_set=None, + thrift_client=None, ) + num_cols = len(initial_results[0]) if initial_results else 0 + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod @@ -90,19 +85,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - description=description, - lz4_compressed=True, + has_more_rows=True, + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], + lz4_compressed=Mock(), + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e4a9e5cdd..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - is_direct_results=False, + has_more_rows=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..8274190fe 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -623,10 +623,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,10 +832,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value @@ -882,10 +878,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) + self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -951,14 +947,8 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -983,14 +973,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1004,10 +988,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1019,7 +1003,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, ), closeOperation=Mock(), @@ -1035,12 +1019,11 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - ( - execute_response, - has_more_rows_result, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) - self.assertEqual(is_direct_results, has_more_rows_result) + self.assertEqual(has_more_rows, execute_response.has_more_rows) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1049,10 +1032,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1065,7 +1048,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1098,7 +1081,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(is_direct_results, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1153,10 +1136,9 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1169,15 +1151,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1189,10 +1170,9 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1205,13 +1185,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,10 +1201,9 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1238,8 +1216,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1251,7 +1228,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1264,10 +1241,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1280,8 +1256,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1295,7 +1270,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1310,10 +1285,9 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1326,8 +1300,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1341,7 +1314,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -2230,23 +2203,14 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class, mock_result_set + self, mock_handle_execute_response, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value - # Set up the mock to return a tuple with two values - mock_execute_response = Mock() - mock_arrow_schema = Mock() - mock_handle_execute_response.return_value = ( - mock_execute_response, - mock_arrow_schema, - ) - # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From cbace3f52c025d2b414c4169555f9daeaa27581d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:20:12 +0000 Subject: [PATCH 062/105] Revert "Merge branch 'exec-models-sea' into exec-phase-sea" This reverts commit be1997e0d6b6cf0f5499db2381971ec3a015a2f7, reversing changes made to 37813ba6d1fe06d7f9f10d510a059b88dc552496. --- examples/experimental/sea_connector_test.py | 68 +-- src/databricks/sql/backend/sea/backend.py | 480 ++++++++++++++++-- src/databricks/sql/backend/sea/models/base.py | 20 +- .../sql/backend/sea/models/requests.py | 14 +- .../sql/backend/sea/models/responses.py | 29 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 137 +++-- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 159 ++++-- src/databricks/sql/utils.py | 6 +- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_sea_backend.py | 453 +++++++++++++---- tests/unit/test_sea_result_set.py | 200 ++++++++ tests/unit/test_thrift_backend.py | 138 +++-- 16 files changed, 1300 insertions(+), 466 deletions(-) create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 2553a2b20..0db326894 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,7 +10,8 @@ import subprocess from typing import List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) TEST_MODULES = [ @@ -87,48 +88,29 @@ def print_summary(results: List[Tuple[str, bool]]) -> None: logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) - - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + +if __name__ == "__main__": + # Check if required environment variables are set + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) + logger.error("Please set these variables before running the tests.") sys.exit(1) - - logger.info("SEA session test completed successfully") -if __name__ == "__main__": - test_sea_session() + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..6d627162d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,43 @@ import logging +import time import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError -from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, ) -from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions +from databricks.sql.backend.sea.models.base import ( + ResultData, + ExternalLink, + ResultManifest, +) from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +75,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -64,6 +87,8 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" + CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -96,6 +121,7 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) + self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -252,6 +278,19 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + @staticmethod + def is_session_configuration_parameter_supported(name: str) -> bool: + """ + Check if a session configuration parameter is supported. + + Args: + name: The name of the session configuration parameter + + Returns: + True if the parameter is supported, False otherwise + """ + return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP + @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -262,8 +301,182 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - # == Not Implemented Operations == - # These methods will be implemented in future iterations + def _get_schema_bytes(self, sea_response) -> Optional[bytes]: + """ + Extract schema bytes from the SEA response. + + For ARROW format, we need to get the schema bytes from the first chunk. + If the first chunk is not available, we need to get it from the server. + + Args: + sea_response: The response from the SEA API + + Returns: + bytes: The schema bytes or None if not available + """ + import requests + import lz4.frame + + # Check if we have the first chunk in the response + result_data = sea_response.get("result", {}) + external_links = result_data.get("external_links", []) + + if not external_links: + return None + + # Find the first chunk (chunk_index = 0) + first_chunk = None + for link in external_links: + if link.get("chunk_index") == 0: + first_chunk = link + break + + if not first_chunk: + # Try to fetch the first chunk from the server + statement_id = sea_response.get("statement_id") + if not statement_id: + return None + + chunks_response = self.get_chunk_links(statement_id, 0) + if not chunks_response.external_links: + return None + + first_chunk = chunks_response.external_links[0].__dict__ + + # Download the first chunk to get the schema bytes + external_link = first_chunk.get("external_link") + http_headers = first_chunk.get("http_headers", {}) + + if not external_link: + return None + + # Use requests to download the first chunk + http_response = requests.get( + external_link, + headers=http_headers, + verify=self.ssl_options.tls_verify, + ) + + if http_response.status_code != 200: + raise Error(f"Failed to download schema bytes: {http_response.text}") + + # Extract schema bytes from the Arrow file + # The schema is at the beginning of the file + data = http_response.content + if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": + data = lz4.frame.decompress(data) + + # Return the schema bytes + return data + + def _results_message_to_execute_response(self, sea_response, command_id): + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, + result data object, and manifest object + """ + # Extract status + status_data = sea_response.get("status", {}) + state = CommandState.from_sea_state(status_data.get("state", "")) + + # Extract description from manifest + description = None + manifest_data = sea_response.get("manifest", {}) + schema_data = manifest_data.get("schema", {}) + columns_data = schema_data.get("columns", []) + + if columns_data: + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + description = columns if columns else None + + # Extract schema bytes for Arrow format + schema_bytes = None + format = manifest_data.get("format") + if format == "ARROW_STREAM": + # For ARROW format, we need to get the schema bytes + schema_bytes = self._get_schema_bytes(sea_response) + + # Check for compression + lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" + + # Initialize result_data_obj and manifest_obj + result_data_obj = None + manifest_obj = None + + result_data = sea_response.get("result", {}) + if result_data: + # Convert external links + external_links = None + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers", {}), + ) + ) + + # Create the result data object + result_data_obj = ResultData( + data=result_data.get("data_array"), external_links=external_links + ) + + # Create the manifest object + manifest_obj = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + execute_response = ExecuteResponse( + command_id=command_id, + status=state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=False, + arrow_schema_bytes=schema_bytes, + result_format=manifest_data.get("format"), + ) + + return execute_response, result_data_obj, manifest_obj def execute_command( self, @@ -274,41 +487,230 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else None + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + # Convert the response to an ExecuteResponse and extract result data + ( + execute_response, + result_data, + manifest, + ) = self._results_message_to_execute_response(response_data, command_id) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + result_data=result_data, + manifest=manifest, ) # == Metadata Operations == @@ -319,9 +721,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + raise NotImplementedError("get_catalogs is not implemented for SEA backend") def get_schemas( self, @@ -331,9 +733,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + raise NotImplementedError("get_schemas is not implemented for SEA backend") def get_tables( self, @@ -345,9 +747,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_tables is not implemented for SEA backend") def get_columns( self, @@ -359,6 +761,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_columns is not implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index b12c26eb0..6175b4ca0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,29 +42,12 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None -@dataclass -class ChunkInfo: - """Information about a chunk in the result set.""" - - chunk_index: int - byte_count: int - row_offset: int - row_count: int - - @dataclass class ResultData: """Result data from a statement execution.""" data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None - byte_count: Optional[int] = None - chunk_index: Optional[int] = None - next_chunk_index: Optional[int] = None - next_chunk_internal_link: Optional[str] = None - row_count: Optional[int] = None - row_offset: Optional[int] = None - attachment: Optional[bytes] = None @dataclass @@ -90,6 +73,5 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[ChunkInfo]] = None + chunks: Optional[List[Dict[str, Any]]] = None result_compression: Optional[str] = None - is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 4c5071dba..58921d793 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Representation of a parameter for a SQL statement.""" + """Parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Representation of a request to execute a SQL statement.""" + """Request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Representation of a request to get information about a statement.""" + """Request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Representation of a request to cancel a statement.""" + """Request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Representation of a request to close a statement.""" + """Request to close a statement.""" statement_id: str @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Representation of a request to create a new session.""" + """Request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Representation of a request to delete a session.""" + """Request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,7 +14,6 @@ ResultData, ServiceError, ExternalLink, - ChunkInfo, ) @@ -44,18 +43,6 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) - chunks = None - if "chunks" in manifest_data: - chunks = [ - ChunkInfo( - chunk_index=chunk.get("chunk_index", 0), - byte_count=chunk.get("byte_count", 0), - row_offset=chunk.get("row_offset", 0), - row_count=chunk.get("row_count", 0), - ) - for chunk in manifest_data.get("chunks", []) - ] - return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -63,9 +50,8 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=chunks, + chunks=manifest_data.get("chunks"), result_compression=manifest_data.get("result_compression"), - is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -94,19 +80,12 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, - byte_count=result_data.get("byte_count"), - chunk_index=result_data.get("chunk_index"), - next_chunk_index=result_data.get("next_chunk_index"), - next_chunk_internal_link=result_data.get("next_chunk_internal_link"), - row_count=result_data.get("row_count"), - row_offset=result_data.get("row_offset"), - attachment=result_data.get("attachment"), ) @dataclass class ExecuteStatementResponse: - """Representation of the response from executing a SQL statement.""" + """Response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -126,7 +105,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Representation of the response from getting information about a statement.""" + """Response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -146,7 +125,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,20 +3,22 @@ import logging import math import time +import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, + BackendType, + guid_to_hex_id, ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id - try: import pyarrow @@ -757,13 +759,11 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - - is_direct_results = ( + has_more_rows = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,25 +779,43 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + if direct_results and direct_results.resultSet: + assert direct_results.resultSet.results.startRowOffset == 0 + assert direct_results.resultSetMetadata + + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + else: + arrow_queue_opt = None + command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - execute_response = ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) - return execute_response, is_direct_results - def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -822,6 +840,9 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -836,9 +857,15 @@ def get_execution_result( else: schema_bytes = None - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - is_direct_results = resp.hasMoreRows + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) status = self.get_query_state(command_id) @@ -846,11 +873,11 @@ def get_execution_result( command_id=command_id, status=status, description=description, + has_more_rows=has_more_rows, + results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -860,10 +887,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=resp.results, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -894,7 +918,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Unknown command state: {operation_state}") + raise ValueError(f"Invalid operation state: {operation_state}") return state @staticmethod @@ -976,14 +1000,10 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -991,10 +1011,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1016,14 +1033,10 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1031,10 +1044,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1060,14 +1070,10 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1075,10 +1081,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1108,14 +1111,10 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1123,10 +1122,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1156,14 +1152,10 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1171,10 +1163,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1188,7 +1177,11 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,9 +423,11 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None - result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cf6940bb2..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - is_direct_results: bool = False, + has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Parameters: - :param connection: The parent connection - :param backend: The backend client - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - :param command_id: The command ID - :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server - :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue - :param description: column description of the results - :param is_staging_operation: Whether the command is a staging operation + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,47 +157,25 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - t_row_set=None, - max_download_threads: int = 10, - ssl_options=None, - is_direct_results: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Parameters: - :param connection: The parent connection - :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access - :param buffer_size_bytes: Buffer size for fetching results - :param arraysize: Default number of rows to fetch - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - :param t_row_set: The TRowSet containing result data (if available) - :param max_download_threads: Maximum number of download threads for cloud fetch - :param ssl_options: SSL options for cloud fetch - :param is_direct_results: Whether there are more rows to fetch + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, @@ -207,8 +185,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - is_direct_results=is_direct_results, - results_queue=results_queue, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -218,7 +196,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -229,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -313,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -338,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -353,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -379,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -438,3 +416,76 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) + """ + + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d7b1b74b4..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2054d01d1..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - is_direct_results=True, + has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,7 +104,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -185,7 +184,6 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -212,7 +210,6 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -257,10 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) - - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -478,6 +472,7 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq + mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,30 +40,25 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - - # Create a mock backend that will return the queue when _fill_results_buffer is called - mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) - - num_cols = len(initial_results[0]) if initial_results else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - description=description, - lz4_compressed=True, + has_more_rows=False, + description=Mock(), + lz4_compressed=Mock(), + results_queue=arrow_queue, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, - t_row_set=None, + thrift_client=None, ) + num_cols = len(initial_results[0]) if initial_results else 0 + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod @@ -90,19 +85,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - description=description, - lz4_compressed=True, + has_more_rows=True, + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], + lz4_compressed=Mock(), + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e4a9e5cdd..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - is_direct_results=False, + has_more_rows=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..2fa362b8e 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,348 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response + + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, } - assert set(allowed_configs) == expected_keys - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + } + mock_http_client._make_request.return_value = execute_response + + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..b691872af --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,200 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + return mock_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.command_id == execute_response.command_id + assert result_set.status == CommandState.SUCCEEDED + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set.description == execute_response.description + + def test_close(self, mock_connection, mock_sea_client, execute_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,14 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,10 +839,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value @@ -882,10 +885,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) + self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -920,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -948,21 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -982,15 +987,15 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ) + tcli_service_instance.GetOperationStatus.return_value = op_state + tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1004,10 +1009,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1019,7 +1024,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, ), closeOperation=Mock(), @@ -1035,12 +1040,11 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - ( - execute_response, - has_more_rows_result, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) - self.assertEqual(is_direct_results, has_more_rows_result) + self.assertEqual(has_more_rows, execute_response.has_more_rows) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1049,10 +1053,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1065,7 +1069,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1098,7 +1102,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(is_direct_results, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1153,10 +1157,9 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1169,15 +1172,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1189,10 +1191,9 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1205,13 +1206,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,10 +1222,9 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1238,8 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1251,7 +1249,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1264,10 +1262,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1280,8 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1295,7 +1291,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1310,10 +1306,9 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1326,8 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1341,7 +1335,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1673,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2230,23 +2226,15 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class, mock_result_set + self, mock_handle_execute_response, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value - # Set up the mock to return a tuple with two values - mock_execute_response = Mock() - mock_arrow_schema = Mock() - mock_handle_execute_response.return_value = ( - mock_execute_response, - mock_arrow_schema, - ) - # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From c075b07164aeaf3d571aeb35c6d7227b92436aeb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:22:30 +0000 Subject: [PATCH 063/105] change logging level Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 0db326894..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,8 +10,7 @@ import subprocess from typing import List, Tuple -# Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) TEST_MODULES = [ From c62f76dce2d17f842708489da04c7a8d4255cf06 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:37:12 +0000 Subject: [PATCH 064/105] remove un-necessary changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 161 ++++++------------ src/databricks/sql/backend/sea/models/base.py | 20 ++- .../sql/backend/sea/models/requests.py | 14 +- .../sql/backend/sea/models/responses.py | 29 +++- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 137 ++++++++------- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 159 ++++++----------- src/databricks/sql/utils.py | 6 +- 9 files changed, 240 insertions(+), 296 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index edd171b05..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,115 +1,66 @@ -""" -Main script to run all SEA connector tests. - -This script runs all the individual test modules and displays -a summary of test results with visual indicators. -""" import os import sys import logging -import subprocess -from typing import List, Tuple +from databricks.sql.client import Connection logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -TEST_MODULES = [ - "test_sea_session", - "test_sea_sync_query", - "test_sea_async_query", - "test_sea_metadata", -] - - -def run_test_module(module_name: str) -> bool: - """Run a test module and return success status.""" - module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" - ) - - # Simply run the module as a script - each module handles its own test execution - result = subprocess.run( - [sys.executable, module_path], capture_output=True, text=True - ) - - # Log the output from the test module - if result.stdout: - for line in result.stdout.strip().split("\n"): - logger.info(line) - - if result.stderr: - for line in result.stderr.strip().split("\n"): - logger.error(line) - - return result.returncode == 0 - - -def run_tests() -> List[Tuple[str, bool]]: - """Run all tests and return results.""" - results = [] - - for module_name in TEST_MODULES: - try: - logger.info(f"\n{'=' * 50}") - logger.info(f"Running test: {module_name}") - logger.info(f"{'-' * 50}") - - success = run_test_module(module_name) - results.append((module_name, success)) - - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"Test {module_name}: {status}") - - except Exception as e: - logger.error(f"Error loading or running test {module_name}: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - results.append((module_name, False)) - - return results - - -def print_summary(results: List[Tuple[str, bool]]) -> None: - """Print a summary of test results.""" - logger.info(f"\n{'=' * 50}") - logger.info("TEST SUMMARY") - logger.info(f"{'-' * 50}") - - passed = sum(1 for _, success in results if success) - total = len(results) - - for module_name, success in results: - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"{status} - {module_name}") - - logger.info(f"{'-' * 50}") - logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") - logger.info(f"{'=' * 50}") - - -if __name__ == "__main__": - # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] - missing_vars = [var for var in required_vars if not os.environ.get(var)] - - if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) - logger.error("Please set these variables before running the tests.") + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) sys.exit(1) + + logger.info("SEA session test completed successfully") - # Run all tests - results = run_tests() - - # Print summary - print_summary(results) - - # Exit with appropriate status code - all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) +if __name__ == "__main__": + test_sea_session() diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 6175b4ca0..b12c26eb0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,12 +42,29 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None +@dataclass +class ChunkInfo: + """Information about a chunk in the result set.""" + + chunk_index: int + byte_count: int + row_offset: int + row_count: int + + @dataclass class ResultData: """Result data from a statement execution.""" data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + attachment: Optional[bytes] = None @dataclass @@ -73,5 +90,6 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[Dict[str, Any]]] = None + chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None + is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 58921d793..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Parameter for a SQL statement.""" + """Representation of a parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Request to execute a SQL statement.""" + """Representation of a request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Request to get information about a statement.""" + """Representation of a request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Request to cancel a statement.""" + """Representation of a request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Request to close a statement.""" + """Representation of a request to close a statement.""" statement_id: str @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Request to create a new session.""" + """Representation of a request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Request to delete a session.""" + """Representation of a request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index c16f19da3..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,6 +14,7 @@ ResultData, ServiceError, ExternalLink, + ChunkInfo, ) @@ -43,6 +44,18 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) + chunks = None + if "chunks" in manifest_data: + chunks = [ + ChunkInfo( + chunk_index=chunk.get("chunk_index", 0), + byte_count=chunk.get("byte_count", 0), + row_offset=chunk.get("row_offset", 0), + row_count=chunk.get("row_count", 0), + ) + for chunk in manifest_data.get("chunks", []) + ] + return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -50,8 +63,9 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), + chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -80,12 +94,19 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, + byte_count=result_data.get("byte_count"), + chunk_index=result_data.get("chunk_index"), + next_chunk_index=result_data.get("next_chunk_index"), + next_chunk_internal_link=result_data.get("next_chunk_internal_link"), + row_count=result_data.get("row_count"), + row_offset=result_data.get("row_offset"), + attachment=result_data.get("attachment"), ) @dataclass class ExecuteStatementResponse: - """Response from executing a SQL statement.""" + """Representation of the response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -105,7 +126,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Response from getting information about a statement.""" + """Representation of the response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -125,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,22 +3,20 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow @@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + execute_response = ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, is_direct_results + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -840,9 +822,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -857,15 +836,9 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + is_direct_results = resp.hasMoreRows status = self.get_query_state(command_id) @@ -873,11 +846,11 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -887,7 +860,10 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1000,10 +976,14 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1011,7 +991,10 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1033,10 +1016,14 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1044,7 +1031,10 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1070,10 +1060,14 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1081,7 +1075,10 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_tables( @@ -1111,10 +1108,14 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1122,7 +1123,10 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_columns( @@ -1152,10 +1156,14 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1163,7 +1171,10 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): @@ -1177,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,11 +423,9 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None + description: Optional[List[Tuple]] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, + is_direct_results: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Args: - connection: The parent connection - backend: The backend client - arraysize: The max number of rows to fetch at a time (PEP-249) - buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - command_id: The command ID - status: The command status - has_been_closed_server_side: Whether the command has been closed on the server - has_more_rows: Whether the command has more rows - results_queue: The results queue - description: column description of the results - is_staging_operation: Whether the command is a staging operation + Parameters: + :param connection: The parent connection + :param backend: The backend client + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,25 +157,47 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes + self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -185,8 +207,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + is_direct_results=is_direct_results, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -196,7 +218,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -207,7 +229,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +313,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +338,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +353,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +379,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -416,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. From 199402eb6f09e8889cfb426935d2ac911543119a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:39:18 +0000 Subject: [PATCH 065/105] remove excess changes Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 ------------------ .../experimental/tests/test_sea_metadata.py | 98 --------- .../experimental/tests/test_sea_session.py | 71 ------- .../experimental/tests/test_sea_sync_query.py | 161 --------------- 5 files changed, 521 deletions(-) delete mode 100644 examples/experimental/tests/__init__.py delete mode 100644 examples/experimental/tests/test_sea_async_query.py delete mode 100644 examples/experimental/tests/test_sea_metadata.py delete mode 100644 examples/experimental/tests/test_sea_session.py delete mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py deleted file mode 100644 index a776377c3..000000000 --- a/examples/experimental/tests/test_sea_async_query.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Test for SEA asynchronous query execution functionality. -""" -import os -import sys -import logging -import time -from databricks.sql.client import Connection -from databricks.sql.backend.types import CommandState - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_async_query_with_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch enabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_without_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch disabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_exec(): - """ - Run both asynchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info( - f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info( - f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_async_query_exec() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py deleted file mode 100644 index a200d97d3..000000000 --- a/examples/experimental/tests/test_sea_metadata.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Test for SEA metadata functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_metadata(): - """ - Test metadata operations using the SEA backend. - - This function connects to a Databricks SQL endpoint using the SEA backend, - and executes metadata operations like catalogs(), schemas(), tables(), and columns(). - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - if not catalog: - logger.error( - "DATABRICKS_CATALOG environment variable is required for metadata tests." - ) - return False - - try: - # Create connection - logger.info("Creating connection for metadata operations") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Test catalogs - cursor = connection.cursor() - logger.info("Fetching catalogs...") - cursor.catalogs() - logger.info("Successfully fetched catalogs") - - # Test schemas - logger.info(f"Fetching schemas for catalog '{catalog}'...") - cursor.schemas(catalog_name=catalog) - logger.info("Successfully fetched schemas") - - # Test tables - logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") - cursor.tables(catalog_name=catalog, schema_name="default") - logger.info("Successfully fetched tables") - - # Test columns for a specific table - # Using a common table that should exist in most environments - logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." - ) - cursor.columns( - catalog_name=catalog, schema_name="default", table_name="customer" - ) - logger.info("Successfully fetched columns") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error during SEA metadata test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_metadata() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py deleted file mode 100644 index 516c1bbb8..000000000 --- a/examples/experimental/tests/test_sea_session.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Test for SEA session management functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"Backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_session() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py deleted file mode 100644 index 07be8aafc..000000000 --- a/examples/experimental/tests/test_sea_sync_query.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Test for SEA synchronous query execution functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_sync_query_with_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_without_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_exec(): - """ - Run both synchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info( - f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info( - f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) From 8ac574ba46d7e2349fba105857e9ca2b7963e32b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:41:22 +0000 Subject: [PATCH 066/105] remove excess changes Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +++--- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_sea_result_set.py | 200 ------------------------------ tests/unit/test_thrift_backend.py | 138 +++++++++++---------- 5 files changed, 106 insertions(+), 284 deletions(-) delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..2054d01d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +257,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -472,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,25 +40,30 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] return rs @staticmethod @@ -85,19 +90,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 7e025cf82..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index b691872af..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b8de970db..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,18 +619,14 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -839,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -885,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + _, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -923,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -953,21 +948,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -987,15 +982,15 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1009,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1024,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1040,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1053,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1069,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1102,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1157,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1172,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1191,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1206,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1237,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1249,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1262,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1277,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1291,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1306,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1321,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1335,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1667,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2226,15 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From b1acc5bffd676c7382be86ad12db011a8ebb38b4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 10:46:57 +0000 Subject: [PATCH 067/105] remove _get_schema_bytes (for now) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 77 +---------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 6d627162d..1d31f2afe 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -301,74 +301,6 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - def _get_schema_bytes(self, sea_response) -> Optional[bytes]: - """ - Extract schema bytes from the SEA response. - - For ARROW format, we need to get the schema bytes from the first chunk. - If the first chunk is not available, we need to get it from the server. - - Args: - sea_response: The response from the SEA API - - Returns: - bytes: The schema bytes or None if not available - """ - import requests - import lz4.frame - - # Check if we have the first chunk in the response - result_data = sea_response.get("result", {}) - external_links = result_data.get("external_links", []) - - if not external_links: - return None - - # Find the first chunk (chunk_index = 0) - first_chunk = None - for link in external_links: - if link.get("chunk_index") == 0: - first_chunk = link - break - - if not first_chunk: - # Try to fetch the first chunk from the server - statement_id = sea_response.get("statement_id") - if not statement_id: - return None - - chunks_response = self.get_chunk_links(statement_id, 0) - if not chunks_response.external_links: - return None - - first_chunk = chunks_response.external_links[0].__dict__ - - # Download the first chunk to get the schema bytes - external_link = first_chunk.get("external_link") - http_headers = first_chunk.get("http_headers", {}) - - if not external_link: - return None - - # Use requests to download the first chunk - http_response = requests.get( - external_link, - headers=http_headers, - verify=self.ssl_options.tls_verify, - ) - - if http_response.status_code != 200: - raise Error(f"Failed to download schema bytes: {http_response.text}") - - # Extract schema bytes from the Arrow file - # The schema is at the beginning of the file - data = http_response.content - if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": - data = lz4.frame.decompress(data) - - # Return the schema bytes - return data - def _results_message_to_execute_response(self, sea_response, command_id): """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -411,13 +343,6 @@ def _results_message_to_execute_response(self, sea_response, command_id): ) description = columns if columns else None - # Extract schema bytes for Arrow format - schema_bytes = None - format = manifest_data.get("format") - if format == "ARROW_STREAM": - # For ARROW format, we need to get the schema bytes - schema_bytes = self._get_schema_bytes(sea_response) - # Check for compression lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" @@ -472,7 +397,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, - arrow_schema_bytes=schema_bytes, + arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW result_format=manifest_data.get("format"), ) From ef2a7eefcf158c6d033664fb5d844c40d07eb65e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 10:48:51 +0000 Subject: [PATCH 068/105] redundant comments Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1d31f2afe..15941d296 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -487,12 +487,11 @@ def execute_command( # Store the command ID in the cursor cursor.active_command_id = command_id - # If async operation, return None and let the client poll for results + # If async operation, return and let the client poll for results if async_op: return None # For synchronous operation, wait for the statement to complete - # Poll until the statement is done status = response.status state = status.state From af8f74e9f3c8bce7d484d312e6f6123d5e770edd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:39:14 +0000 Subject: [PATCH 069/105] remove fetch phase methods Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 15941d296..42903d09d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -87,8 +87,6 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" - CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" - CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -278,19 +276,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) - @staticmethod - def is_session_configuration_parameter_supported(name: str) -> bool: - """ - Check if a session configuration parameter is supported. - - Args: - name: The name of the session configuration parameter - - Returns: - True if the parameter is supported, False otherwise - """ - return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP - @staticmethod def get_allowed_session_configurations() -> List[str]: """ From 5540c5c4a8198f5820e275a379110c13d86e0517 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:45:56 +0000 Subject: [PATCH 070/105] reduce code repetititon + introduce gaps after multi line pydocs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 78 +++++-------------- .../sql/backend/sea/models/responses.py | 18 ++--- tests/unit/test_sea_backend.py | 2 +- 3 files changed, 30 insertions(+), 68 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 42903d09d..0e34d2470 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -40,6 +40,11 @@ GetStatementResponse, CreateSessionResponse, ) +from databricks.sql.backend.sea.models.responses import ( + parse_status, + parse_manifest, + parse_result, +) logger = logging.getLogger(__name__) @@ -75,9 +80,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -119,7 +121,6 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) - self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -298,16 +299,16 @@ def _results_message_to_execute_response(self, sea_response, command_id): tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, result data object, and manifest object """ - # Extract status - status_data = sea_response.get("status", {}) - state = CommandState.from_sea_state(status_data.get("state", "")) - # Extract description from manifest + # Parse the response + status = parse_status(sea_response) + manifest_obj = parse_manifest(sea_response) + result_data_obj = parse_result(sea_response) + + # Extract description from manifest schema description = None - manifest_data = sea_response.get("manifest", {}) - schema_data = manifest_data.get("schema", {}) + schema_data = manifest_obj.schema columns_data = schema_data.get("columns", []) - if columns_data: columns = [] for col_data in columns_data: @@ -329,61 +330,17 @@ def _results_message_to_execute_response(self, sea_response, command_id): description = columns if columns else None # Check for compression - lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" - - # Initialize result_data_obj and manifest_obj - result_data_obj = None - manifest_obj = None - - result_data = sea_response.get("result", {}) - if result_data: - # Convert external links - external_links = None - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers", {}), - ) - ) - - # Create the result data object - result_data_obj = ResultData( - data=result_data.get("data_array"), external_links=external_links - ) - - # Create the manifest object - manifest_obj = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) + lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME" execute_response = ExecuteResponse( command_id=command_id, - status=state, + status=status.state, description=description, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW - result_format=manifest_data.get("format"), + result_format=manifest_obj.format, ) return execute_response, result_data_obj, manifest_obj @@ -419,6 +376,7 @@ def execute_command( Returns: ResultSet: A SeaResultSet instance for the executed command """ + if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") @@ -506,6 +464,7 @@ def cancel_command(self, command_id: CommandId) -> None: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -528,6 +487,7 @@ def close_command(self, command_id: CommandId) -> None: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -553,6 +513,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -587,6 +548,7 @@ def get_execution_result( Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..dae37b1ae 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -18,7 +18,7 @@ ) -def _parse_status(data: Dict[str, Any]) -> StatementStatus: +def parse_status(data: Dict[str, Any]) -> StatementStatus: """Parse status from response data.""" status_data = data.get("status", {}) error = None @@ -40,7 +40,7 @@ def _parse_status(data: Dict[str, Any]) -> StatementStatus: ) -def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: +def parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) @@ -69,7 +69,7 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: ) -def _parse_result(data: Dict[str, Any]) -> ResultData: +def parse_result(data: Dict[str, Any]) -> ResultData: """Parse result data from response data.""" result_data = data.get("result", {}) external_links = None @@ -118,9 +118,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=_parse_status(data), - manifest=_parse_manifest(data), - result=_parse_result(data), + status=parse_status(data), + manifest=parse_manifest(data), + result=parse_result(data), ) @@ -138,9 +138,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=_parse_status(data), - manifest=_parse_manifest(data), - result=_parse_result(data), + status=parse_status(data), + manifest=parse_manifest(data), + result=parse_result(data), ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..01424a4d2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -536,7 +536,7 @@ def test_get_execution_result( print(result) # Verify basic properties of the result - assert result.statement_id == "test-statement-123" + assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED # Verify the HTTP request From efe3881c1b4f7ff31305bcf64a7e39acfd72e590 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:46:53 +0000 Subject: [PATCH 071/105] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 0e34d2470..03080bf5a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -19,14 +19,9 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions -from databricks.sql.backend.sea.models.base import ( - ResultData, - ExternalLink, - ResultManifest, -) from databricks.sql.backend.sea.models import ( ExecuteStatementRequest, From 36ab59bbdb3e942ede39a2f32844bf3697d15a33 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:51:04 +0000 Subject: [PATCH 072/105] move description extraction to helper func Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 60 ++++++++++++++--------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 03080bf5a..014912c8f 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -282,6 +282,43 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) + def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: + """ + Extract column description from a manifest object. + + Args: + manifest_obj: The ResultManifest object containing schema information + + Returns: + Optional[List]: A list of column tuples or None if no columns are found + """ + + schema_data = manifest_obj.schema + columns_data = schema_data.get("columns", []) + + if not columns_data: + return None + + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + + return columns if columns else None + def _results_message_to_execute_response(self, sea_response, command_id): """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -301,28 +338,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): result_data_obj = parse_result(sea_response) # Extract description from manifest schema - description = None - schema_data = manifest_obj.schema - columns_data = schema_data.get("columns", []) - if columns_data: - columns = [] - for col_data in columns_data: - if not isinstance(col_data, dict): - continue - - # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) - columns.append( - ( - col_data.get("name", ""), # name - col_data.get("type_name", ""), # type_code - None, # display_size (not provided by SEA) - None, # internal_size (not provided by SEA) - col_data.get("precision"), # precision - col_data.get("scale"), # scale - col_data.get("nullable", True), # null_ok - ) - ) - description = columns if columns else None + description = self._extract_description_from_manifest(manifest_obj) # Check for compression lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME" From 1d57c996afff5727c1e66a36e9da82f75777d6f1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:52:06 +0000 Subject: [PATCH 073/105] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 014912c8f..1dde8e4dc 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -295,10 +295,10 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: schema_data = manifest_obj.schema columns_data = schema_data.get("columns", []) - + if not columns_data: return None - + columns = [] for col_data in columns_data: if not isinstance(col_data, dict): @@ -316,7 +316,7 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: col_data.get("nullable", True), # null_ok ) ) - + return columns if columns else None def _results_message_to_execute_response(self, sea_response, command_id): From df6dac2bd84b7e3e2b71f51469571396166a5b34 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:20:49 +0000 Subject: [PATCH 074/105] add more unit tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 299 ++++++++++++++++++++++++++++++++- 1 file changed, 296 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 01424a4d2..e6d293e5f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -9,12 +9,15 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.backend import ( + SeaDatabricksClient, + _filter_session_configuration, +) from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError class TestSeaBackend: @@ -305,6 +308,32 @@ def test_execute_command_async( assert isinstance(mock_cursor.active_command_id, CommandId) assert mock_cursor.active_command_id.guid == "test-statement-456" + def test_execute_command_async_missing_statement_id( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing an async command that returns no statement ID.""" + # Set up mock response with status but no statement_id + mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} + + # Call the method and expect an error + with pytest.raises(ServerOperationError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, + ) + + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value + ) + def test_execute_command_with_polling( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): @@ -442,6 +471,32 @@ def test_execute_command_failure( assert "Statement execution did not succeed" in str(excinfo.value) + def test_execute_command_missing_statement_id( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that returns no statement ID.""" + # Set up mock response with status but no statement_id + mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} + + # Call the method and expect an error + with pytest.raises(ServerOperationError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value + ) + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): """Test canceling a command.""" # Set up mock response @@ -533,7 +588,6 @@ def test_get_execution_result( # Create a real result set to verify the implementation result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) # Verify basic properties of the result assert result.command_id.to_sea_statement_id() == "test-statement-123" @@ -546,3 +600,242 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) + + def test_get_execution_result_with_invalid_command_id( + self, sea_client, mock_cursor + ): + """Test getting execution result with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.get_execution_result(command_id, mock_cursor) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_max_download_threads_property(self, mock_http_client): + """Test the max_download_threads property.""" + # Test with default value + client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client.max_download_threads == 10 + + # Test with custom value + client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client.max_download_threads == 5 + + def test_get_default_session_configuration_value(self): + """Test the get_default_session_configuration_value static method.""" + # Test with supported configuration parameter + value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE") + assert value == "true" + + # Test with unsupported configuration parameter + value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert value is None + + # Test with case-insensitive parameter name + value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode") + assert value == "true" + + def test_get_allowed_session_configurations(self): + """Test the get_allowed_session_configurations static method.""" + configs = SeaDatabricksClient.get_allowed_session_configurations() + assert isinstance(configs, list) + assert len(configs) > 0 + assert "ANSI_MODE" in configs + + def test_extract_description_from_manifest(self, sea_client): + """Test the _extract_description_from_manifest method.""" + # Test with valid manifest containing columns + manifest_obj = MagicMock() + manifest_obj.schema = { + "columns": [ + { + "name": "col1", + "type_name": "STRING", + "precision": 10, + "scale": 2, + "nullable": True, + }, + { + "name": "col2", + "type_name": "INT", + "nullable": False, + }, + ] + } + + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is not None + assert len(description) == 2 + + # Check first column + assert description[0][0] == "col1" # name + assert description[0][1] == "STRING" # type_code + assert description[0][4] == 10 # precision + assert description[0][5] == 2 # scale + assert description[0][6] is True # null_ok + + # Check second column + assert description[1][0] == "col2" # name + assert description[1][1] == "INT" # type_code + assert description[1][6] is False # null_ok + + # Test with manifest containing non-dict column + manifest_obj.schema = {"columns": ["not_a_dict"]} + description = sea_client._extract_description_from_manifest(manifest_obj) + assert ( + description is None + ) # Method returns None when no valid columns are found + + # Test with manifest without columns + manifest_obj.schema = {} + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is None + + def test_cancel_command_with_invalid_command_id(self, sea_client): + """Test canceling a command with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_close_command_with_invalid_command_id(self, sea_client): + """Test closing a command with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_get_query_state_with_invalid_command_id(self, sea_client): + """Test getting query state with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_unimplemented_metadata_methods( + self, sea_client, sea_session_id, mock_cursor + ): + """Test that metadata methods raise NotImplementedError.""" + # Test get_catalogs + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) + assert "get_catalogs is not implemented for SEA backend" in str(excinfo.value) + + # Test get_schemas + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) + assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) + + # Test get_schemas with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas( + sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" + ) + assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) + + # Test get_tables + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) + assert "get_tables is not implemented for SEA backend" in str(excinfo.value) + + # Test get_tables with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + table_types=["TABLE", "VIEW"], + ) + assert "get_tables is not implemented for SEA backend" in str(excinfo.value) + + # Test get_columns + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) + assert "get_columns is not implemented for SEA backend" in str(excinfo.value) + + # Test get_columns with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + column_name="column", + ) + assert "get_columns is not implemented for SEA backend" in str(excinfo.value) + + def test_execute_command_with_invalid_session_id(self, sea_client, mock_cursor): + """Test executing a command with an invalid session ID type.""" + # Create a Thrift session ID (not SEA) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Not a valid SEA session ID" in str(excinfo.value) From ad0e527c6a67ba5d8d89d63655c33f27d2acbe7a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:34:25 +0000 Subject: [PATCH 075/105] streamline unit tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 534 ++++++++++----------------------- 1 file changed, 166 insertions(+), 368 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index e6d293e5f..4b1ec55a3 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -5,7 +5,6 @@ the Databricks SQL connector's SEA backend functionality. """ -import json import pytest from unittest.mock import patch, MagicMock, Mock @@ -13,7 +12,6 @@ SeaDatabricksClient, _filter_session_configuration, ) -from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider @@ -68,10 +66,28 @@ def mock_cursor(self): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None + cursor.buffer_size_bytes = 1000 + cursor.arraysize = 100 return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): - """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + @pytest.fixture + def thrift_session_id(self): + """Create a Thrift session ID (not SEA).""" + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + return SessionId.from_thrift_handle(mock_thrift_handle) + + @pytest.fixture + def thrift_command_id(self): + """Create a Thrift command ID (not SEA).""" + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + return CommandId.from_thrift_handle(mock_thrift_operation_handle) + + def test_initialization(self, mock_http_client): + """Test client initialization and warehouse ID extraction.""" # Test with warehouses format client1 = SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -82,6 +98,7 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ssl_options=SSLOptions(), ) assert client1.warehouse_id == "abc123" + assert client1.max_download_threads == 10 # Default value # Test with endpoints format client2 = SeaDatabricksClient( @@ -94,8 +111,19 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ) assert client2.warehouse_id == "def456" - def test_init_raises_error_for_invalid_http_path(self, mock_http_client): - """Test that the constructor raises an error for invalid HTTP paths.""" + # Test with custom max_download_threads + client3 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client3.max_download_threads == 5 + + # Test with invalid HTTP path with pytest.raises(ValueError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -107,30 +135,21 @@ def test_init_raises_error_for_invalid_http_path(self, mock_http_client): ) assert "Could not extract warehouse ID" in str(excinfo.value) - def test_open_session_basic(self, sea_client, mock_http_client): - """Test the open_session method with minimal parameters.""" - # Set up mock response + def test_session_management(self, sea_client, mock_http_client, thrift_session_id): + """Test session management methods.""" + # Test open_session with minimal parameters mock_http_client._make_request.return_value = {"session_id": "test-session-123"} - - # Call the method session_id = sea_client.open_session(None, None, None) - - # Verify the result assert isinstance(session_id, SessionId) assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-123" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} ) - def test_open_session_with_all_parameters(self, sea_client, mock_http_client): - """Test the open_session method with all parameters.""" - # Set up mock response + # Test open_session with all parameters + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"session_id": "test-session-456"} - - # Call the method with all parameters, including both supported and unsupported configurations session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter @@ -138,16 +157,8 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): } catalog = "test_catalog" schema = "test_schema" - session_id = sea_client.open_session(session_config, catalog, schema) - - # Verify the result - assert isinstance(session_id, SessionId) - assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-456" - - # Verify the HTTP request - only supported parameters should be included - # and keys should be in lowercase expected_data = { "warehouse_id": "abc123", "session_confs": { @@ -157,60 +168,37 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): "catalog": catalog, "schema": schema, } - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data=expected_data ) - def test_open_session_error_handling(self, sea_client, mock_http_client): - """Test error handling in the open_session method.""" - # Set up mock response without session_id + # Test open_session error handling + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {} - - # Call the method and expect an error with pytest.raises(Error) as excinfo: sea_client.open_session(None, None, None) - assert "Failed to create session" in str(excinfo.value) - def test_close_session_valid_id(self, sea_client, mock_http_client): - """Test closing a session with a valid session ID.""" - # Create a valid SEA session ID + # Test close_session with valid ID + mock_http_client.reset_mock() session_id = SessionId.from_sea_session_id("test-session-789") - - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method sea_client.close_session(session_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="DELETE", path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) - def test_close_session_invalid_id_type(self, sea_client): - """Test closing a session with an invalid session ID type.""" - # Create a Thrift session ID (not SEA) - mock_thrift_handle = MagicMock() - mock_thrift_handle.sessionId.guid = b"guid" - mock_thrift_handle.sessionId.secret = b"secret" - session_id = SessionId.from_thrift_handle(mock_thrift_handle) - - # Call the method and expect an error + # Test close_session with invalid ID type with pytest.raises(ValueError) as excinfo: - sea_client.close_session(session_id) - + sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( + def test_command_execution_sync( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command synchronously.""" - # Set up mock responses + """Test synchronous command execution.""" + # Test synchronous execution execute_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, @@ -230,11 +218,9 @@ def test_execute_command_sync( } mock_http_client._make_request.return_value = execute_response - # Mock the get_execution_result method with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: - # Call the method result = sea_client.execute_command( operation="SELECT 1", session_id=sea_session_id, @@ -247,38 +233,43 @@ def test_execute_command_sync( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the result assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() cmd_id_arg = mock_get_result.call_args[0][0] assert isinstance(cmd_id_arg, CommandId) assert cmd_id_arg.guid == "test-statement-123" - def test_execute_command_async( + # Test with invalid session ID + with pytest.raises(ValueError) as excinfo: + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + thrift_session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + sea_client.execute_command( + operation="SELECT 1", + session_id=thrift_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_command_execution_async( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command asynchronously.""" - # Set up mock response + """Test asynchronous command execution.""" + # Test asynchronous execution execute_response = { "statement_id": "test-statement-456", "status": {"state": "PENDING"}, } mock_http_client._make_request.return_value = execute_response - # Call the method result = sea_client.execute_command( operation="SELECT 1", session_id=sea_session_id, @@ -288,34 +279,16 @@ def test_execute_command_async( cursor=mock_cursor, use_cloud_fetch=False, parameters=[], - async_op=True, # Async mode + async_op=True, enforce_embedded_schema_correctness=False, ) - - # Verify the result is None for async operation assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") assert isinstance(mock_cursor.active_command_id, CommandId) assert mock_cursor.active_command_id.guid == "test-statement-456" - def test_execute_command_async_missing_statement_id( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing an async command that returns no statement ID.""" - # Set up mock response with status but no statement_id + # Test async with missing statement ID + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} - - # Call the method and expect an error with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -326,19 +299,18 @@ def test_execute_command_async_missing_statement_id( cursor=mock_cursor, use_cloud_fetch=False, parameters=[], - async_op=True, # Async mode + async_op=True, enforce_embedded_schema_correctness=False, ) - assert "Failed to execute command: No statement ID returned" in str( excinfo.value ) - def test_execute_command_with_polling( + def test_command_execution_advanced( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling + """Test advanced command execution scenarios.""" + # Test with polling initial_response = { "statement_id": "test-statement-789", "status": {"state": "RUNNING"}, @@ -349,17 +321,12 @@ def test_execute_command_with_polling( "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, "result": {"data": []}, } - - # Configure mock to return different responses on subsequent calls mock_http_client._make_request.side_effect = [initial_response, poll_response] - # Mock the get_execution_result method with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: - # Mock time.sleep to avoid actual delays with patch("time.sleep"): - # Call the method result = sea_client.execute_command( operation="SELECT * FROM large_table", session_id=sea_session_id, @@ -372,39 +339,22 @@ def test_execute_command_with_polling( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the result assert result == "mock_result_set" - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response + # Test with parameters + mock_http_client.reset_mock() + mock_http_client._make_request.side_effect = None # Reset side_effect execute_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - - # Create parameter mock param = MagicMock() param.name = "param1" param.value = "value1" param.type = "STRING" - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( operation="SELECT * FROM table WHERE col = :param1", session_id=sea_session_id, @@ -417,9 +367,6 @@ def test_execute_command_with_parameters( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() args, kwargs = mock_http_client._make_request.call_args assert "parameters" in kwargs["data"] assert len(kwargs["data"]["parameters"]) == 1 @@ -427,11 +374,8 @@ def test_execute_command_with_parameters( assert kwargs["data"]["parameters"][0]["value"] == "value1" assert kwargs["data"]["parameters"][0]["type"] == "STRING" - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution + # Test execution failure + mock_http_client.reset_mock() error_response = { "statement_id": "test-statement-123", "status": { @@ -442,43 +386,30 @@ def test_execute_command_failure( }, }, } + mock_http_client._make_request.return_value = error_response - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_execute_command_missing_statement_id( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that returns no statement ID.""" - # Set up mock response with status but no statement_id + with patch.object( + sea_client, "get_query_state", return_value=CommandState.FAILED + ): + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Statement execution did not succeed" in str(excinfo.value) + + # Test missing statement ID + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} - - # Call the method and expect an error with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -492,70 +423,68 @@ def test_execute_command_missing_statement_id( async_op=False, enforce_embedded_schema_correctness=False, ) - assert "Failed to execute command: No statement ID returned" in str( excinfo.value ) - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response + def test_command_management( + self, + sea_client, + mock_http_client, + sea_command_id, + thrift_command_id, + mock_cursor, + ): + """Test command management methods.""" + # Test cancel_command mock_http_client._make_request.return_value = {} - - # Call the method sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="POST", + path=sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} + # Test cancel_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) - # Call the method + # Test close_command + mock_http_client.reset_mock() sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="DELETE", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response + # Test close_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_query_state + mock_http_client.reset_mock() mock_http_client._make_request.return_value = { "statement_id": "test-statement-123", "status": {"state": "RUNNING"}, } - - # Call the method state = sea_client.get_query_state(sea_command_id) - - # Verify the result assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="GET", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response + # Test get_query_state with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_execution_result + mock_http_client.reset_mock() sea_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, @@ -585,66 +514,18 @@ def test_get_execution_result( }, } mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation result = sea_client.get_execution_result(sea_command_id, mock_cursor) - - # Verify basic properties of the result assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result_with_invalid_command_id( - self, sea_client, mock_cursor - ): - """Test getting execution result with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error + # Test get_execution_result with invalid ID with pytest.raises(ValueError) as excinfo: - sea_client.get_execution_result(command_id, mock_cursor) - + sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) - def test_max_download_threads_property(self, mock_http_client): - """Test the max_download_threads property.""" - # Test with default value - client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) - assert client.max_download_threads == 10 - - # Test with custom value - client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=5, - ) - assert client.max_download_threads == 5 - - def test_get_default_session_configuration_value(self): - """Test the get_default_session_configuration_value static method.""" - # Test with supported configuration parameter + def test_utility_methods(self, sea_client): + """Test utility methods.""" + # Test get_default_session_configuration_value value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE") assert value == "true" @@ -658,16 +539,13 @@ def test_get_default_session_configuration_value(self): value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode") assert value == "true" - def test_get_allowed_session_configurations(self): - """Test the get_allowed_session_configurations static method.""" + # Test get_allowed_session_configurations configs = SeaDatabricksClient.get_allowed_session_configurations() assert isinstance(configs, list) assert len(configs) > 0 assert "ANSI_MODE" in configs - def test_extract_description_from_manifest(self, sea_client): - """Test the _extract_description_from_manifest method.""" - # Test with valid manifest containing columns + # Test _extract_description_from_manifest manifest_obj = MagicMock() manifest_obj.schema = { "columns": [ @@ -689,15 +567,11 @@ def test_extract_description_from_manifest(self, sea_client): description = sea_client._extract_description_from_manifest(manifest_obj) assert description is not None assert len(description) == 2 - - # Check first column assert description[0][0] == "col1" # name assert description[0][1] == "STRING" # type_code assert description[0][4] == 10 # precision assert description[0][5] == 2 # scale assert description[0][6] is True # null_ok - - # Check second column assert description[1][0] == "col2" # name assert description[1][1] == "INT" # type_code assert description[1][6] is False # null_ok @@ -705,85 +579,37 @@ def test_extract_description_from_manifest(self, sea_client): # Test with manifest containing non-dict column manifest_obj.schema = {"columns": ["not_a_dict"]} description = sea_client._extract_description_from_manifest(manifest_obj) - assert ( - description is None - ) # Method returns None when no valid columns are found + assert description is None # Test with manifest without columns manifest_obj.schema = {} description = sea_client._extract_description_from_manifest(manifest_obj) assert description is None - def test_cancel_command_with_invalid_command_id(self, sea_client): - """Test canceling a command with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.cancel_command(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - - def test_close_command_with_invalid_command_id(self, sea_client): - """Test closing a command with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.close_command(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - - def test_get_query_state_with_invalid_command_id(self, sea_client): - """Test getting query state with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.get_query_state(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - def test_unimplemented_metadata_methods( self, sea_client, sea_session_id, mock_cursor ): """Test that metadata methods raise NotImplementedError.""" # Test get_catalogs - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - assert "get_catalogs is not implemented for SEA backend" in str(excinfo.value) # Test get_schemas - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_schemas( sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" ) - assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) # Test get_tables - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - assert "get_tables is not implemented for SEA backend" in str(excinfo.value) # Test get_tables with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_tables( sea_session_id, 100, @@ -794,15 +620,13 @@ def test_unimplemented_metadata_methods( table_name="table", table_types=["TABLE", "VIEW"], ) - assert "get_tables is not implemented for SEA backend" in str(excinfo.value) # Test get_columns - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - assert "get_columns is not implemented for SEA backend" in str(excinfo.value) # Test get_columns with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_columns( sea_session_id, 100, @@ -813,29 +637,3 @@ def test_unimplemented_metadata_methods( table_name="table", column_name="column", ) - assert "get_columns is not implemented for SEA backend" in str(excinfo.value) - - def test_execute_command_with_invalid_session_id(self, sea_client, mock_cursor): - """Test executing a command with an invalid session ID type.""" - # Create a Thrift session ID (not SEA) - mock_thrift_handle = MagicMock() - mock_thrift_handle.sessionId.guid = b"guid" - mock_thrift_handle.sessionId.secret = b"secret" - session_id = SessionId.from_thrift_handle(mock_thrift_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Not a valid SEA session ID" in str(excinfo.value) From ed446a0fe240d27626fa70657005f7f8ce065766 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:37:24 +0000 Subject: [PATCH 076/105] test getting the list of allowed configurations Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 4b1ec55a3..1d16763be 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -545,6 +545,20 @@ def test_utility_methods(self, sea_client): assert len(configs) > 0 assert "ANSI_MODE" in configs + # Test getting the list of allowed configurations with specific keys + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", + } + assert set(allowed_configs) == expected_keys + # Test _extract_description_from_manifest manifest_obj = MagicMock() manifest_obj.schema = { From 38e4b5c25517146acb90ae962a02cbb6a5c3b98e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:38:50 +0000 Subject: [PATCH 077/105] reduce diff Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1dde8e4dc..cf10c904a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -604,8 +604,8 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -616,8 +616,8 @@ def get_schemas( catalog_name: Optional[str] = None, schema_name: Optional[str] = None, ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -630,8 +630,8 @@ def get_tables( table_name: Optional[str] = None, table_types: Optional[List[str]] = None, ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -644,5 +644,5 @@ def get_columns( table_name: Optional[str] = None, column_name: Optional[str] = None, ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 94879c017ce2db6e289c46c47b51a7296c0db678 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:39:28 +0000 Subject: [PATCH 078/105] reduce diff Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index cf10c904a..e892e10e7 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -603,7 +603,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") @@ -615,7 +615,7 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_schemas is not yet implemented for SEA backend") @@ -629,7 +629,7 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_tables is not yet implemented for SEA backend") @@ -643,6 +643,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 18099560157074870d83f1a43146c1687962a92d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 13 Jun 2025 03:38:43 +0000 Subject: [PATCH 079/105] house constants in enums for readability and immutability Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 20 ++++++++++--- .../sql/backend/sea/utils/constants.py | 29 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index e892e10e7..4602db3b7 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -5,6 +5,10 @@ from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, + ResultFormat, + ResultDisposition, + ResultCompression, + WaitTimeout, ) if TYPE_CHECKING: @@ -405,9 +409,17 @@ def execute_command( ) ) - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else None + format = ( + ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY + ).value + disposition = ( + ResultDisposition.EXTERNAL_LINKS + if use_cloud_fetch + else ResultDisposition.INLINE + ).value + result_compression = ( + ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE + ).value request = ExecuteStatementRequest( warehouse_id=self.warehouse_id, @@ -415,7 +427,7 @@ def execute_command( statement=operation, disposition=disposition, format=format, - wait_timeout="0s" if async_op else "10s", + wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, on_wait_timeout="CONTINUE", row_limit=max_rows, parameters=sea_parameters if sea_parameters else None, diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 9160ef6ad..cd5cc657d 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -3,6 +3,7 @@ """ from typing import Dict +from enum import Enum # from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { @@ -15,3 +16,31 @@ "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", } + + +class ResultFormat(Enum): + """Enum for result format values.""" + + ARROW_STREAM = "ARROW_STREAM" + JSON_ARRAY = "JSON_ARRAY" + + +class ResultDisposition(Enum): + """Enum for result disposition values.""" + + EXTERNAL_LINKS = "EXTERNAL_LINKS" + INLINE = "INLINE" + + +class ResultCompression(Enum): + """Enum for result compression values.""" + + LZ4_FRAME = "LZ4_FRAME" + NONE = None + + +class WaitTimeout(Enum): + """Enum for wait timeout values.""" + + ASYNC = "0s" + SYNC = "10s" From da5260cd82ffcdd31ed6393d0d0101c41fc7fcc7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 13 Jun 2025 03:39:16 +0000 Subject: [PATCH 080/105] add note on hybrid disposition Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index cd5cc657d..7481a90db 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -28,6 +28,7 @@ class ResultFormat(Enum): class ResultDisposition(Enum): """Enum for result disposition values.""" + # TODO: add support for hybrid disposition EXTERNAL_LINKS = "EXTERNAL_LINKS" INLINE = "INLINE" From 0385ffb03a3684d5a00f74eed32610cacbc34331 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 06:31:29 +0000 Subject: [PATCH 081/105] remove redundant note on arrow_schema_bytes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 4602db3b7..b829f0644 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -354,7 +354,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, - arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW + arrow_schema_bytes=None, result_format=manifest_obj.format, ) From 62298486dd4d4d20ee5503a32ab73ca70a609294 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:37:16 +0000 Subject: [PATCH 082/105] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/requests.py | 4 +- .../sql/backend/sea/models/responses.py | 2 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 130 ++++++++++-------- src/databricks/sql/result_set.py | 84 ++++++----- src/databricks/sql/utils.py | 6 +- 6 files changed, 131 insertions(+), 97 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 8524275d4..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Request to create a new session.""" + """Representation of a request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Request to delete a session.""" + """Representation of a request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 43283a8b0..dae37b1ae 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 48e9a115f..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,24 +3,21 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow except ImportError: @@ -760,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -780,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + execute_response = ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, is_direct_results + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -841,9 +822,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -858,25 +836,21 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + is_direct_results = resp.hasMoreRows + + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -886,7 +860,10 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,10 +976,14 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1010,7 +991,10 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1032,10 +1016,14 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1043,7 +1031,10 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1069,10 +1060,14 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1080,7 +1075,10 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_tables( @@ -1110,10 +1108,14 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1121,7 +1123,10 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_columns( @@ -1151,10 +1156,14 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1162,7 +1171,10 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index b2ecd00f0..38b8a3c2f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, + is_direct_results: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -51,18 +51,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Args: - connection: The parent connection - backend: The backend client - arraysize: The max number of rows to fetch at a time (PEP-249) - buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - command_id: The command ID - status: The command status - has_been_closed_server_side: Whether the command has been closed on the server - has_more_rows: Whether the command has more rows - results_queue: The results queue - description: column description of the results - is_staging_operation: Whether the command is a staging operation + Parameters: + :param connection: The parent connection + :param backend: The backend client + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -74,7 +74,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation self.lz4_compressed = lz4_compressed @@ -161,25 +161,47 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -189,8 +211,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + is_direct_results=is_direct_results, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, @@ -202,7 +224,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -213,7 +235,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -297,7 +319,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -322,7 +344,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -337,7 +359,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -363,7 +385,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. From fd5235606bbf307432e375a57e760319fc78709e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:39:42 +0000 Subject: [PATCH 083/105] remove un-necessary test changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 8 +++--- tests/unit/test_client.py | 11 +++++--- tests/unit/test_fetches.py | 39 ++++++++++++++++------------- tests/unit/test_fetches_bench.py | 2 +- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,11 +423,9 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None + description: Optional[List[Tuple]] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..2054d01d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +257,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -472,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,25 +40,30 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] return rs @staticmethod @@ -85,19 +90,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 7e025cf82..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, From 64e58b05415591a22feb4ab8ed52440c63be0d49 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:41:51 +0000 Subject: [PATCH 084/105] remove un-necessary changes in thrift backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 106 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 8274190fe..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -623,7 +623,10 @@ def test_handle_execute_response_sets_compression_in_direct_results( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -832,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -878,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + _, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -947,8 +951,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -973,8 +983,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -988,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1003,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1019,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1032,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1048,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1081,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1136,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1151,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1170,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1185,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1201,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1216,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1228,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1241,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1256,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1270,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1285,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1300,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1314,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -2203,14 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From 0a2cdfd7a08fcf48db3eb80b475315e56f876921 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:43:37 +0000 Subject: [PATCH 085/105] remove unimplemented methods test Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 52 ---------------------------------- 1 file changed, 52 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1d16763be..1434ed831 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -599,55 +599,3 @@ def test_utility_methods(self, sea_client): manifest_obj.schema = {} description = sea_client._extract_description_from_manifest(manifest_obj) assert description is None - - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" - ) - - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], - ) - - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", - ) From cd22389fcc12713ec0c24715001b9067f856242b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 05:16:36 +0000 Subject: [PATCH 086/105] remove invalid import Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 373a1b6d1..24a8880af 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -29,7 +29,6 @@ from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.types import ExecuteResponse -from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite From 5ab9bbe4fff28a60eb35439130a589b83375789b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:34:26 +0000 Subject: [PATCH 087/105] better align queries with JDBC impl Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 3b9d92151..49534ea16 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -645,7 +645,7 @@ def get_schemas( if not catalog_name: raise ValueError("Catalog name is required for get_schemas") - operation = f"SHOW SCHEMAS IN `{catalog_name}`" + operation = f"SHOW SCHEMAS IN {catalog_name}" if schema_name: operation += f" LIKE '{schema_name}'" @@ -683,7 +683,7 @@ def get_tables( operation = "SHOW TABLES IN " + ( "ALL CATALOGS" if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" + else f"CATALOG {catalog_name}" ) if schema_name: @@ -706,7 +706,7 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" - # Apply client-side filtering by table_types if specified + # Apply client-side filtering by table_types from databricks.sql.backend.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) @@ -728,7 +728,7 @@ def get_columns( if not catalog_name: raise ValueError("Catalog name is required for get_columns") - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + operation = f"SHOW COLUMNS IN CATALOG {catalog_name}" if schema_name: operation += f" SCHEMA LIKE '{schema_name}'" From 1ab6e8793b04c3065fbe49f9a42d6a3ddb83feed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:38:37 +0000 Subject: [PATCH 088/105] line breaks after multi-line PRs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index ec91d87da..2966f6797 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -49,6 +49,7 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ + # Get all remaining rows all_rows = result_set.results.remaining_rows() @@ -108,6 +109,7 @@ def filter_by_column_values( Returns: A filtered result set """ + # Convert to uppercase for case-insensitive comparison if needed if not case_sensitive: allowed_values = [v.upper() for v in allowed_values] @@ -154,6 +156,7 @@ def filter_tables_by_type( Returns: A filtered result set containing only tables of the specified types """ + # Default table types if none specified DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] valid_types = ( From f469c24c09f82b8d747d4b93b73fdf8380e7c0a5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:59:02 +0000 Subject: [PATCH 089/105] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 2966f6797..f8abe26e0 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -9,16 +9,11 @@ List, Optional, Any, - Dict, Callable, - TypeVar, - Generic, cast, TYPE_CHECKING, ) -from databricks.sql.backend.types import ExecuteResponse, CommandId -from databricks.sql.backend.sea.models.base import ResultData from databricks.sql.backend.sea.backend import SeaDatabricksClient if TYPE_CHECKING: From 68ec65f039695d4c98518d676b4ac0d53cf20600 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 05:03:04 +0000 Subject: [PATCH 090/105] fix: introduce ExecuteResponse import Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index f8abe26e0..b97787889 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -15,6 +15,7 @@ ) from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.types import ExecuteResponse if TYPE_CHECKING: from databricks.sql.result_set import ResultSet, SeaResultSet From f6d873dc68b6aa15ea53bdc9c54d6f5d4a7f0106 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 07:58:15 +0000 Subject: [PATCH 091/105] remove unimplemented metadata methods test, un-necessary imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 6 +-- .../sql/backend/sea/models/responses.py | 18 +++---- tests/unit/test_filters.py | 5 -- tests/unit/test_sea_backend.py | 53 +------------------ 4 files changed, 13 insertions(+), 69 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a48a97953..9d301d3bc 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,9 +41,9 @@ CreateSessionResponse, ) from databricks.sql.backend.sea.models.responses import ( - parse_status, - parse_manifest, - parse_result, + _parse_status, + _parse_manifest, + _parse_result, ) logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index dae37b1ae..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -18,7 +18,7 @@ ) -def parse_status(data: Dict[str, Any]) -> StatementStatus: +def _parse_status(data: Dict[str, Any]) -> StatementStatus: """Parse status from response data.""" status_data = data.get("status", {}) error = None @@ -40,7 +40,7 @@ def parse_status(data: Dict[str, Any]) -> StatementStatus: ) -def parse_manifest(data: Dict[str, Any]) -> ResultManifest: +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) @@ -69,7 +69,7 @@ def parse_manifest(data: Dict[str, Any]) -> ResultManifest: ) -def parse_result(data: Dict[str, Any]) -> ResultData: +def _parse_result(data: Dict[str, Any]) -> ResultData: """Parse result data from response data.""" result_data = data.get("result", {}) external_links = None @@ -118,9 +118,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=parse_status(data), - manifest=parse_manifest(data), - result=parse_result(data), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -138,9 +138,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=parse_status(data), - manifest=parse_manifest(data), - result=parse_result(data), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 49bd1c328..d0b815b95 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -4,11 +4,6 @@ import unittest from unittest.mock import MagicMock, patch -import sys -from typing import List, Dict, Any - -# Add the necessary path to import the filter module -sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src") from databricks.sql.backend.filters import ResultSetFilter diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f30c92ed0..af4742cb2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -631,55 +631,4 @@ def test_utility_methods(self, sea_client): assert ( sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" - ) - - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], - ) - - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", - ) + \ No newline at end of file From 28675f5c46c5233159d5b0456793ffa9a246d795 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 08:28:27 +0000 Subject: [PATCH 092/105] introduce unit tests for metadata methods Signed-off-by: varun-edachali-dbx --- tests/unit/test_filters.py | 133 +++++++++++------ tests/unit/test_sea_backend.py | 253 ++++++++++++++++++++++++++++++++- 2 files changed, 342 insertions(+), 44 deletions(-) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index d0b815b95..bf8d30707 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -15,17 +15,31 @@ def setUp(self): """Set up test fixtures.""" # Create a mock SeaResultSet self.mock_sea_result_set = MagicMock() - self.mock_sea_result_set._response = { - "result": { - "data_array": [ - ["catalog1", "schema1", "table1", "TABLE", ""], - ["catalog1", "schema1", "table2", "VIEW", ""], - ["catalog1", "schema1", "table3", "SYSTEM TABLE", ""], - ["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""], - ], - "row_count": 4, - } - } + + # Set up the remaining_rows method on the results attribute + self.mock_sea_result_set.results = MagicMock() + self.mock_sea_result_set.results.remaining_rows.return_value = [ + ["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""], + ["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""], + [ + "catalog1", + "schema1", + "table3", + "owner1", + "2023-01-01", + "SYSTEM TABLE", + "", + ], + [ + "catalog1", + "schema1", + "table4", + "owner1", + "2023-01-01", + "EXTERNAL TABLE", + "", + ], + ] # Set up the connection and other required attributes self.mock_sea_result_set.connection = MagicMock() @@ -33,6 +47,7 @@ def setUp(self): self.mock_sea_result_set.buffer_size_bytes = 1000 self.mock_sea_result_set.arraysize = 100 self.mock_sea_result_set.statement_id = "test-statement-id" + self.mock_sea_result_set.lz4_compressed = False # Create a mock CommandId from databricks.sql.backend.types import CommandId, BackendType @@ -45,70 +60,102 @@ def setUp(self): ("catalog_name", "string", None, None, None, None, True), ("schema_name", "string", None, None, None, None, True), ("table_name", "string", None, None, None, None, True), + ("owner", "string", None, None, None, None, True), + ("creation_time", "string", None, None, None, None, True), ("table_type", "string", None, None, None, None, True), ("remarks", "string", None, None, None, None, True), ] self.mock_sea_result_set.has_been_closed_server_side = False + self.mock_sea_result_set._arrow_schema_bytes = None - def test_filter_tables_by_type(self): - """Test filtering tables by type.""" - # Test with specific table types - table_types = ["TABLE", "VIEW"] + def test_filter_by_column_values(self): + """Test filtering by column values with various options.""" + # Case 1: Case-sensitive filtering + allowed_values = ["table1", "table3"] - # Make the mock_sea_result_set appear to be a SeaResultSet with patch("databricks.sql.backend.filters.isinstance", return_value=True): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, table_types + # Call filter_by_column_values on the table_name column (index 2) + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, 2, allowed_values, case_sensitive=True ) # Verify the filter was applied correctly mock_sea_result_set_class.assert_called_once() - def test_filter_tables_by_type_case_insensitive(self): - """Test filtering tables by type with case insensitivity.""" - # Test with lowercase table types - table_types = ["table", "view"] + # Check the filtered data passed to the constructor + args, kwargs = mock_sea_result_set_class.call_args + result_data = kwargs.get("result_data") + self.assertIsNotNone(result_data) + self.assertEqual(len(result_data.data), 2) + self.assertIn(result_data.data[0][2], allowed_values) + self.assertIn(result_data.data[1][2], allowed_values) - # Make the mock_sea_result_set appear to be a SeaResultSet + # Case 2: Case-insensitive filtering + mock_sea_result_set_class.reset_mock() with patch("databricks.sql.backend.filters.isinstance", return_value=True): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, table_types + # Call filter_by_column_values with case-insensitive matching + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, + 2, + ["TABLE1", "TABLE3"], + case_sensitive=False, ) - - # Verify the filter was applied correctly mock_sea_result_set_class.assert_called_once() - def test_filter_tables_by_type_default(self): - """Test filtering tables by type with default types.""" - # Make the mock_sea_result_set appear to be a SeaResultSet - with patch("databricks.sql.backend.filters.isinstance", return_value=True): - with patch( - "databricks.sql.result_set.SeaResultSet" - ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated - mock_instance = MagicMock() - mock_sea_result_set_class.return_value = mock_instance + # Case 3: Unsupported result set type + mock_unsupported_result_set = MagicMock() + with patch("databricks.sql.backend.filters.isinstance", return_value=False): + with patch("databricks.sql.backend.filters.logger") as mock_logger: + result = ResultSetFilter.filter_by_column_values( + mock_unsupported_result_set, 0, ["value"], True + ) + mock_logger.warning.assert_called_once() + self.assertEqual(result, mock_unsupported_result_set) + + def test_filter_tables_by_type(self): + """Test filtering tables by type with various options.""" + # Case 1: Specific table types + table_types = ["TABLE", "VIEW"] - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, None + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types ) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(args[1], 5) # Table type column index + self.assertEqual(args[2], table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) - # Verify the filter was applied correctly - mock_sea_result_set_class.assert_called_once() + # Case 2: Default table types (None or empty list) + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) if __name__ == "__main__": diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index af4742cb2..d75359f2f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -631,4 +631,255 @@ def test_utility_methods(self, sea_client): assert ( sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - \ No newline at end of file + + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): + """Test the get_catalogs method.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call get_catalogs + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify execute_command was called with the correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result is correct + assert result == mock_result_set + + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): + """Test the get_schemas method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With catalog and schema names + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): + """Test the get_tables method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the filter_tables_by_type method + with patch( + "databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type", + return_value=mock_result_set, + ) as mock_filter: + # Case 1: With catalog name only + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, None) + + # Case 2: With all parameters + table_types = ["TABLE", "VIEW"] + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + table_types=table_types, + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog SCHEMA LIKE 'test_schema' LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, table_types) + + # Case 3: With wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 4: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): + """Test the get_columns method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With all parameters + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_columns" in str(excinfo.value) From 3578659af87df515addf8632d88549df769106d2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 13:56:15 +0530 Subject: [PATCH 093/105] remove verbosity in ResultSetFilter docstring Co-authored-by: jayant <167047871+jayantsing-db@users.noreply.github.com> --- src/databricks/sql/backend/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index b97787889..30f36f25c 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -25,7 +25,7 @@ class ResultSetFilter: """ - A general-purpose filter for result sets that can be applied to any backend. + A general-purpose filter for result sets. This class provides methods to filter result sets based on various criteria, similar to the client-side filtering in the JDBC connector. From 8713023df340c0f943ead5ba7578e6d686953e46 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:28:37 +0000 Subject: [PATCH 094/105] remove un-necessary info in ResultSetFilter docstring Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 30f36f25c..17a426596 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -26,9 +26,6 @@ class ResultSetFilter: """ A general-purpose filter for result sets. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. """ @staticmethod From 22dc2522f0edfe43d5a7d2398ec487e229491526 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:33:39 +0000 Subject: [PATCH 095/105] remove explicit type checking, string literals around forward annotations Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 17a426596..468fb4d4c 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -11,14 +11,12 @@ Any, Callable, cast, - TYPE_CHECKING, ) from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.types import ExecuteResponse -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet, SeaResultSet +from databricks.sql.result_set import ResultSet, SeaResultSet logger = logging.getLogger(__name__) @@ -30,8 +28,8 @@ class ResultSetFilter: @staticmethod def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": + result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] + ) -> SeaResultSet: """ Filter a SEA result set using the provided filter function. @@ -49,9 +47,6 @@ def _filter_sea_result_set( # Filter rows filtered_rows = [row for row in all_rows if filter_func(row)] - # Import SeaResultSet here to avoid circular imports - from databricks.sql.result_set import SeaResultSet - # Reuse the command_id from the original result set command_id = result_set.command_id @@ -67,10 +62,13 @@ def _filter_sea_result_set( ) # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData result_data = ResultData(data=filtered_rows, external_links=None) + from databricks.sql.result_set import SeaResultSet + # Create a new SeaResultSet with the filtered data filtered_result_set = SeaResultSet( connection=result_set.connection, @@ -85,11 +83,11 @@ def _filter_sea_result_set( @staticmethod def filter_by_column_values( - result_set: "ResultSet", + result_set: ResultSet, column_index: int, allowed_values: List[str], case_sensitive: bool = False, - ) -> "ResultSet": + ) -> ResultSet: """ Filter a result set by values in a specific column. @@ -133,8 +131,8 @@ def filter_by_column_values( @staticmethod def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": + result_set: ResultSet, table_types: Optional[List[str]] = None + ) -> ResultSet: """ Filter a result set of tables by the specified table types. From 390f5928aca9b16c5b30b8a7eb292c3b4cd405dd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:56:37 +0000 Subject: [PATCH 096/105] house SQL commands in constants Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 27 ++++++++++--------- .../sql/backend/sea/utils/constants.py | 20 ++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 9d301d3bc..ac3644b2f 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -10,6 +10,7 @@ ResultDisposition, ResultCompression, WaitTimeout, + MetadataCommands, ) if TYPE_CHECKING: @@ -635,7 +636,7 @@ def get_catalogs( ) -> "ResultSet": """Get available catalogs by executing 'SHOW CATALOGS'.""" result = self.execute_command( - operation="SHOW CATALOGS", + operation=MetadataCommands.SHOW_CATALOGS.value, session_id=session_id, max_rows=max_rows, max_bytes=max_bytes, @@ -662,10 +663,10 @@ def get_schemas( if not catalog_name: raise ValueError("Catalog name is required for get_schemas") - operation = f"SHOW SCHEMAS IN {catalog_name}" + operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) if schema_name: - operation += f" LIKE '{schema_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name) result = self.execute_command( operation=operation, @@ -697,17 +698,19 @@ def get_tables( if not catalog_name: raise ValueError("Catalog name is required for get_tables") - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" + operation = ( + MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] - else f"CATALOG {catalog_name}" + else MetadataCommands.SHOW_TABLES.value.format( + MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name) + ) ) if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) if table_name: - operation += f" LIKE '{table_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(table_name) result = self.execute_command( operation=operation, @@ -745,16 +748,16 @@ def get_columns( if not catalog_name: raise ValueError("Catalog name is required for get_columns") - operation = f"SHOW COLUMNS IN CATALOG {catalog_name}" + operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) if table_name: - operation += f" TABLE LIKE '{table_name}'" + operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name) if column_name: - operation += f" LIKE '{column_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(column_name) result = self.execute_command( operation=operation, diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 7481a90db..4912455c9 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -45,3 +45,23 @@ class WaitTimeout(Enum): ASYNC = "0s" SYNC = "10s" + + +class MetadataCommands(Enum): + """SQL commands used in the SEA backend. + + These constants are used for metadata operations and other SQL queries + to ensure consistency and avoid string literal duplication. + """ + + SHOW_CATALOGS = "SHOW CATALOGS" + SHOW_SCHEMAS = "SHOW SCHEMAS IN {}" + SHOW_TABLES = "SHOW TABLES IN {}" + SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" + SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" + + SCHEMA_LIKE_PATTERN = " SCHEMA LIKE '{}'" + TABLE_LIKE_PATTERN = " TABLE LIKE '{}'" + LIKE_PATTERN = " LIKE '{}'" + + CATALOG_SPECIFIC = "CATALOG {}" From 35f1ef0eb40928d4c92b4b69312acf603c95dcd8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 01:56:46 +0000 Subject: [PATCH 097/105] remove catalog requirement in get_tables Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 3 --- src/databricks/sql/backend/sea/utils/constants.py | 4 ++-- tests/unit/test_sea_backend.py | 10 ---------- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ac3644b2f..53679d10e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -695,9 +695,6 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 4912455c9..402da0de5 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -60,8 +60,8 @@ class MetadataCommands(Enum): SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" - SCHEMA_LIKE_PATTERN = " SCHEMA LIKE '{}'" - TABLE_LIKE_PATTERN = " TABLE LIKE '{}'" LIKE_PATTERN = " LIKE '{}'" + SCHEMA_LIKE_PATTERN = " SCHEMA" + LIKE_PATTERN + TABLE_LIKE_PATTERN = " TABLE" + LIKE_PATTERN CATALOG_SPECIFIC = "CATALOG {}" diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index d75359f2f..e6c8734d0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -810,16 +810,6 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) - # Case 4: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - assert "Catalog name is required for get_tables" in str(excinfo.value) - def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method From a515d260992b7902b017daf152b1c04c86c3d46d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 05:37:46 +0000 Subject: [PATCH 098/105] move filters.py to SEA utils Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- .../sql/backend/{ => sea/utils}/filters.py | 42 +++++++------------ tests/unit/test_filters.py | 28 ++++++------- tests/unit/test_sea_backend.py | 2 +- 4 files changed, 31 insertions(+), 43 deletions(-) rename src/databricks/sql/backend/{ => sea/utils}/filters.py (80%) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 53679d10e..e6d9a082e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -724,7 +724,7 @@ def get_tables( assert result is not None, "execute_command returned None in synchronous mode" # Apply client-side filtering by table_types - from databricks.sql.backend.filters import ResultSetFilter + from databricks.sql.backend.sea.utils.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/sea/utils/filters.py similarity index 80% rename from src/databricks/sql/backend/filters.py rename to src/databricks/sql/backend/sea/utils/filters.py index 468fb4d4c..493975433 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -83,11 +83,11 @@ def _filter_sea_result_set( @staticmethod def filter_by_column_values( - result_set: ResultSet, + result_set: SeaResultSet, column_index: int, allowed_values: List[str], case_sensitive: bool = False, - ) -> ResultSet: + ) -> SeaResultSet: """ Filter a result set by values in a specific column. @@ -105,34 +105,24 @@ def filter_by_column_values( if not case_sensitive: allowed_values = [v.upper() for v in allowed_values] - # Determine the type of result set and apply appropriate filtering - from databricks.sql.result_set import SeaResultSet - - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), ) - return result_set @staticmethod def filter_tables_by_type( - result_set: ResultSet, table_types: Optional[List[str]] = None - ) -> ResultSet: + result_set: SeaResultSet, table_types: Optional[List[str]] = None + ) -> SeaResultSet: """ Filter a result set of tables by the specified table types. diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index bf8d30707..975376e13 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -5,7 +5,7 @@ import unittest from unittest.mock import MagicMock, patch -from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.backend.sea.utils.filters import ResultSetFilter class TestResultSetFilter(unittest.TestCase): @@ -73,7 +73,9 @@ def test_filter_by_column_values(self): # Case 1: Case-sensitive filtering allowed_values = ["table1", "table3"] - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: @@ -98,7 +100,9 @@ def test_filter_by_column_values(self): # Case 2: Case-insensitive filtering mock_sea_result_set_class.reset_mock() - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: @@ -114,22 +118,14 @@ def test_filter_by_column_values(self): ) mock_sea_result_set_class.assert_called_once() - # Case 3: Unsupported result set type - mock_unsupported_result_set = MagicMock() - with patch("databricks.sql.backend.filters.isinstance", return_value=False): - with patch("databricks.sql.backend.filters.logger") as mock_logger: - result = ResultSetFilter.filter_by_column_values( - mock_unsupported_result_set, 0, ["value"], True - ) - mock_logger.warning.assert_called_once() - self.assertEqual(result, mock_unsupported_result_set) - def test_filter_tables_by_type(self): """Test filtering tables by type with various options.""" # Case 1: Specific table types table_types = ["TABLE", "VIEW"] - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch.object( ResultSetFilter, "filter_by_column_values" ) as mock_filter: @@ -143,7 +139,9 @@ def test_filter_tables_by_type(self): self.assertEqual(kwargs.get("case_sensitive"), True) # Case 2: Default table types (None or empty list) - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch.object( ResultSetFilter, "filter_by_column_values" ) as mock_filter: diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index e6c8734d0..2d45a1f49 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -735,7 +735,7 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): ) as mock_execute: # Mock the filter_tables_by_type method with patch( - "databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type", + "databricks.sql.backend.sea.utils.filters.ResultSetFilter.filter_tables_by_type", return_value=mock_result_set, ) as mock_filter: # Case 1: With catalog name only From 59b1330f2db8e680bce7b17b0941e39699b93cf2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 05:40:23 +0000 Subject: [PATCH 099/105] ensure SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index e6d9a082e..623979115 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -12,6 +12,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.result_set import SeaResultSet if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -722,6 +723,9 @@ def get_tables( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "SEA backend execute_command returned a non-SeaResultSet" # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter From dd40bebff73442eedfd264192dc05376a7f86bed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:13:43 +0000 Subject: [PATCH 100/105] prevent circular imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 +++----- src/databricks/sql/backend/sea/utils/filters.py | 11 +++++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 623979115..2af77ec45 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,7 +1,7 @@ import logging import time import re -from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set, cast from databricks.sql.backend.sea.models.base import ResultManifest from databricks.sql.backend.sea.utils.constants import ( @@ -12,7 +12,6 @@ WaitTimeout, MetadataCommands, ) -from databricks.sql.result_set import SeaResultSet if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -723,13 +722,12 @@ def get_tables( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" - assert isinstance( - result, SeaResultSet - ), "SEA backend execute_command returned a non-SeaResultSet" # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter + from databricks.sql.result_set import SeaResultSet + result = cast(SeaResultSet, result) result = ResultSetFilter.filter_tables_by_type(result, table_types) return result diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 493975433..db6a12e16 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -4,6 +4,8 @@ This module provides filtering capabilities for result sets returned by different backends. """ +from __future__ import annotations + import logging from typing import ( List, @@ -11,12 +13,13 @@ Any, Callable, cast, + TYPE_CHECKING, ) -from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import ExecuteResponse +if TYPE_CHECKING: + from databricks.sql.result_set import SeaResultSet -from databricks.sql.result_set import ResultSet, SeaResultSet +from databricks.sql.backend.types import ExecuteResponse logger = logging.getLogger(__name__) @@ -62,11 +65,11 @@ def _filter_sea_result_set( ) # Create a new ResultData object with filtered data - from databricks.sql.backend.sea.models.base import ResultData result_data = ResultData(data=filtered_rows, external_links=None) + from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data From 14057acb8e3201574b8a2054eb63506d7d894800 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:46:16 +0000 Subject: [PATCH 101/105] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2af77ec45..b5385d5df 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,11 +41,6 @@ GetStatementResponse, CreateSessionResponse, ) -from databricks.sql.backend.sea.models.responses import ( - _parse_status, - _parse_manifest, - _parse_result, -) logger = logging.getLogger(__name__) From a4d5bdb726aee53bfa27b60e1b7baf78c01a67d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:51:59 +0000 Subject: [PATCH 102/105] remove cast, throw error if not SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 10 +++++++--- tests/unit/test_sea_backend.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b5385d5df..2cd1c98c2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,7 +1,7 @@ import logging import time import re -from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set, cast +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.models.base import ResultManifest from databricks.sql.backend.sea.utils.constants import ( @@ -718,11 +718,15 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" + from databricks.sql.result_set import SeaResultSet + + assert isinstance( + result, SeaResultSet + ), "execute_command returned a non-SeaResultSet" + # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter - from databricks.sql.result_set import SeaResultSet - result = cast(SeaResultSet, result) result = ResultSetFilter.filter_tables_by_type(result, table_types) return result diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2d45a1f49..68dea3d81 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -729,7 +729,10 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + from databricks.sql.result_set import SeaResultSet + + mock_result_set = Mock(spec=SeaResultSet) + with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: From e9b1314e28c2898f4d9c32defcf7042d4eb1fada Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:33:54 +0000 Subject: [PATCH 103/105] make SEA backend methods return SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2cd1c98c2..83255f79b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import time import re @@ -15,7 +17,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet + from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -401,12 +403,12 @@ def execute_command( max_rows: int, max_bytes: int, lz4_compression: bool, - cursor: "Cursor", + cursor: Cursor, use_cloud_fetch: bool, parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: + ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -573,8 +575,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: def get_execution_result( self, command_id: CommandId, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> SeaResultSet: """ Get the result of a command execution. @@ -583,7 +585,7 @@ def get_execution_result( cursor: Cursor executing the command Returns: - ResultSet: A SeaResultSet instance with the execution results + SeaResultSet: A SeaResultSet instance with the execution results Raises: ValueError: If the command ID is invalid @@ -627,8 +629,8 @@ def get_catalogs( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> SeaResultSet: """Get available catalogs by executing 'SHOW CATALOGS'.""" result = self.execute_command( operation=MetadataCommands.SHOW_CATALOGS.value, @@ -650,10 +652,10 @@ def get_schemas( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" if not catalog_name: raise ValueError("Catalog name is required for get_schemas") @@ -683,12 +685,12 @@ def get_tables( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value @@ -718,12 +720,6 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" - from databricks.sql.result_set import SeaResultSet - - assert isinstance( - result, SeaResultSet - ), "execute_command returned a non-SeaResultSet" - # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter @@ -736,12 +732,12 @@ def get_columns( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" if not catalog_name: raise ValueError("Catalog name is required for get_columns") From 8ede414f8ac485f4e9ed83b49af7087b106d0175 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:38:33 +0000 Subject: [PATCH 104/105] use spec-aligned Exceptions in SEA backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 26 +++++++++++------------ tests/unit/test_sea_backend.py | 17 ++++++++------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 83255f79b..bfc0c6c9e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -27,7 +27,7 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions @@ -172,7 +172,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: f"Note: SEA only works for warehouses." ) logger.error(error_message) - raise ValueError(error_message) + raise ProgrammingError(error_message) @property def max_download_threads(self) -> int: @@ -244,14 +244,14 @@ def close_session(self, session_id: SessionId) -> None: session_id: The session identifier returned by open_session() Raises: - ValueError: If the session ID is invalid + ProgrammingError: If the session ID is invalid OperationalError: If there's an error closing the session """ logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( @@ -429,7 +429,7 @@ def execute_command( """ if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -504,11 +504,11 @@ def cancel_command(self, command_id: CommandId) -> None: command_id: Command identifier to cancel Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -527,11 +527,11 @@ def close_command(self, command_id: CommandId) -> None: command_id: Command identifier to close Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -553,7 +553,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: CommandState: The current state of the command Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -592,7 +592,7 @@ def get_execution_result( """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -658,7 +658,7 @@ def get_schemas( ) -> SeaResultSet: """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") + raise DatabaseError("Catalog name is required for get_schemas") operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) @@ -740,7 +740,7 @@ def get_columns( ) -> SeaResultSet: """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" if not catalog_name: - raise ValueError("Catalog name is required for get_columns") + raise DatabaseError("Catalog name is required for get_columns") operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 68dea3d81..6847cded0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -18,6 +18,7 @@ from databricks.sql.exc import ( Error, NotSupportedError, + ProgrammingError, ServerOperationError, DatabaseError, ) @@ -129,7 +130,7 @@ def test_initialization(self, mock_http_client): assert client3.max_download_threads == 5 # Test with invalid HTTP path - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", port=443, @@ -195,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i ) # Test close_session with invalid ID type - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) @@ -244,7 +245,7 @@ def test_command_execution_sync( assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: mock_thrift_handle = MagicMock() mock_thrift_handle.sessionId.guid = b"guid" mock_thrift_handle.sessionId.secret = b"secret" @@ -448,7 +449,7 @@ def test_command_management( ) # Test cancel_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.cancel_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -462,7 +463,7 @@ def test_command_management( ) # Test close_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -521,7 +522,7 @@ def test_command_management( assert result.status == CommandState.SUCCEEDED # Test get_execution_result with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -717,7 +718,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): ) # Case 3: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(DatabaseError) as excinfo: sea_client.get_schemas( session_id=sea_session_id, max_rows=100, @@ -868,7 +869,7 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): ) # Case 3: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(DatabaseError) as excinfo: sea_client.get_columns( session_id=sea_session_id, max_rows=100, From 09a1b11865ef9bad7d0ae5e510aede2b375f1beb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:51:38 +0000 Subject: [PATCH 105/105] remove defensive row type check Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/filters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index db6a12e16..1b7660829 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -112,7 +112,6 @@ def filter_by_column_values( result_set, lambda row: ( len(row) > column_index - and isinstance(row[column_index], str) and ( row[column_index].upper() if not case_sensitive