Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swap Dag Parsing to use the TaskSDK machinery. #44972

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 25 additions & 80 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,38 @@
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from pydantic import BaseModel

from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.typing_compat import Self


class CallbackRequest:
class CallbackRequest(BaseModel):
"""
Base Class with information about the callback to be executed.

:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""

def __init__(
self,
full_filepath: str,
processor_subdir: str | None = None,
msg: str | None = None,
):
self.full_filepath = full_filepath
self.processor_subdir = processor_subdir
self.msg = msg

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __repr__(self):
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)
full_filepath: str
"""File Path to use to run the callback"""
processor_subdir: str | None = None
"""Directory used by Dag Processor when parsed the dag"""
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/zombie"""

@classmethod
def from_json(cls, json_str: str):
json_object = json.loads(json_str)
return cls(**json_object)
def from_json(cls, data: str | bytes | bytearray) -> Self:
return cls.model_validate_json(data)

def to_json(self, **kwargs) -> str:
return self.model_dump_json(**kwargs)


class TaskCallbackRequest(CallbackRequest):
Expand All @@ -67,25 +56,12 @@ class TaskCallbackRequest(CallbackRequest):

A Class with information about the success/failure TI callback to be executed. Currently, only failure
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.

:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.task_callback_type = task_callback_type
ti: ti_datamodel.TaskInstance
"""Simplified Task Instance representation"""
task_callback_type: TaskInstanceState | None = None
"""Whether on success, on failure, on retry"""

@property
def is_failure_callback(self) -> bool:
Expand All @@ -98,42 +74,11 @@ def is_failure_callback(self) -> bool:
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.serialization.serialized_objects import BaseSerialization

val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
"""
A Class with information about the success/failure DAG callback to be executed.

:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
"""A Class with information about the success/failure DAG callback to be executed."""

def __init__(
self,
full_filepath: str,
dag_id: str,
run_id: str,
processor_subdir: str | None,
is_failure_callback: bool | None = True,
msg: str | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
dag_id: str
run_id: str
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
8 changes: 1 addition & 7 deletions airflow/cli/commands/local_commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import Any

from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option
Expand All @@ -36,11 +35,10 @@
def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
"""Create DagFileProcessorProcess instance."""
processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
return DagProcessorJobRunner(
job=Job(),
processor=DagFileProcessorManager(
processor_timeout=processor_timeout,
processor_timeout=processor_timeout_seconds,
dag_directory=args.subdir,
max_runs=args.num_runs,
),
Expand All @@ -54,10 +52,6 @@ def dag_processor(args):
if not conf.getboolean("scheduler", "standalone_dag_processor"):
raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.")

sql_conn: str = conf.get("database", "sql_alchemy_conn").lower()
if sql_conn.startswith("sqlite"):
raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")

job_runner = _create_dag_processor_job_runner(args)

reload_configuration_for_dag_processing()
Expand Down
84 changes: 49 additions & 35 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from __future__ import annotations

import itertools
import logging
import traceback
from typing import TYPE_CHECKING, NamedTuple
Expand Down Expand Up @@ -64,12 +63,13 @@
from sqlalchemy.sql import Select

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.typing_compat import Self

log = logging.getLogger(__name__)


def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> Iterator[DagModel]:
def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]:
for dag in dags:
orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
Expand Down Expand Up @@ -124,7 +124,7 @@ class _RunInfo(NamedTuple):
num_active_runs: dict[str, int]

@classmethod
def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
def calculate(cls, dags: dict[str, MaybeSerializedDAG], *, session: Session) -> Self:
"""
Query the the run counts from the db.

Expand Down Expand Up @@ -169,7 +169,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se
)


def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None):
def _serialize_dag_capturing_errors(dag: MaybeSerializedDAG, session: Session, processor_subdir: str | None):
"""
Try to serialize the dag to the DB, but make a note of any errors.

Expand All @@ -192,7 +192,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
_sync_dag_perms(dag, session=session)
else:
# Check and update DagCode
DagCode.update_source_code(dag)
DagCode.update_source_code(dag.dag_id, dag.fileloc)
ashb marked this conversation as resolved.
Show resolved Hide resolved
return []
except OperationalError:
raise
Expand All @@ -202,7 +202,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))]


def _sync_dag_perms(dag: DAG, session: Session):
def _sync_dag_perms(dag: MaybeSerializedDAG, session: Session):
"""Sync DAG specific permissions."""
dag_id = dag.dag_id

Expand Down Expand Up @@ -270,7 +270,7 @@ def _update_import_errors(


def update_dag_parsing_results_in_db(
dags: Collection[DAG],
dags: Collection[MaybeSerializedDAG],
import_errors: dict[str, str],
processor_subdir: str | None,
warnings: set[DagWarning],
Expand Down Expand Up @@ -347,7 +347,7 @@ def update_dag_parsing_results_in_db(
class DagModelOperation(NamedTuple):
"""Collect DAG objects and perform database operations for them."""

dags: dict[str, DAG]
dags: dict[str, MaybeSerializedDAG]

def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
Expand Down Expand Up @@ -380,6 +380,8 @@ def update_dags(
processor_subdir: str | None = None,
session: Session,
) -> None:
from airflow.configuration import conf

# we exclude backfill from active run counts since their concurrency is separate
run_info = _RunInfo.calculate(
dags=self.dags,
Expand All @@ -393,19 +395,41 @@ def update_dags(
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm.default_view = dag.default_view or conf.get("webserver", "dag_default_view").lower()
if hasattr(dag, "_dag_display_property_value"):
dm._dag_display_property_value = dag._dag_display_property_value
elif dag.dag_display_name != dag.dag_id:
dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)

# These "is not None" checks are because with a LazySerializedDag object where the user hasn't
# specified an explicit value, we don't get the default values from the config in the lazy
# serialized ver
# we just
if dag.max_active_tasks is not None:
dm.max_active_tasks = dag.max_active_tasks
ashb marked this conversation as resolved.
Show resolved Hide resolved
elif dag.max_active_tasks is None and dm.max_active_tasks is None:
dm.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag")

if dag.max_active_runs is not None:
dm.max_active_runs = dag.max_active_runs
elif dag.max_active_runs is None and dm.max_active_runs is None:
dm.max_active_runs = conf.getint("core", "max_active_runs_per_dag")

if dag.max_consecutive_failed_dag_runs is not None:
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
elif dag.max_consecutive_failed_dag_runs is None and dm.max_consecutive_failed_dag_runs is None:
dm.max_consecutive_failed_dag_runs = conf.getint(
"core", "max_consecutive_failed_dag_runs_per_dag"
)

if hasattr(dag, "has_task_concurrency_limits"):
dm.has_task_concurrency_limits = dag.has_task_concurrency_limits
else:
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.asset_expression = dag.timetable.asset_condition.as_expression()
Expand All @@ -419,7 +443,7 @@ def update_dags(
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
dm.next_dagrun_create_after = None
else:
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) # type: ignore[arg-type]
ashb marked this conversation as resolved.
Show resolved Hide resolved

if not dag.timetable.asset_condition:
dm.schedule_asset_references = []
Expand All @@ -436,24 +460,20 @@ def update_dags(
dm.dag_owner_links = []


def _find_all_assets(dags: Iterable[DAG]) -> Iterator[Asset]:
def _find_all_assets(dags: Iterable[MaybeSerializedDAG]) -> Iterator[Asset]:
for dag in dags:
for _, asset in dag.timetable.asset_condition.iter_assets():
yield asset
for task in dag.task_dict.values():
for obj in itertools.chain(task.inlets, task.outlets):
if isinstance(obj, Asset):
yield obj
for _, alias in dag.get_task_assets(of_type=Asset):
yield alias


def _find_all_asset_aliases(dags: Iterable[DAG]) -> Iterator[AssetAlias]:
def _find_all_asset_aliases(dags: Iterable[MaybeSerializedDAG]) -> Iterator[AssetAlias]:
for dag in dags:
for _, alias in dag.timetable.asset_condition.iter_asset_aliases():
yield alias
for task in dag.task_dict.values():
for obj in itertools.chain(task.inlets, task.outlets):
if isinstance(obj, AssetAlias):
yield obj
for _, alias in dag.get_task_assets(of_type=AssetAlias):
yield alias


def _find_active_assets(name_uri_assets, session: Session):
Expand Down Expand Up @@ -500,7 +520,7 @@ class AssetModelOperation(NamedTuple):
asset_aliases: dict[str, AssetAlias]

@classmethod
def collect(cls, dags: dict[str, DAG]) -> Self:
def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self:
coll = cls(
schedule_asset_references={
dag_id: [asset for _, asset in dag.timetable.asset_condition.iter_assets()]
Expand All @@ -511,13 +531,7 @@ def collect(cls, dags: dict[str, DAG]) -> Self:
for dag_id, dag in dags.items()
},
outlet_references={
dag_id: [
(task_id, outlet)
for task_id, task in dag.task_dict.items()
for outlet in task.outlets
if isinstance(outlet, Asset)
]
for dag_id, dag in dags.items()
dag_id: list(dag.get_task_assets(inlets=False, outlets=True)) for dag_id, dag in dags.items()
},
assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())},
asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())},
Expand Down
Loading