Skip to content

Commit

Permalink
Microbatch: store model context var as dict, not ModelNode (#10917)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Oct 28, 2024
1 parent 4d4b05e commit b71ceb3
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 17 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241028-132751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: 'Fix ''model'' jinja context variable type to dict '
time: 2024-10-28T13:27:51.604093-04:00
custom:
Author: michelleark
Issue: "10927"
22 changes: 21 additions & 1 deletion core/dbt/materializations/incremental/microbatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import List, Optional
from typing import Any, Dict, List, Optional

import pytz

Expand Down Expand Up @@ -99,6 +99,26 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]:

return batches

def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]:
"""
Create context with entries that reflect microbatch model + incremental execution state
Assumes self.model has been (re)-compiled with necessary batch filters applied.
"""
batch_context: Dict[str, Any] = {}

# Microbatch model properties
batch_context["model"] = self.model.to_dict()
batch_context["sql"] = self.model.compiled_code
batch_context["compiled_code"] = self.model.compiled_code

# Add incremental context variables for batches running incrementally
if incremental_batch:
batch_context["is_incremental"] = lambda: True
batch_context["should_full_refresh"] = lambda: False

return batch_context

@staticmethod
def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime:
"""Truncates the passed in timestamp based on the batch_size and then applies the offset by the batch_size.
Expand Down
35 changes: 19 additions & 16 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,28 +499,29 @@ def _execute_microbatch_materialization(
materialization_macro: MacroProtocol,
) -> List[RunResult]:
batch_results: List[RunResult] = []
microbatch_builder = MicrobatchBuilder(
model=model,
is_incremental=self._is_incremental(model),
event_time_start=getattr(self.config.args, "EVENT_TIME_START", None),
event_time_end=getattr(self.config.args, "EVENT_TIME_END", None),
default_end_time=self.config.invoked_at,
)
# Indicates whether current batch should be run incrementally
incremental_batch = False

# Note currently (9/30/2024) model.batch_info is only ever _not_ `None`
# IFF `dbt retry` is being run and the microbatch model had batches which
# failed on the run of the model (which is being retried)
if model.batch_info is None:
microbatch_builder = MicrobatchBuilder(
model=model,
is_incremental=self._is_incremental(model),
event_time_start=getattr(self.config.args, "EVENT_TIME_START", None),
event_time_end=getattr(self.config.args, "EVENT_TIME_END", None),
default_end_time=self.config.invoked_at,
)
end = microbatch_builder.build_end_time()
start = microbatch_builder.build_start_time(end)
batches = microbatch_builder.build_batches(start, end)
else:
batches = model.batch_info.failed
# if there is batch info, then don't run as full_refresh and do force is_incremental
# If there is batch info, then don't run as full_refresh and do force is_incremental
# not doing this risks blowing away the work that has already been done
if self._has_relation(model=model):
context["is_incremental"] = lambda: True
context["should_full_refresh"] = lambda: False
incremental_batch = True

# iterate over each batch, calling materialization_macro to get a batch-level run result
for batch_idx, batch in enumerate(batches):
Expand All @@ -542,9 +543,11 @@ def _execute_microbatch_materialization(
batch[0], model.config.batch_size
),
)
context["model"] = model
context["sql"] = model.compiled_code
context["compiled_code"] = model.compiled_code
# Update jinja context with batch context members
batch_context = microbatch_builder.build_batch_context(
incremental_batch=incremental_batch
)
context.update(batch_context)

# Materialize batch and cache any materialized relations
result = MacroGenerator(
Expand All @@ -557,9 +560,9 @@ def _execute_microbatch_materialization(
batch_run_result = self._build_succesful_run_batch_result(
model, context, batch, time.perf_counter() - start_time
)
# Update context vars for future batches
context["is_incremental"] = lambda: True
context["should_full_refresh"] = lambda: False
# At least one batch has been inserted successfully!
incremental_batch = True

except Exception as e:
exception = e
batch_run_result = self._build_failed_run_batch_result(
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/microbatch/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@
select * from {{ ref('input_model') }}
"""

invalid_batch_context_macro_sql = """
{% macro check_invalid_batch_context() %}
{% if model is not mapping %}
{{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }}
{% elif compiled_code and compiled_code is not string %}
{{ exceptions.raise_compiler_error("`compiled_code` is invalid: expected string type") }}
{% elif sql and sql is not string %}
{{ exceptions.raise_compiler_error("`sql` is invalid: expected string type") }}
{% elif is_incremental is not callable %}
{{ exceptions.raise_compiler_error("`is_incremental()` is invalid: expected callable type") }}
{% elif should_full_refresh is not callable %}
{{ exceptions.raise_compiler_error("`should_full_refresh()` is invalid: expected callable type") }}
{% endif %}
{% endmacro %}
"""

microbatch_model_with_context_checks_sql = """
{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
{{ check_invalid_batch_context() }}
select * from {{ ref('input_model') }}
"""

microbatch_model_downstream_sql = """
{{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
select * from {{ ref('microbatch_model') }}
Expand Down Expand Up @@ -324,6 +349,27 @@ def test_run_with_event_time(self, project):
self.assert_row_count(project, "microbatch_model", 5)


class TestMicrobatchJinjaContext(BaseMicrobatchTest):

@pytest.fixture(scope="class")
def macros(self):
return {"check_batch_context.sql": invalid_batch_context_macro_sql}

@pytest.fixture(scope="class")
def models(self):
return {
"input_model.sql": input_model_sql,
"microbatch_model.sql": microbatch_model_with_context_checks_sql,
}

@mock.patch.dict(os.environ, {"DBT_EXPERIMENTAL_MICROBATCH": "True"})
def test_run_with_event_time(self, project):
# initial run -- backfills all data
with patch_microbatch_end_time("2020-01-03 13:57:00"):
run_dbt(["run"])
self.assert_row_count(project, "microbatch_model", 3)


class TestMicrobatchWithInputWithoutEventTime(BaseMicrobatchTest):
@pytest.fixture(scope="class")
def models(self):
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/materializations/incremental/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,33 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_
assert len(actual_batches) == len(expected_batches)
assert actual_batches == expected_batches

def test_build_batch_context_incremental_batch(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=True)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
assert context["compiled_code"] == microbatch_model.compiled_code

assert context["is_incremental"]() is True
assert context["should_full_refresh"]() is False

def test_build_batch_context_incremental_batch_false(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=False)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
assert context["compiled_code"] == microbatch_model.compiled_code

# Only build is_incremental callables when not first batch
assert "is_incremental" not in context
assert "should_full_refresh" not in context

@pytest.mark.parametrize(
"timestamp,batch_size,offset,expected_timestamp",
[
Expand Down

0 comments on commit b71ceb3

Please sign in to comment.