Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing merge clauses in DeltaTableWriter #44

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.head_ref }}
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
- name: Fetch main branch
run: git fetch origin main:main
- name: Check changes
Expand Down
30 changes: 19 additions & 11 deletions src/koheesio/spark/writers/delta/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,8 @@ def _merge_builder_from_args(self):
.merge(self.df.alias(source_alias), merge_cond)
)

valid_clauses = [
"whenMatchedUpdate",
"whenNotMatchedInsert",
"whenMatchedDelete",
"whenNotMatchedBySourceUpdate",
"whenNotMatchedBySourceDelete",
]

for merge_clause in merge_clauses:
clause_type = merge_clause.pop("clause", None)
if clause_type not in valid_clauses:
raise ValueError(f"Invalid merge clause '{clause_type}' provided")

method = getattr(builder, clause_type)
builder = method(**merge_clause)

Expand Down Expand Up @@ -314,6 +303,25 @@ def _validate_table(cls, table):
return DeltaTableStep(table=table)
return table

@field_validator("params")
def _validate_params(cls, params):
"""Validates params. If an array of merge clauses is provided, they will be validated against the available
ones in DeltaMergeBuilder"""

valid_clauses = {m for m in dir(DeltaMergeBuilder) if m.startswith("when")}

if "merge_builder" in params:
merge_builder = params["merge_builder"]
if isinstance(merge_builder, list):
for merge_conf in merge_builder:
clause = merge_conf.get("clause")
if clause not in valid_clauses:
raise ValueError(f"Invalid merge clause '{clause}' provided")
elif not isinstance(merge_builder, DeltaMergeBuilder):
raise ValueError("merge_builder must be a list or merge clauses or a DeltaMergeBuilder instance")

return params

@classmethod
def get_output_mode(cls, choice: str, options: Set[Type]) -> Union[BatchOutputMode, StreamingOutputMode]:
"""Retrieve an OutputMode by validating `choice` against a set of option OutputModes.
Expand Down
25 changes: 13 additions & 12 deletions tests/spark/writers/delta/test_delta_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,24 +308,25 @@ def test_merge_from_args(spark, dummy_df):
)


def test_merge_from_args_raise_value_error(spark, dummy_df):
table_name = "test_table_merge_from_args_value_error"
dummy_df.write.format("delta").saveAsTable(table_name)

writer = DeltaTableWriter(
df=dummy_df,
table=table_name,
output_mode=BatchOutputMode.MERGE,
output_mode_params={
@pytest.mark.parametrize(
"output_mode_params",
[
{
"merge_builder": [
{"clause": "NOT-SUPPORTED-MERGE-CLAUSE", "set": {"id": "source.id"}, "condition": "source.id=target.id"}
],
"merge_cond": "source.id=target.id",
},
)

{"merge_builder": MagicMock()},
],
)
def test_merge_from_args_raise_value_error(spark, output_mode_params):
with pytest.raises(ValueError):
writer._merge_builder_from_args()
DeltaTableWriter(
table="test_table_merge",
output_mode=BatchOutputMode.MERGE,
output_mode_params=output_mode_params,
)


def test_merge_no_table(spark):
Expand Down