diff --git a/src/koheesio/spark/readers/databricks/autoloader.py b/src/koheesio/spark/readers/databricks/autoloader.py index 6b26e20..8444a54 100644 --- a/src/koheesio/spark/readers/databricks/autoloader.py +++ b/src/koheesio/spark/readers/databricks/autoloader.py @@ -3,9 +3,12 @@ Autoloader can ingest JSON, CSV, PARQUET, AVRO, ORC, TEXT, and BINARYFILE file formats. """ -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from enum import Enum +from pyspark.sql.streaming import DataStreamReader +from pyspark.sql.types import AtomicType, StructType + from koheesio.models import Field, field_validator from koheesio.spark.readers import Reader @@ -53,7 +56,7 @@ class AutoLoader(Reader): Example ------- ```python - from koheesio.spark.readers.databricks import AutoLoader, AutoLoaderFormat + from koheesio.steps.readers.databricks import AutoLoader, AutoLoaderFormat result_df = AutoLoader( format=AutoLoaderFormat.JSON, @@ -82,11 +85,16 @@ class AutoLoader(Reader): description="The location for storing inferred schema and supporting schema evolution, " "used in `cloudFiles.schemaLocation`.", ) - options: Optional[Dict[str, str]] = Field( + options: Optional[Dict[str, Any]] = Field( default_factory=dict, description="Extra inputs to provide to the autoloader. For a full list of inputs, " "see https://docs.databricks.com/ingestion/auto-loader/options.html", ) + schema_: Optional[Union[str, StructType, List[str], Tuple[str, ...], AtomicType]] = Field( + default=None, + description="Explicit schema to apply to the input files.", + alias="schema", + ) @field_validator("format") def validate_format(cls, format_specified): @@ -107,9 +115,12 @@ def get_options(self): return self.options # @property - def reader(self): - """Return the reader for the autoloader""" - return self.spark.readStream.format("cloudFiles").options(**self.get_options()) + def reader(self) -> DataStreamReader: + reader = self.spark.readStream.format("cloudFiles") + if self.schema_ is not None: + reader = reader.schema(self.schema_) + reader = reader.options(**self.get_options()) + return reader def execute(self): """Reads from the given location with the given options using Autoloader""" diff --git a/tests/spark/readers/test_auto_loader.py b/tests/spark/readers/test_auto_loader.py index ca07f55..8f2b168 100644 --- a/tests/spark/readers/test_auto_loader.py +++ b/tests/spark/readers/test_auto_loader.py @@ -21,10 +21,13 @@ def test_invalid_format(bad_format): def mock_reader(self): - return self.spark.read.format("json").options(**self.options) + reader = self.spark.read.format("json") + if self.schema_ is not None: + reader = reader.schema(self.schema_) + return reader.options(**self.options) -def test_read_json(spark, mocker, data_path): +def test_read_json_infer_schema(spark, mocker, data_path): mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader) options = {"multiLine": "true"} @@ -49,3 +52,51 @@ def test_read_json(spark, mocker, data_path): ] expected_df = spark.createDataFrame(data_expected, schema_expected) assert_df_equality(result, expected_df, ignore_column_order=True) + + +def test_read_json_exact_explicit_schema_struct(spark, mocker, data_path): + mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader) + + schema = StructType( + [ + StructField("string", StringType(), True), + StructField("int", LongType(), True), + StructField("array", ArrayType(LongType()), True), + ] + ) + options = {"multiLine": "true"} + json_file_path_str = f"{data_path}/readers/json_file/dummy.json" + auto_loader = AutoLoader( + format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema + ) + + auto_loader.execute() + result = auto_loader.output.df + + data_expected = [ + {"string": "string1", "int": 1, "array": [1, 11, 111]}, + {"string": "string2", "int": 2, "array": [2, 22, 222]}, + ] + expected_df = spark.createDataFrame(data_expected, schema) + assert_df_equality(result, expected_df, ignore_column_order=True) + + +def test_read_json_different_explicit_schema_string(spark, mocker, data_path): + mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader) + + schema = "string STRING,array ARRAY" + options = {"multiLine": "true"} + json_file_path_str = f"{data_path}/readers/json_file/dummy.json" + auto_loader = AutoLoader( + format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema + ) + + auto_loader.execute() + result = auto_loader.output.df + + data_expected = [ + {"string": "string1", "array": [1, 11, 111]}, + {"string": "string2", "array": [2, 22, 222]}, + ] + expected_df = spark.createDataFrame(data_expected, schema) + assert_df_equality(result, expected_df, ignore_column_order=True)