diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index dc04d7c..48998d7 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -274,19 +274,8 @@ def _merge_builder_from_args(self): .merge(self.df.alias(source_alias), merge_cond) ) - valid_clauses = [ - "whenMatchedUpdate", - "whenNotMatchedInsert", - "whenMatchedDelete", - "whenNotMatchedBySourceUpdate", - "whenNotMatchedBySourceDelete", - ] - for merge_clause in merge_clauses: clause_type = merge_clause.pop("clause", None) - if clause_type not in valid_clauses: - raise ValueError(f"Invalid merge clause '{clause_type}' provided") - method = getattr(builder, clause_type) builder = method(**merge_clause) @@ -314,6 +303,25 @@ def _validate_table(cls, table): return DeltaTableStep(table=table) return table + @field_validator("params") + def _validate_params(cls, params): + """Validates params. If an array of merge clauses is provided, they will be validated against the available + ones in DeltaMergeBuilder""" + + valid_clauses = {m for m in dir(DeltaMergeBuilder) if m.startswith("when")} + + if "merge_builder" in params: + merge_builder = params["merge_builder"] + if isinstance(merge_builder, list): + for merge_conf in merge_builder: + clause = merge_conf.get("clause") + if clause not in valid_clauses: + raise ValueError(f"Invalid merge clause '{clause}' provided") + elif not isinstance(merge_builder, DeltaMergeBuilder): + raise ValueError("merge_builder must be a list or merge clauses or a DeltaMergeBuilder instance") + + return params + @classmethod def get_output_mode(cls, choice: str, options: Set[Type]) -> Union[BatchOutputMode, StreamingOutputMode]: """Retrieve an OutputMode by validating `choice` against a set of option OutputModes. diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 4a36069..66306de 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -308,24 +308,25 @@ def test_merge_from_args(spark, dummy_df): ) -def test_merge_from_args_raise_value_error(spark, dummy_df): - table_name = "test_table_merge_from_args_value_error" - dummy_df.write.format("delta").saveAsTable(table_name) - - writer = DeltaTableWriter( - df=dummy_df, - table=table_name, - output_mode=BatchOutputMode.MERGE, - output_mode_params={ +@pytest.mark.parametrize( + "output_mode_params", + [ + { "merge_builder": [ {"clause": "NOT-SUPPORTED-MERGE-CLAUSE", "set": {"id": "source.id"}, "condition": "source.id=target.id"} ], "merge_cond": "source.id=target.id", }, - ) - + {"merge_builder": MagicMock()}, + ], +) +def test_merge_from_args_raise_value_error(spark, output_mode_params): with pytest.raises(ValueError): - writer._merge_builder_from_args() + DeltaTableWriter( + table="test_table_merge", + output_mode=BatchOutputMode.MERGE, + output_mode_params=output_mode_params, + ) def test_merge_no_table(spark):