Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to set explicit schema i.s.o letting AutoLoader infer the schema #40

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/koheesio/spark/readers/databricks/autoloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Autoloader can ingest JSON, CSV, PARQUET, AVRO, ORC, TEXT, and BINARYFILE file formats.
"""

from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum

from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import AtomicType, StructType

from koheesio.models import Field, field_validator
from koheesio.spark.readers import Reader

Expand Down Expand Up @@ -53,7 +56,7 @@ class AutoLoader(Reader):
Example
-------
```python
from koheesio.spark.readers.databricks import AutoLoader, AutoLoaderFormat
from koheesio.steps.readers.databricks import AutoLoader, AutoLoaderFormat

result_df = AutoLoader(
format=AutoLoaderFormat.JSON,
Expand Down Expand Up @@ -82,11 +85,16 @@ class AutoLoader(Reader):
description="The location for storing inferred schema and supporting schema evolution, "
"used in `cloudFiles.schemaLocation`.",
)
options: Optional[Dict[str, str]] = Field(
options: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Extra inputs to provide to the autoloader. For a full list of inputs, "
"see https://docs.databricks.com/ingestion/auto-loader/options.html",
)
schema_: Optional[Union[str, StructType, List[str], Tuple[str, ...], AtomicType]] = Field(
default=None,
description="Explicit schema to apply to the input files.",
alias="schema",
)

@field_validator("format")
def validate_format(cls, format_specified):
Expand All @@ -107,9 +115,12 @@ def get_options(self):
return self.options

# @property
def reader(self):
"""Return the reader for the autoloader"""
return self.spark.readStream.format("cloudFiles").options(**self.get_options())
def reader(self) -> DataStreamReader:
reader = self.spark.readStream.format("cloudFiles")
if self.schema_ is not None:
reader = reader.schema(self.schema_)
reader = reader.options(**self.get_options())
return reader

def execute(self):
"""Reads from the given location with the given options using Autoloader"""
Expand Down
55 changes: 53 additions & 2 deletions tests/spark/readers/test_auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def test_invalid_format(bad_format):


def mock_reader(self):
return self.spark.read.format("json").options(**self.options)
reader = self.spark.read.format("json")
if self.schema_ is not None:
reader = reader.schema(self.schema_)
return reader.options(**self.options)


def test_read_json(spark, mocker, data_path):
def test_read_json_infer_schema(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

options = {"multiLine": "true"}
Expand All @@ -49,3 +52,51 @@ def test_read_json(spark, mocker, data_path):
]
expected_df = spark.createDataFrame(data_expected, schema_expected)
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_exact_explicit_schema_struct(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = StructType(
[
StructField("string", StringType(), True),
StructField("int", LongType(), True),
StructField("array", ArrayType(LongType()), True),
]
)
options = {"multiLine": "true"}
json_file_path_str = f"{data_path}/readers/json_file/dummy.json"
auto_loader = AutoLoader(
format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema
)

auto_loader.execute()
result = auto_loader.output.df

data_expected = [
{"string": "string1", "int": 1, "array": [1, 11, 111]},
{"string": "string2", "int": 2, "array": [2, 22, 222]},
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_different_explicit_schema_string(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = "string STRING,array ARRAY<BIGINT>"
options = {"multiLine": "true"}
json_file_path_str = f"{data_path}/readers/json_file/dummy.json"
auto_loader = AutoLoader(
format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema
)

auto_loader.execute()
result = auto_loader.output.df

data_expected = [
{"string": "string1", "array": [1, 11, 111]},
{"string": "string2", "array": [2, 22, 222]},
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)