Skip to content

Commit

Permalink
Make it possible to set explicit schema i.s.o letting AutoLoader infe…
Browse files Browse the repository at this point in the history
…r the schema (#40)

Make it possible to set explicit schema i.s.o letting AutoLoader infer the schema.
Optional `schema` argument has been added

## Related Issue
[koheesio-39](#39)

## Motivation and Context
The previous implementation of AutoLoader within koheesio would always infer the schema from the files it reads. In a lot of cases this was unnecessary and might even give issues if the input data doesn't contain the required fields.

## How Has This Been Tested?
- Through UTs
- On DBX
  • Loading branch information
BrendBraeckmans committed Jun 10, 2024
1 parent 4c4d062 commit 2f040f0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 8 deletions.
23 changes: 17 additions & 6 deletions src/koheesio/spark/readers/databricks/autoloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Autoloader can ingest JSON, CSV, PARQUET, AVRO, ORC, TEXT, and BINARYFILE file formats.
"""

from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum

from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import AtomicType, StructType

from koheesio.models import Field, field_validator
from koheesio.spark.readers import Reader

Expand Down Expand Up @@ -53,7 +56,7 @@ class AutoLoader(Reader):
Example
-------
```python
from koheesio.spark.readers.databricks import AutoLoader, AutoLoaderFormat
from koheesio.steps.readers.databricks import AutoLoader, AutoLoaderFormat
result_df = AutoLoader(
format=AutoLoaderFormat.JSON,
Expand Down Expand Up @@ -82,11 +85,16 @@ class AutoLoader(Reader):
description="The location for storing inferred schema and supporting schema evolution, "
"used in `cloudFiles.schemaLocation`.",
)
options: Optional[Dict[str, str]] = Field(
options: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Extra inputs to provide to the autoloader. For a full list of inputs, "
"see https://docs.databricks.com/ingestion/auto-loader/options.html",
)
schema_: Optional[Union[str, StructType, List[str], Tuple[str, ...], AtomicType]] = Field(
default=None,
description="Explicit schema to apply to the input files.",
alias="schema",
)

@field_validator("format")
def validate_format(cls, format_specified):
Expand All @@ -107,9 +115,12 @@ def get_options(self):
return self.options

# @property
def reader(self):
"""Return the reader for the autoloader"""
return self.spark.readStream.format("cloudFiles").options(**self.get_options())
def reader(self) -> DataStreamReader:
reader = self.spark.readStream.format("cloudFiles")
if self.schema_ is not None:
reader = reader.schema(self.schema_)
reader = reader.options(**self.get_options())
return reader

def execute(self):
"""Reads from the given location with the given options using Autoloader"""
Expand Down
55 changes: 53 additions & 2 deletions tests/spark/readers/test_auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def test_invalid_format(bad_format):


def mock_reader(self):
return self.spark.read.format("json").options(**self.options)
reader = self.spark.read.format("json")
if self.schema_ is not None:
reader = reader.schema(self.schema_)
return reader.options(**self.options)


def test_read_json(spark, mocker, data_path):
def test_read_json_infer_schema(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

options = {"multiLine": "true"}
Expand All @@ -49,3 +52,51 @@ def test_read_json(spark, mocker, data_path):
]
expected_df = spark.createDataFrame(data_expected, schema_expected)
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_exact_explicit_schema_struct(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = StructType(
[
StructField("string", StringType(), True),
StructField("int", LongType(), True),
StructField("array", ArrayType(LongType()), True),
]
)
options = {"multiLine": "true"}
json_file_path_str = f"{data_path}/readers/json_file/dummy.json"
auto_loader = AutoLoader(
format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema
)

auto_loader.execute()
result = auto_loader.output.df

data_expected = [
{"string": "string1", "int": 1, "array": [1, 11, 111]},
{"string": "string2", "int": 2, "array": [2, 22, 222]},
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_different_explicit_schema_string(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = "string STRING,array ARRAY<BIGINT>"
options = {"multiLine": "true"}
json_file_path_str = f"{data_path}/readers/json_file/dummy.json"
auto_loader = AutoLoader(
format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema
)

auto_loader.execute()
result = auto_loader.output.df

data_expected = [
{"string": "string1", "array": [1, 11, 111]},
{"string": "string2", "array": [2, 22, 222]},
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)

0 comments on commit 2f040f0

Please sign in to comment.