diff --git a/dbt_dry_run/models/manifest.py b/dbt_dry_run/models/manifest.py index bb4c83a..35b5464 100644 --- a/dbt_dry_run/models/manifest.py +++ b/dbt_dry_run/models/manifest.py @@ -29,6 +29,7 @@ class PartitionBy(BaseModel): field: str data_type: Literal["timestamp", "date", "datetime", "int64"] range: Optional[IntPartitionRange] + time_ingestion_partitioning: Optional[bool] @root_validator(pre=True) def lower_data_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/dbt_dry_run/node_runner/incremental_runner.py b/dbt_dry_run/node_runner/incremental_runner.py index da6c6cb..40159d4 100644 --- a/dbt_dry_run/node_runner/incremental_runner.py +++ b/dbt_dry_run/node_runner/incremental_runner.py @@ -4,7 +4,7 @@ from dbt_dry_run import flags from dbt_dry_run.exception import SchemaChangeException, UpstreamFailedException from dbt_dry_run.literals import insert_dependant_sql_literals -from dbt_dry_run.models import Table +from dbt_dry_run.models import BigQueryFieldMode, BigQueryFieldType, Table, TableField from dbt_dry_run.models.manifest import Node, OnSchemaChange from dbt_dry_run.node_runner import NodeRunner from dbt_dry_run.results import DryRunResult, DryRunStatus @@ -175,6 +175,32 @@ def _get_full_refresh_config(self, node: Node) -> bool: return node.config.full_refresh return flags.FULL_REFRESH + def _is_time_ingestion_partitioned(self, node: Node) -> bool: + if node.config.partition_by: + if node.config.partition_by.time_ingestion_partitioning is True: + return True + return False + + def _replace_partition_with_time_ingestion_column( + self, dry_run_result: DryRunResult + ) -> DryRunResult: + if not dry_run_result.table: + return dry_run_result + + if not dry_run_result.node.config.partition_by: + return dry_run_result + + new_partition_field = TableField( + name="_PARTITIONTIME", + type=BigQueryFieldType.TIMESTAMP, + mode=BigQueryFieldMode.NULLABLE, + ) + + final_fields = [field for field in dry_run_result.table.fields] + final_fields.append(new_partition_field) + + return dry_run_result.replace_table(Table(fields=final_fields)) + def run(self, node: Node) -> DryRunResult: try: sql_with_literals = insert_dependant_sql_literals(node, self._results) @@ -202,4 +228,10 @@ def run(self, node: Node) -> DryRunResult: if result.status == DryRunStatus.SUCCESS: result = handler(result, target_table) + if ( + result.status == DryRunStatus.SUCCESS + and self._is_time_ingestion_partitioned(node) + ): + result = self._replace_partition_with_time_ingestion_column(result) + return result diff --git a/integration/projects/test_incremental/models/partition_by_time_ingestion.sql b/integration/projects/test_incremental/models/partition_by_time_ingestion.sql new file mode 100644 index 0000000..68f7f7b --- /dev/null +++ b/integration/projects/test_incremental/models/partition_by_time_ingestion.sql @@ -0,0 +1,16 @@ +{{ + config( + materialized="incremental", + partition_by={ + "field": "executed_at", + "data_type": "date", + "time_ingestion_partitioning": true + } + ) +}} + +SELECT + executed_at, + col_1, + col_2 +FROM (SELECT DATE('2024-06-06') as executed_at, "foo" as col_1, "bar" as col_2) diff --git a/integration/projects/test_incremental/test_incremental.py b/integration/projects/test_incremental/test_incremental.py index 551cbc6..b0d0c93 100644 --- a/integration/projects/test_incremental/test_incremental.py +++ b/integration/projects/test_incremental/test_incremental.py @@ -210,3 +210,21 @@ def test_sql_header_and_max_partition( assert_report_node_has_columns_in_order( report_node, ["snapshot_date", "my_string", "my_func_output"] ) + + +def test_partition_by_time_ingestion( + compiled_project: ProjectContext, +): + node_id = "model.test_incremental.partition_by_time_ingestion" + manifest_node = compiled_project.manifest.nodes[node_id] + columns = ["executed_at", "col_1 STRING", "col_2 STRING"] + with compiled_project.create_state(manifest_node, columns, "_PARTITIONTIME", False): + run_result = compiled_project.dry_run() + assert_report_produced(run_result) + report_node = get_report_node_by_id( + run_result.report, + node_id, + ) + assert_report_node_has_columns_in_order( + report_node, ["executed_at", "col_1", "col_2", "_PARTITIONTIME"] + )