Skip to content

Commit

Permalink
[dagster-airlift] Handle run retries
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Nov 6, 2024
1 parent d4fa3bc commit 0db4861
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

from dagster_airlift.constants import DAG_ID_TAG_KEY, DAG_RUN_ID_TAG_KEY, TASK_ID_TAG_KEY

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION, VERIFICATION_QUERY
from .dagster_run_utils import PARENT_RUN_ID_TAG, RETRY_NUMBER_TAG, DagsterRunResult
from .gql_queries import (
ASSET_NODES_QUERY,
RUNS_BY_TAG_QUERY,
RUNS_QUERY,
TRIGGER_ASSETS_MUTATION,
VERIFICATION_QUERY,
)
from .partition_utils import (
PARTITION_NAME_TAG,
PartitioningInformation,
Expand Down Expand Up @@ -122,16 +129,16 @@ def launch_dagster_run(
launch_data = self.get_valid_graphql_response(response, "launchPipelineExecution")
return launch_data["run"]["id"]

def get_dagster_run_status(
def get_dagster_run_obj(
self, session: requests.Session, dagster_url: str, run_id: str
) -> str:
) -> Mapping[str, Any]:
response = session.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
# Timeout in seconds
timeout=3,
)
return self.get_valid_graphql_response(response, "runOrError")["status"]
return self.get_valid_graphql_response(response, "runOrError")

def get_attribute_from_airflow_context(self, context: Context, attribute: str) -> Any:
if attribute not in context or context[attribute] is None:
Expand Down Expand Up @@ -203,15 +210,77 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N
),
)
logger.info("Waiting for dagster run completion...")
while status := self.get_dagster_run_status(session, dagster_url, run_id):
self.wait_for_run_and_retries(session=session, dagster_url=dagster_url, run_id=run_id)
logger.info("All runs completed successfully.")
return None

def wait_for_run(
self, session: requests.Session, dagster_url: str, run_id: str
) -> DagsterRunResult:
while response := self.get_dagster_run_obj(session, dagster_url, run_id):
status = response["status"]
if status in ["SUCCESS", "FAILURE", "CANCELED"]:
break
time.sleep(self.dagster_run_status_poll_interval)
if status != "SUCCESS":
raise Exception(f"Dagster run {run_id} did not complete successfully.")
logger.info("All runs completed successfully.")
tags = {tag["key"]: tag["value"] for tag in response["tags"]}
return DagsterRunResult(status=response["status"], tags=tags)

def wait_for_run_and_retries(
self, session: requests.Session, dagster_url: str, run_id: str
) -> None:
run_id_to_check = run_id
while result := self.wait_for_run(
session=session, dagster_url=dagster_url, run_id=run_id_to_check
):
if result.success:
break
elif result.run_retries_configured and result.has_remaining_retries:
logger.info(
f"Run {run_id} completed with {result.status} status ({result.retry_number}/{result.max_retries}). Waiting for retried run..."
)
run_id_to_check = self.search_for_retried_run(
parent_run_id=run_id_to_check,
expected_retry_number=result.retry_number + 1,
session=session,
dagster_url=dagster_url,
)
logger.info(f"Found retry {run_id_to_check}. Waiting for completion...")
continue
else:
raise Exception(
f"Run {run_id_to_check} failed, and there are no remaining retries."
)
return None

def make_runs_query_with_filter(
self, runs_filter: Mapping[str, Any], session: requests.Session, dagster_url: str
) -> Sequence[Mapping[str, Any]]:
response = session.post(
f"{dagster_url}/graphql",
json={"query": RUNS_BY_TAG_QUERY, "variables": {"filter": runs_filter}},
)
return self.get_valid_graphql_response(response, "runsOrError")["results"]

def search_for_retried_run(
self,
parent_run_id: str,
expected_retry_number: int,
session: requests.Session,
dagster_url: str,
) -> str:
runs_filter = _build_runs_filter_param(
tags={RETRY_NUMBER_TAG: str(expected_retry_number), PARENT_RUN_ID_TAG: parent_run_id}
)
while runs := self.make_runs_query_with_filter(
runs_filter=runs_filter, session=session, dagster_url=dagster_url
):
if len(runs) == 0:
# Maybe use a new var here
time.sleep(self.dagster_run_status_poll_interval)
continue
return next(iter(runs))["id"]
raise Exception("Should never get here")

def execute(self, context: Context) -> Any:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
Expand Down Expand Up @@ -266,3 +335,7 @@ def _build_dagster_run_execution_params(

def _is_asset_node_executable(asset_node: Mapping[str, Any]) -> bool:
return bool(asset_node["jobs"])


def _build_runs_filter_param(tags: Mapping[str, Any]) -> Mapping[str, Any]:
return {"tags": [{"key": key, "value": value} for key, value in tags.items()]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any, Mapping, NamedTuple, Optional

TERMINAL_STATI = ["SUCCESS", "FAILURE", "CANCELED"]
SYSTEM_TAG_PREFIX = "dagster/"
MAX_RETRIES_TAG = f"{SYSTEM_TAG_PREFIX}max_retries"
RETRY_NUMBER_TAG = f"{SYSTEM_TAG_PREFIX}retry_number"
PARENT_RUN_ID_TAG = f"{SYSTEM_TAG_PREFIX}parent_run_id"
SUCCESS_STATUS = "SUCCESS"
RETRY_ON_ASSET_OR_OP_FAILURE_TAG = f"{SYSTEM_TAG_PREFIX}retry_on_asset_or_op_failure"
RUN_FAILURE_REASON_TAG = f"{SYSTEM_TAG_PREFIX}failure_reason"
STEP_FAILURE_REASON = "STEP_FAILURE"


class DagsterRunResult(NamedTuple):
status: str
tags: Mapping[str, Any]

@property
def run_retries_configured(self) -> bool:
return MAX_RETRIES_TAG in self.tags

@property
def has_remaining_retries(self) -> bool:
if MAX_RETRIES_TAG not in self.tags:
raise Exception(
"Tried to retrieve tags from run, but run retries "
"were either not set or not properly configured. Found tags: {self.tags}"
)
return self.max_retries - self.retry_number > 0

@property
def run_will_automatically_retry(self) -> bool:
if not self.run_retries_configured:
return False
if (
not self.should_retry_on_asset_or_op_failure
and self.failure_reason == STEP_FAILURE_REASON
):
return False
return self.has_remaining_retries

@property
def should_retry_on_asset_or_op_failure(self) -> bool:
return get_boolean_tag_value(self.tags.get(RETRY_ON_ASSET_OR_OP_FAILURE_TAG), True)

@property
def failure_reason(self) -> Optional[str]:
return self.tags.get(RUN_FAILURE_REASON_TAG)

@property
def retry_number(self) -> int:
# this is sketchy
return int(self.tags.get(RETRY_NUMBER_TAG, 0))

@property
def success(self) -> bool:
return self.status == SUCCESS_STATUS

@property
def max_retries(self) -> int:
if MAX_RETRIES_TAG not in self.tags:
raise Exception("Could not determine max retries by tag because tag is not set.")
return int(self.tags[MAX_RETRIES_TAG])


def get_boolean_tag_value(tag_value: Optional[str], default_value: bool = False) -> bool:
if tag_value is None:
return default_value

return tag_value.lower() not in {"false", "none", "0", ""}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@
... on Run {
id
status
tags {
key
value
}
__typename
}
}
Expand All @@ -125,3 +129,20 @@
}
}
"""

RUNS_BY_TAG_QUERY = """
query RunsByTagQuery($filter: RunsFilter!) {
runsOrError(filter: $filter) {
... on Runs {
results {
id
status
tags {
key
value
}
}
}
}
}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def setup_dagster(
time.sleep(20)
with stand_up_dagster(dagster_dev_cmd) as process:
yield process
time.sleep(5)


@contextmanager
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pathlib import Path
from typing import Any, Mapping

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster._time import get_current_datetime_midnight
from dagster_airlift.in_airflow import DefaultProxyTaskToDagsterOperator, proxying_to_dagster
from dagster_airlift.in_airflow.dagster_run_utils import (
MAX_RETRIES_TAG,
RETRY_ON_ASSET_OR_OP_FAILURE_TAG,
)
from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml


def print_hello() -> None:
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"retries": 0,
}


# Normally this isn't needed, but we're trying to get away with not using a multi-process-safe run storage
# to test behavior here.
class SetDagsterRetryInfoOperator(DefaultProxyTaskToDagsterOperator):
def default_dagster_run_tags(self, context) -> Mapping[str, Any]:
tags = {**super().default_dagster_run_tags(context), MAX_RETRIES_TAG: "3"}
if self.get_airflow_dag_id(context).endswith("not_step_failure"):
tags[RETRY_ON_ASSET_OR_OP_FAILURE_TAG] = "false"
return tags


with DAG(
dag_id="migrated_asset_has_retries",
default_args=default_args,
schedule=None,
start_date=get_current_datetime_midnight(),
# We pause this dag upon creation to avoid running it immediately
is_paused_upon_creation=False,
) as minute_dag:
PythonOperator(task_id="my_task", python_callable=print_hello)


with DAG(
dag_id="migrated_asset_has_retries_not_step_failure",
default_args=default_args,
schedule=None,
start_date=get_current_datetime_midnight(),
# We pause this dag upon creation to avoid running it immediately
is_paused_upon_creation=False,
) as minute_dag:
PythonOperator(task_id="my_task", python_callable=print_hello)


proxying_to_dagster(
proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"),
global_vars=globals(),
build_from_task_fn=SetDagsterRetryInfoOperator.build_from_task,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: my_task
proxied: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: my_task
proxied: True
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dagster_airlift.core.multiple_tasks import targeted_by_multiple_tasks

from ..airflow_instance import local_airflow_instance
from .retries_configured import just_fails, succeeds_on_final_retry


def make_print_asset(key: str) -> AssetsDefinition:
Expand Down Expand Up @@ -147,6 +148,18 @@ def build_mapped_defs() -> Definitions:
task_mappings={"my_task": [migrated_daily_interval_dag__partitioned]},
),
),
Definitions(
assets=assets_with_task_mappings(
dag_id="migrated_asset_has_retries",
task_mappings={"my_task": [succeeds_on_final_retry]},
)
),
Definitions(
assets=assets_with_task_mappings(
dag_id="migrated_asset_has_retries_not_step_failure",
task_mappings={"my_task": [just_fails]},
)
),
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dagster import AssetExecutionContext, asset, materialize
from dagster._core.storage.tags import MAX_RETRIES_TAG, PARENT_RUN_ID_TAG, RETRY_NUMBER_TAG


# Asset that simulates having run retries activated (so that we don't have to stand up non-sqlite-storage)
@asset
def succeeds_on_final_retry(context: AssetExecutionContext):
if RETRY_NUMBER_TAG not in context.run_tags or int(context.run_tags[RETRY_NUMBER_TAG]) < 2:
# Launch a run of the "next retry"
current_retry = int(context.run_tags.get(RETRY_NUMBER_TAG, 0))
materialize(
[succeeds_on_final_retry],
instance=context.instance,
tags={
**context.run_tags,
RETRY_NUMBER_TAG: str(current_retry + 1),
PARENT_RUN_ID_TAG: context.run_id,
MAX_RETRIES_TAG: "3",
},
)
raise Exception("oops i failed")
return None


@asset
def just_fails():
raise Exception("I fail every time")
Loading

0 comments on commit 0db4861

Please sign in to comment.