Skip to content

Commit

Permalink
Add missing merge clauses in DeltaTableWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
riacom_nike committed Jun 7, 2024
1 parent fc11f0e commit 39ccccf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
30 changes: 19 additions & 11 deletions src/koheesio/spark/writers/delta/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,8 @@ def _merge_builder_from_args(self):
.merge(self.df.alias(source_alias), merge_cond)
)

valid_clauses = [
"whenMatchedUpdate",
"whenNotMatchedInsert",
"whenMatchedDelete",
"whenNotMatchedBySourceUpdate",
"whenNotMatchedBySourceDelete",
]

for merge_clause in merge_clauses:
clause_type = merge_clause.pop("clause", None)
if clause_type not in valid_clauses:
raise ValueError(f"Invalid merge clause '{clause_type}' provided")

method = getattr(builder, clause_type)
builder = method(**merge_clause)

Expand Down Expand Up @@ -314,6 +303,25 @@ def _validate_table(cls, table):
return DeltaTableStep(table=table)
return table

@field_validator("params")
def _validate_params(cls, params):
"""Validates params. If an array of merge clauses is provided, they will be validated against the available
ones in DeltaMergeBuilder"""

valid_clauses = {m for m in dir(DeltaMergeBuilder) if m.startswith("when")}

if "merge_builder" in params:
merge_builder = params["merge_builder"]
if isinstance(merge_builder, list):
for merge_conf in merge_builder:
clause = merge_conf.get("clause")
if clause not in valid_clauses:
raise ValueError(f"Invalid merge clause '{clause}' provided")
elif not isinstance(merge_builder, DeltaMergeBuilder):
raise ValueError("merge_builder must be a list or merge clauses or a DeltaMergeBuilder instance")

return params

@classmethod
def get_output_mode(cls, choice: str, options: Set[Type]) -> Union[BatchOutputMode, StreamingOutputMode]:
"""Retrieve an OutputMode by validating `choice` against a set of option OutputModes.
Expand Down
25 changes: 13 additions & 12 deletions tests/spark/writers/delta/test_delta_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,24 +308,25 @@ def test_merge_from_args(spark, dummy_df):
)


def test_merge_from_args_raise_value_error(spark, dummy_df):
table_name = "test_table_merge_from_args_value_error"
dummy_df.write.format("delta").saveAsTable(table_name)

writer = DeltaTableWriter(
df=dummy_df,
table=table_name,
output_mode=BatchOutputMode.MERGE,
output_mode_params={
@pytest.mark.parametrize(
"output_mode_params",
[
{
"merge_builder": [
{"clause": "NOT-SUPPORTED-MERGE-CLAUSE", "set": {"id": "source.id"}, "condition": "source.id=target.id"}
],
"merge_cond": "source.id=target.id",
},
)

{"merge_builder": MagicMock()},
],
)
def test_merge_from_args_raise_value_error(spark, output_mode_params):
with pytest.raises(ValueError):
writer._merge_builder_from_args()
DeltaTableWriter(
table="test_table_merge",
output_mode=BatchOutputMode.MERGE,
output_mode_params=output_mode_params,
)


def test_merge_no_table(spark):
Expand Down

0 comments on commit 39ccccf

Please sign in to comment.