Skip to content

Commit

Permalink
Incorporate review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendBraeckmans committed Jun 7, 2024
1 parent a48d0f4 commit aa8d4e9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 28 deletions.
33 changes: 8 additions & 25 deletions src/koheesio/spark/readers/databricks/autoloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
Autoloader can ingest JSON, CSV, PARQUET, AVRO, ORC, TEXT, and BINARYFILE file formats.
"""

import json
from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum
from pathlib import Path

from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import StructType
from pyspark.sql.types import AtomicType, StructType

from koheesio.models import Field, field_validator
from koheesio.spark.readers import Reader
Expand Down Expand Up @@ -87,14 +85,15 @@ class AutoLoader(Reader):
description="The location for storing inferred schema and supporting schema evolution, "
"used in `cloudFiles.schemaLocation`.",
)
options: Optional[Dict[str, str]] = Field(
options: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Extra inputs to provide to the autoloader. For a full list of inputs, "
"see https://docs.databricks.com/ingestion/auto-loader/options.html",
)
schema: Optional[Union[Path, str, StructType]] = Field(
schema_: Optional[Union[str, StructType, List[str], Tuple[str, ...], AtomicType]] = Field(
default=None,
description="Explicit schema to infer the schema of the input files.",
description="Explicit schema to apply to the input files.",
alias="schema",
)

@field_validator("format")
Expand All @@ -105,22 +104,6 @@ def validate_format(cls, format_specified):
format_specified = getattr(AutoLoaderFormat, format_specified.upper())
return str(format_specified.value)

@field_validator("schema")
def validate_schema(cls, schema_specified):
"""Validate `schema` value"""
schema = schema_specified
if isinstance(schema, StructType):
return schema

if schema is not None:
schema_path = Path(schema)
if schema_path.exists():
schema = schema_path.read_text()
else:
raise FileNotFoundError(f"Schema file not found at path {schema}")
schema = StructType.fromJson(json.loads(schema))
return schema

def get_options(self):
"""Get the options for the autoloader"""
self.options.update(
Expand All @@ -134,8 +117,8 @@ def get_options(self):
# @property
def reader(self) -> DataStreamReader:
reader = self.spark.readStream.format("cloudFiles")
if self.schema is not None:
reader = reader.schema(self.schema)
if self.schema_ is not None:
reader = reader.schema(self.schema_)
reader = reader.options(**self.get_options())
return reader

Expand Down
30 changes: 27 additions & 3 deletions tests/spark/readers/test_auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def test_invalid_format(bad_format):


def mock_reader(self):
return self.spark.read.format("json").options(**self.options)
reader = self.spark.read.format("json")
if self.schema_ is not None:
reader = reader.schema(self.schema_)
return reader.options(**self.options)


def test_read_json(spark, mocker, data_path):
def test_read_json_infer_schema(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

options = {"multiLine": "true"}
Expand All @@ -51,7 +54,7 @@ def test_read_json(spark, mocker, data_path):
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_schema_defined(spark, mocker, data_path):
def test_read_json_exact_explicit_schema_struct(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = StructType(
Expand All @@ -76,3 +79,24 @@ def test_read_json_schema_defined(spark, mocker, data_path):
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)


def test_read_json_different_explicit_schema_string(spark, mocker, data_path):
mocker.patch("koheesio.spark.readers.databricks.autoloader.AutoLoader.reader", mock_reader)

schema = "string STRING,array ARRAY<BIGINT>"
options = {"multiLine": "true"}
json_file_path_str = f"{data_path}/readers/json_file/dummy.json"
auto_loader = AutoLoader(
format="json", location=json_file_path_str, schema_location="dummy_value", options=options, schema=schema
)

auto_loader.execute()
result = auto_loader.output.df

data_expected = [
{"string": "string1", "array": [1, 11, 111]},
{"string": "string2", "array": [2, 22, 222]},
]
expected_df = spark.createDataFrame(data_expected, schema)
assert_df_equality(result, expected_df, ignore_column_order=True)

0 comments on commit aa8d4e9

Please sign in to comment.