Skip to content

Commit

Permalink
do not push stale update to related DagRun on TI update after task ex…
Browse files Browse the repository at this point in the history
…ecution

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Dec 29, 2024
1 parent a22faa5 commit 47a1a2c
Showing 4 changed files with 59 additions and 8 deletions.
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ def _execute_callbacks(
dagbag: DagBag, callback_requests: list[CallbackRequest], log: FilteringBoundLogger
) -> None:
for request in callback_requests:
log.debug("Processing Callback Request", request=request)
log.debug("Processing Callback Request", request=request.to_json())
if isinstance(request, TaskCallbackRequest):
raise NotImplementedError(
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
3 changes: 2 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
@@ -1126,7 +1126,7 @@ def _handle_failure(
)

if not test_mode:
TaskInstance.save_to_db(failure_context["ti"], session)
TaskInstance.save_to_db(task_instance, session)

with Trace.start_span_from_taskinstance(ti=task_instance) as span:
span.set_attributes(
@@ -3146,6 +3146,7 @@ def fetch_handle_failure_context(
@staticmethod
@provide_session
def save_to_db(ti: TaskInstance, session: Session = NEW_SESSION):
ti.get_dagrun().refresh_from_db()
ti.updated_at = timezone.utcnow()
session.merge(ti)
session.flush()
52 changes: 46 additions & 6 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
from uuid import uuid4

import pendulum
import psutil
import pytest
import time_machine
import uuid6
@@ -83,6 +84,7 @@
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sensors.base import BaseSensorOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.settings import reconfigure_orm
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
@@ -2298,11 +2300,11 @@ def test_outlet_assets(self, create_task_instance):
session.flush()

run_id = str(uuid4())
dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow())
task = dag1.get_task("producing_task_1")
task.bash_command = "echo 1" # make it go faster
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
ti._run_raw_task()
@@ -2357,10 +2359,12 @@ def test_outlet_assets_failed(self, create_task_instance):
dagbag.collect_dags(only_if_updated=False, safe_mode=False)
dagbag.sync_to_db(session=session)
run_id = str(uuid4())
dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(
dag_with_fail_task.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow()
)
task = dag_with_fail_task.get_task("fail_task")
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
with pytest.raises(AirflowFailException):
@@ -2416,10 +2420,12 @@ def test_outlet_assets_skipped(self):
session.flush()

run_id = str(uuid4())
dr = DagRun(dag_with_skip_task.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(
dag_with_skip_task.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow()
)
task = dag_with_skip_task.get_task("skip_task")
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
ti._run_raw_task()
@@ -3588,6 +3594,40 @@ def test_handle_failure(self, create_dummy_dag, session=None):
assert "task_instance" in context_arg_3
mock_on_retry_3.assert_not_called()

@provide_session
def test_handle_failure_does_not_push_stale_dagrun_model(self, dag_maker, create_dummy_dag, session=None):
session = settings.Session()
with dag_maker():

def method(): ...

task = PythonOperator(task_id="mytask", python_callable=method)
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task.task_id)
ti.state = State.RUNNING

assert dr.state == DagRunState.RUNNING

session.merge(ti)
session.flush()
session.commit()

pid = os.fork()
if pid:
process = psutil.Process(pid)
dr.state = DagRunState.SUCCESS
session.merge(dr)
session.flush()
session.commit()
process.wait(timeout=5)
else:
reconfigure_orm(disable_connection_pool=True)
ti.handle_failure("should not update related models")
os._exit(0)

dr.refresh_from_db()
assert dr.state == DagRunState.SUCCESS

def test_handle_failure_updates_queued_task_updates_state(self, dag_maker):
session = settings.Session()
with dag_maker():
10 changes: 10 additions & 0 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
@@ -114,6 +114,16 @@ def setup_method(self):
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args)
self.dag_run_id = DagRunType.MANUAL.generate_run_id(DEFAULT_DATE)
self.dag_run = DagRun(
dag_id=self.dag.dag_id,
run_id=self.dag_run_id,
run_type=DagRunType.MANUAL,
logical_date=DEFAULT_DATE,
)
with create_session() as session:
session.merge(self.dag_run)
session.flush()
session.commit()

def add_time_sensor(self, task_id=TEST_TASK_ID):
op = TimeSensor(task_id=task_id, target_time=time(0), dag=self.dag)

0 comments on commit 47a1a2c

Please sign in to comment.