diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a776377c3..3c0e325fe 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -51,12 +51,20 @@ def test_sea_async_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 5000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query with cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch enabled" ) @@ -69,8 +77,25 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" + "PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly" ) # Close resources @@ -130,12 +155,20 @@ def test_sea_async_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 100)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query without cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch disabled" ) @@ -148,8 +181,24 @@ def test_sea_async_query_without_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" + "PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly" ) # Close resources diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 07be8aafc..76941e2d2 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -49,13 +49,27 @@ def test_sea_sync_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 10000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute(query) + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + f"{actual_row_count} rows retrieved against {requested_row_count} requested" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") # Close resources cursor.close() @@ -114,13 +128,18 @@ def test_sea_sync_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") # Close resources cursor.close() diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 0c0400ae2..2ed248c3d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -616,10 +616,10 @@ def get_execution_result( connection=cursor.connection, execute_response=execute_response, sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, result_data=response.result, manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/conversion.py b/src/databricks/sql/backend/sea/conversion.py new file mode 100644 index 000000000..a3edd6dcc --- /dev/null +++ b/src/databricks/sql/backend/sea/conversion.py @@ -0,0 +1,135 @@ +""" +Type conversion utilities for the Databricks SQL Connector. + +This module provides functionality to convert string values from SEA Inline results +to appropriate Python types based on column metadata. +""" + +import datetime +import decimal +import logging +from dateutil import parser +from typing import Any, Callable, Dict, Optional, Union + +logger = logging.getLogger(__name__) + + +class SqlType: + """ + SQL type constants + + The list of types can be found in the SEA REST API Reference: + https://docs.databricks.com/api/workspace/statementexecution/executestatement + """ + + # Numeric types + BYTE = "byte" + SHORT = "short" + INT = "int" + LONG = "long" + FLOAT = "float" + DOUBLE = "double" + DECIMAL = "decimal" + + # Boolean type + BOOLEAN = "boolean" + + # Date/Time types + DATE = "date" + TIMESTAMP = "timestamp" + INTERVAL = "interval" + + # String types + CHAR = "char" + STRING = "string" + + # Binary type + BINARY = "binary" + + # Complex types + ARRAY = "array" + MAP = "map" + STRUCT = "struct" + + # Other types + NULL = "null" + USER_DEFINED_TYPE = "user_defined_type" + + +class SqlTypeConverter: + """ + Utility class for converting SQL types to Python types. + Based on the types supported by the Databricks SDK. + """ + + # SQL type to conversion function mapping + # TODO: complex types + TYPE_MAPPING: Dict[str, Callable] = { + # Numeric types + SqlType.BYTE: lambda v: int(v), + SqlType.SHORT: lambda v: int(v), + SqlType.INT: lambda v: int(v), + SqlType.LONG: lambda v: int(v), + SqlType.FLOAT: lambda v: float(v), + SqlType.DOUBLE: lambda v: float(v), + SqlType.DECIMAL: lambda v, p=None, s=None: ( + decimal.Decimal(v).quantize( + decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p) + ) + if p is not None and s is not None + else decimal.Decimal(v) + ), + # Boolean type + SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), + # Date/Time types + SqlType.DATE: lambda v: datetime.date.fromisoformat(v), + SqlType.TIMESTAMP: lambda v: parser.parse(v), + SqlType.INTERVAL: lambda v: v, # Keep as string for now + # String types - no conversion needed + SqlType.CHAR: lambda v: v, + SqlType.STRING: lambda v: v, + # Binary type + SqlType.BINARY: lambda v: bytes.fromhex(v), + # Other types + SqlType.NULL: lambda v: None, + # Complex types and user-defined types return as-is + SqlType.USER_DEFINED_TYPE: lambda v: v, + } + + @staticmethod + def convert_value( + value: Any, + sql_type: str, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> Any: + """ + Convert a string value to the appropriate Python type based on SQL type. + + Args: + value: The string value to convert + sql_type: The SQL type (e.g., 'int', 'decimal') + precision: Optional precision for decimal types + scale: Optional scale for decimal types + + Returns: + The converted value in the appropriate Python type + """ + + if value is None: + return None + + sql_type = sql_type.lower().strip() + + if sql_type not in SqlTypeConverter.TYPE_MAPPING: + return value + + converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] + try: + if sql_type == SqlType.DECIMAL: + return converter_func(value, precision, scale) + else: + return converter_func(value) + except (ValueError, TypeError, decimal.InvalidOperation) as e: + logger.warning(f"Error converting value '{value}' to {sql_type}: {e}") + return value diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 1b7660829..f3bf4669a 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -77,9 +77,9 @@ def _filter_sea_result_set( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), + result_data=result_data, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, - result_data=result_data, ) return filtered_result_set diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..02d335aa4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -40,11 +40,11 @@ ) from databricks.sql.utils import ( - ResultSetQueueFactory, + ThriftResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, - ResultSetQueueFactory, + ThriftResultSetQueueFactory, convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, @@ -1232,7 +1232,7 @@ def fetch_results( ) ) - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 38b8a3c2f..a4814db57 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,11 +1,14 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING import logging -import time import pandas from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.conversion import SqlTypeConverter try: import pyarrow @@ -16,10 +19,14 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row -from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.exc import RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ( + ColumnTable, + ColumnQueue, + JsonQueue, + SeaResultSetQueueFactory, +) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -88,6 +95,44 @@ def __iter__(self): else: break + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + @property def rownumber(self): return self._next_row_index @@ -97,12 +142,6 @@ def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" return self._is_staging_operation - # Define abstract methods that concrete implementations must implement - @abstractmethod - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - pass - @abstractmethod def fetchone(self) -> Optional[Row]: """Fetch the next row of a query result set.""" @@ -189,10 +228,10 @@ def __init__( # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory + from databricks.sql.utils import ThriftResultSetQueueFactory # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( + results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", @@ -249,44 +288,6 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result @@ -451,13 +452,13 @@ class SeaResultSet(ResultSet): def __init__( self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + result_data: ResultData, + manifest: Optional[ResultManifest] = None, buffer_size_bytes: int = 104857600, arraysize: int = 10000, - result_data=None, - manifest=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -468,10 +469,21 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - result_data: Result data from SEA response (optional) - manifest: Manifest from SEA response (optional) + result_data: Result data from SEA response + manifest: Manifest from SEA response """ + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + str(execute_response.command_id.to_sea_statement_id()), + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, @@ -480,46 +492,201 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) + def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": + """ + Convert raw data rows to Arrow table. + """ + if not rows: + return pyarrow.Table.from_pydict({}) + + columns = [] + num_cols = len(rows[0]) + for i in range(num_cols): + columns.append([row[i] for row in rows]) + names = [col[0] for col in self.description] + return pyarrow.Table.from_arrays(columns, names=names) + + def _convert_json_types(self, rows: List) -> List: + """ + Convert raw data rows to Row objects with named columns based on description. + Also converts string values to appropriate Python types based on column metadata. + """ + + if not self.description or not rows: + return rows + + # JSON + INLINE gives us string values, so we convert them to appropriate + # types based on column metadata + converted_rows = [] + for row in rows: + converted_row = [] + + for i, value in enumerate(row): + column_type = self.description[i][1] + precision = self.description[i][4] + scale = self.description[i][5] + + try: + converted_value = SqlTypeConverter.convert_value( + value, column_type, precision=precision, scale=scale + ) + converted_row.append(converted_value) + except Exception as e: + logger.warning( + f"Error converting value '{value}' to {column_type}: {e}" + ) + converted_row.append(value) + + converted_rows.append(converted_row) + + return converted_rows + + def _create_json_table(self, rows: List) -> List[Row]: + """ + Convert raw data rows to Row objects with named columns based on description. + Also converts string values to appropriate Python types based on column metadata. + + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns and converted values + """ + + if not self.description or not rows: + return rows + + ResultRow = Row(*[col[0] for col in self.description]) + rows = self._convert_json_types(rows) + + return [ResultRow(*row) for row in rows] + + def fetchmany_json(self, size: int) -> List: + """ + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch + + Returns: + Columnar table containing the fetched rows + + Raises: + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += len(results) + + return results + + def fetchall_json(self) -> List: + """ + Fetch all remaining rows as a columnar table. + + Returns: + Columnar table containing all remaining rows + """ + + results = self.results.remaining_rows() + self._next_row_index += len(results) + + return results + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchmany_arrow only supported for JSON data") + + rows = self._convert_json_types(self.results.next_n_rows(size)) + results = self._convert_json_to_arrow(rows) + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchall_arrow only supported for JSON data") + + rows = self._convert_json_types(self.results.remaining_rows()) + results = self._convert_json_to_arrow(rows) + self._next_row_index += results.num_rows + + return results def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. + + Returns: + A single Row object or None if no more rows are available """ - raise NotImplementedError("fetchone is not implemented for SEA backend") + if isinstance(self.results, JsonQueue): + res = self._create_json_table(self.fetchmany_json(1)) + else: + raise NotImplementedError("fetchone only supported for JSON data") + + return res[0] if res else None - def fetchmany(self, size: Optional[int] = None) -> List[Row]: + def fetchmany(self, size: int) -> List[Row]: """ Fetch the next set of rows of a query result, returning a list of rows. - An empty sequence is returned when no more rows are available. + Args: + size: Number of rows to fetch (defaults to arraysize if None) + + Returns: + List of Row objects + + Raises: + ValueError: If size is negative """ - raise NotImplementedError("fetchmany is not implemented for SEA backend") + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchmany_json(size)) + else: + raise NotImplementedError("fetchmany only supported for JSON data") def fetchall(self) -> List[Row]: """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ + Fetch all remaining rows of a query result, returning them as a list of rows. - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + Returns: + List of Row objects containing all remaining rows + """ - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchall_json()) + else: + raise NotImplementedError("fetchall only supported for JSON data") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 8997bda22..fcf39df33 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,12 +13,16 @@ import lz4.frame +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + try: import pyarrow except ImportError: pyarrow = None -from databricks.sql import OperationalError, exc +from databricks.sql import OperationalError +from databricks.sql.exc import ProgrammingError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -48,7 +52,7 @@ def remaining_rows(self): pass -class ResultSetQueueFactory(ABC): +class ThriftResultSetQueueFactory(ABC): @staticmethod def build_queue( row_set_type: TSparkRowSetType, @@ -107,6 +111,70 @@ def build_queue( raise AssertionError("Row set type is not valid") +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + sea_result_data: ResultData, + manifest: Optional[ResultManifest], + statement_id: str, + description: Optional[List[Tuple[Any, ...]]] = None, + schema_bytes: Optional[bytes] = None, + max_download_threads: Optional[int] = None, + sea_client: Optional[SeaDatabricksClient] = None, + lz4_compressed: bool = False, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + sea_result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + schema_bytes (bytes): Arrow schema bytes + max_download_threads (int): Maximum number of download threads + ssl_options (SSLOptions): SSL options for downloads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if sea_result_data.data is not None: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(sea_result_data.data) + elif sea_result_data.external_links is not None: + # EXTERNAL_LINKS disposition + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" + ) + return JsonQueue([]) + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array): + """Initialize with JSON array data.""" + self.data_array = data_array + self.cur_row_index = 0 + self.n_valid_rows = len(data_array) + + def next_n_rows(self, num_rows): + """Get the next n rows from the data array.""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self): + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice + + class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index a3f9b1af8..5848d780b 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -196,10 +196,21 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} - def test_execute_async__small_result(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(small_result_query) ## Fake sleep for 5 secs @@ -328,8 +339,19 @@ def test_incorrect_query_throws_exception(self): cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") assert "table_234234234" in str(cm.value) - def test_create_table_will_return_empty_result_set(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_create_table_will_return_empty_result_set(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute( @@ -527,10 +549,21 @@ def test_get_catalogs(self): ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") - def test_get_arrow(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_get_arrow(self, extra_params): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM range(10)") table_1 = cursor.fetchmany_arrow(1).to_pydict() assert table_1 == OrderedDict([("id", [0])]) @@ -538,9 +571,20 @@ def test_get_arrow(self): table_2 = cursor.fetchall_arrow().to_pydict() assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) - def test_unicode(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_unicode(self, extra_params): unicode_str = "数据砖" - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT '{}'".format(unicode_str)) results = cursor.fetchall() assert len(results) == 1 and len(results[0]) == 1 @@ -578,8 +622,19 @@ def execute_really_long_query(): assert len(cursor.fetchall()) == 3 @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_failure(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_can_execute_command_after_failure(self, extra_params): + with self.cursor(extra_params) as cursor: with pytest.raises(DatabaseError): cursor.execute("this is a sytnax error") @@ -589,8 +644,19 @@ def test_can_execute_command_after_failure(self): self.assertEqualRowValues(res, [[1]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_success(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_can_execute_command_after_success(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1;") cursor.execute("SELECT 2;") @@ -602,8 +668,19 @@ def generate_multi_row_query(self): return query @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchone(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchone(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -614,8 +691,19 @@ def test_fetchone(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchall(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchall(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -624,8 +712,19 @@ def test_fetchall(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_when_stride_fits(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchmany_when_stride_fits(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -633,8 +732,19 @@ def test_fetchmany_when_stride_fits(self): self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_in_excess(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchmany_in_excess(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -642,8 +752,19 @@ def test_fetchmany_in_excess(self): self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_iterator_api(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_iterator_api(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -716,8 +837,21 @@ def test_timestamps_arrow(self): ), "timestamp {} did not match {}".format(timestamp, expected) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_multi_timestamps_arrow(self): - with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_multi_timestamps_arrow(self, extra_params): + with self.cursor( + {"session_configuration": {"ansi_mode": False}, **extra_params} + ) as cursor: query, expected = self.multi_query() expected = [ [self.maybe_add_timezone_to_timestamp(ts) for ts in row] diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py new file mode 100644 index 000000000..738889975 --- /dev/null +++ b/tests/unit/test_sea_conversion.py @@ -0,0 +1,136 @@ +""" +Tests for the conversion module in the SEA backend. + +This module contains tests for the SqlType and SqlTypeConverter classes. +""" + +import pytest +import datetime +import decimal +from unittest.mock import Mock, patch + +from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter + + +class TestSqlTypeConverter: + """Test suite for the SqlTypeConverter class.""" + + def test_convert_value_null(self): + """Test converting null values.""" + assert SqlTypeConverter.convert_value(None, SqlType.INT) is None + assert SqlTypeConverter.convert_value(None, SqlType.STRING) is None + assert SqlTypeConverter.convert_value(None, SqlType.BOOLEAN) is None + + def test_convert_numeric_types(self): + """Test converting numeric types.""" + # Test integer types + assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123 + assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456 + assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789 + assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890 + + # Test floating point types + assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45 + assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90 + + # Test decimal type + decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test decimal with precision and scale + decimal_value = SqlTypeConverter.convert_value( + "123.45", SqlType.DECIMAL, precision=5, scale=2 + ) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test invalid numeric input + result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT) + assert result == "not_a_number" # Returns original value on error + + def test_convert_boolean_type(self): + """Test converting boolean types.""" + # True values + assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True + + # False values + assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False + + def test_convert_datetime_types(self): + """Test converting datetime types.""" + # Test date type + date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE) + assert isinstance(date_value, datetime.date) + assert date_value == datetime.date(2023, 1, 15) + + # Test timestamp type + timestamp_value = SqlTypeConverter.convert_value( + "2023-01-15T12:30:45", SqlType.TIMESTAMP + ) + assert isinstance(timestamp_value, datetime.datetime) + assert timestamp_value.year == 2023 + assert timestamp_value.month == 1 + assert timestamp_value.day == 15 + assert timestamp_value.hour == 12 + assert timestamp_value.minute == 30 + assert timestamp_value.second == 45 + + # Test interval type (currently returns as string) + interval_value = SqlTypeConverter.convert_value( + "1 day 2 hours", SqlType.INTERVAL + ) + assert interval_value == "1 day 2 hours" + + # Test invalid date input + result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE) + assert result == "not_a_date" # Returns original value on error + + def test_convert_string_types(self): + """Test converting string types.""" + # String types don't need conversion, they should be returned as-is + assert ( + SqlTypeConverter.convert_value("test string", SqlType.STRING) + == "test string" + ) + assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char" + + def test_convert_binary_type(self): + """Test converting binary type.""" + # Test valid hex string + binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY) + assert isinstance(binary_value, bytes) + assert binary_value == b"Hello" + + # Test invalid binary input + result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY) + assert result == "not_hex" # Returns original value on error + + def test_convert_unsupported_type(self): + """Test converting an unsupported type.""" + # Should return the original value + assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test" + + # Complex types should return as-is + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.MAP) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT) + == "complex_value" + ) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py new file mode 100644 index 000000000..4a4dee8f5 --- /dev/null +++ b/tests/unit/test_sea_queue.py @@ -0,0 +1,172 @@ +""" +Tests for SEA-related queue classes in utils.py. + +This module contains tests for the JsonQueue and SeaResultSetQueueFactory classes. +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch + +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + + +class TestJsonQueue: + """Test suite for the JsonQueue class.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", 1, True], + ["value2", 2, False], + ["value3", 3, True], + ["value4", 4, False], + ["value5", 5, True], + ] + + def test_init(self, sample_data): + """Test initialization of JsonQueue.""" + queue = JsonQueue(sample_data) + assert queue.data_array == sample_data + assert queue.cur_row_index == 0 + assert queue.n_valid_rows == len(sample_data) + + def test_next_n_rows_partial(self, sample_data): + """Test fetching a subset of rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(2) + assert result == sample_data[:2] + assert queue.cur_row_index == 2 + + def test_next_n_rows_all(self, sample_data): + """Test fetching all rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data)) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_more_than_available(self, sample_data): + """Test fetching more rows than available.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data) + 10) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_after_partial(self, sample_data): + """Test fetching rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.next_n_rows(2) # Fetch next 2 rows + assert result == sample_data[2:4] + assert queue.cur_row_index == 4 + + def test_remaining_rows_all(self, sample_data): + """Test fetching all remaining rows at once.""" + queue = JsonQueue(sample_data) + result = queue.remaining_rows() + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_remaining_rows_after_partial(self, sample_data): + """Test fetching remaining rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.remaining_rows() # Fetch remaining rows + assert result == sample_data[2:] + assert queue.cur_row_index == len(sample_data) + + def test_empty_data(self): + """Test with empty data array.""" + queue = JsonQueue([]) + assert queue.next_n_rows(10) == [] + assert queue.remaining_rows() == [] + assert queue.cur_row_index == 0 + assert queue.n_valid_rows == 0 + + +class TestSeaResultSetQueueFactory: + """Test suite for the SeaResultSetQueueFactory class.""" + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @pytest.fixture + def mock_description(self): + """Create a mock column description.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): + """Test building a queue with inline JSON data.""" + # Create sample data for inline JSON result + data = [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ] + + # Create a ResultData object with inline data + result_data = ResultData(data=data, external_links=None, row_count=len(data)) + + # Create a manifest (not used for inline data) + manifest = None + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with the correct data + assert isinstance(queue, JsonQueue) + assert queue.data_array == data + assert queue.n_valid_rows == len(data) + + def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): + """Test building a queue with empty data.""" + # Create a ResultData object with no data + result_data = ResultData(data=[], external_links=None, row_count=0) + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + None, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with empty data + assert isinstance(queue, JsonQueue) + assert queue.data_array == [] + assert queue.n_valid_rows == 0 + + def test_build_queue_with_external_links(self, mock_sea_client, mock_description): + """Test building a queue with external links raises NotImplementedError.""" + # Create a ResultData object with external links + result_data = ResultData( + data=None, external_links=["link1", "link2"], row_count=10 + ) + + # Verify that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSetQueueFactory.build_queue( + result_data, + None, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..775b42d13 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -8,8 +8,10 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.result_set import SeaResultSet +from databricks.sql.result_set import SeaResultSet, Row +from databricks.sql.utils import JsonQueue from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest class TestSeaResultSet: @@ -37,11 +39,55 @@ def execute_response(self): mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = None return mock_response + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ["value3", "3", "true"], + ["value4", "4", "false"], + ["value5", "5", "true"], + ] + + @pytest.fixture + def result_set_with_data( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Create a SeaResultSet with sample data.""" + # Create ResultData with inline data + result_data = ResultData( + data=sample_data, external_links=None, row_count=len(sample_data) + ) + + # Initialize SeaResultSet with result data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=result_data, + manifest=None, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = JsonQueue(sample_data) + + return result_set + + @pytest.fixture + def json_queue(self, sample_data): + """Create a JsonQueue with sample data.""" + return JsonQueue(sample_data) + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -50,6 +96,7 @@ def test_init_with_execute_response( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -69,6 +116,7 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -89,6 +137,7 @@ def test_close_when_already_closed_server_side( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -111,6 +160,7 @@ def test_close_when_connection_closed( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -123,79 +173,189 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + def test_init_with_result_data(self, result_set_with_data, sample_data): + """Test initializing SeaResultSet with result data.""" + # Verify the results queue was created correctly + assert isinstance(result_set_with_data.results, JsonQueue) + assert result_set_with_data.results.data_array == sample_data + assert result_set_with_data.results.n_valid_rows == len(sample_data) - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() + def test_convert_json_types(self, result_set_with_data, sample_data): + """Test the _convert_json_types method.""" + # Call _convert_json_types + converted_rows = result_set_with_data._convert_json_types(sample_data) - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) + # Verify the conversion + assert len(converted_rows) == len(sample_data) + assert converted_rows[0][0] == "value1" # string stays as string + assert converted_rows[0][1] == 1 # "1" converted to int + assert converted_rows[0][2] is True # "true" converted to boolean - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() + def test_create_json_table(self, result_set_with_data, sample_data): + """Test the _create_json_table method.""" + # Call _create_json_table + result_rows = result_set_with_data._create_json_table(sample_data) - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() + # Verify the result + assert len(result_rows) == len(sample_data) + assert isinstance(result_rows[0], Row) + assert result_rows[0].col1 == "value1" + assert result_rows[0].col2 == 1 + assert result_rows[0].col3 is True - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) + def test_fetchmany_json(self, result_set_with_data): + """Test the fetchmany_json method.""" + # Test fetching a subset of rows + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 2 + + # Test fetching the next subset + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 4 + + # Test fetching more than available + result = result_set_with_data.fetchmany_json(10) + assert len(result) == 1 # Only one row left + assert result_set_with_data._next_row_index == 5 + + def test_fetchall_json(self, result_set_with_data, sample_data): + """Test the fetchall_json method.""" + # Test fetching all rows + result = result_set_with_data.fetchall_json() + assert result == sample_data + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + result = result_set_with_data.fetchall_json() + assert result == [] + assert result_set_with_data._next_row_index == len(sample_data) + + def test_fetchone(self, result_set_with_data): + """Test the fetchone method.""" + # Test fetching one row at a time + row1 = result_set_with_data.fetchone() + assert isinstance(row1, Row) + assert row1.col1 == "value1" + assert row1.col2 == 1 + assert row1.col3 is True + assert result_set_with_data._next_row_index == 1 + + row2 = result_set_with_data.fetchone() + assert isinstance(row2, Row) + assert row2.col1 == "value2" + assert row2.col2 == 2 + assert row2.col3 is False + assert result_set_with_data._next_row_index == 2 + + # Fetch the rest + result_set_with_data.fetchall() + # Test fetching when no more rows + row_none = result_set_with_data.fetchone() + assert row_none is None + + def test_fetchmany(self, result_set_with_data): + """Test the fetchmany method.""" + # Test fetching multiple rows + rows = result_set_with_data.fetchmany(2) + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert rows[1].col1 == "value2" + assert rows[1].col2 == 2 + assert rows[1].col3 is False + assert result_set_with_data._next_row_index == 2 + + # Test with invalid size with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - result_set.fetchall_arrow() + result_set_with_data.fetchmany(-1) + + def test_fetchall(self, result_set_with_data, sample_data): + """Test the fetchall method.""" + # Test fetching all rows + rows = result_set_with_data.fetchall() + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + rows = result_set_with_data.fetchall() + assert len(rows) == 0 + + def test_iteration(self, result_set_with_data, sample_data): + """Test iterating over the result set.""" + # Test iteration + rows = list(result_set_with_data) + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + + def test_fetchmany_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + buffer_size_bytes=1000, + arraysize=100, + ) + def test_fetchall_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - # Test using the result set in a for loop - for row in result_set: - pass + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + buffer_size_bytes=1000, + arraysize=100, + ) - def test_fill_results_buffer_not_implemented( + def test_is_staging_operation( self, mock_connection, mock_sea_client, execute_response ): - """Test that _fill_results_buffer raises NotImplementedError.""" + """Test the is_staging_operation property.""" + # Set is_staging_operation to True + execute_response.is_staging_operation = True + + # Create a result set result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() + # Test the property + assert result_set.is_staging_operation is True diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..4a4295e11 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -610,7 +610,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): self.assertIn("some information about the error", str(cm.exception)) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) def test_handle_execute_response_sets_compression_in_direct_results( self, build_queue @@ -998,7 +999,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( @@ -1043,7 +1045,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(is_direct_results, has_more_rows_result) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response(