diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 07ad648e9630f..d5603d71a6581 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -16,49 +16,38 @@ # under the License. from __future__ import annotations -import json from typing import TYPE_CHECKING +from pydantic import BaseModel + +from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001 from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from airflow.models.taskinstance import SimpleTaskInstance + from airflow.typing_compat import Self -class CallbackRequest: +class CallbackRequest(BaseModel): """ Base Class with information about the callback to be executed. - :param full_filepath: File Path to use to run the callback :param msg: Additional Message that can be used for logging :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ - def __init__( - self, - full_filepath: str, - processor_subdir: str | None = None, - msg: str | None = None, - ): - self.full_filepath = full_filepath - self.processor_subdir = processor_subdir - self.msg = msg - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return NotImplemented - - def __repr__(self): - return str(self.__dict__) - - def to_json(self) -> str: - return json.dumps(self.__dict__) + full_filepath: str + """File Path to use to run the callback""" + processor_subdir: str | None = None + """Directory used by Dag Processor when parsed the dag""" + msg: str | None = None + """Additional Message that can be used for logging to determine failure/zombie""" @classmethod - def from_json(cls, json_str: str): - json_object = json.loads(json_str) - return cls(**json_object) + def from_json(cls, data: str | bytes | bytearray) -> Self: + return cls.model_validate_json(data) + + def to_json(self, **kwargs) -> str: + return self.model_dump_json(**kwargs) class TaskCallbackRequest(CallbackRequest): @@ -67,25 +56,12 @@ class TaskCallbackRequest(CallbackRequest): A Class with information about the success/failure TI callback to be executed. Currently, only failure callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess. - - :param full_filepath: File Path to use to run the callback - :param simple_task_instance: Simplified Task Instance representation - :param msg: Additional Message that can be used for logging to determine failure/zombie - :param processor_subdir: Directory used by Dag Processor when parsed the dag. - :param task_callback_type: e.g. whether on success, on failure, on retry. """ - def __init__( - self, - full_filepath: str, - simple_task_instance: SimpleTaskInstance, - processor_subdir: str | None = None, - msg: str | None = None, - task_callback_type: TaskInstanceState | None = None, - ): - super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) - self.simple_task_instance = simple_task_instance - self.task_callback_type = task_callback_type + ti: ti_datamodel.TaskInstance + """Simplified Task Instance representation""" + task_callback_type: TaskInstanceState | None = None + """Whether on success, on failure, on retry""" @property def is_failure_callback(self) -> bool: @@ -98,42 +74,11 @@ def is_failure_callback(self) -> bool: TaskInstanceState.UPSTREAM_FAILED, } - def to_json(self) -> str: - from airflow.serialization.serialized_objects import BaseSerialization - - val = BaseSerialization.serialize(self.__dict__, strict=True) - return json.dumps(val) - - @classmethod - def from_json(cls, json_str: str): - from airflow.serialization.serialized_objects import BaseSerialization - - val = json.loads(json_str) - return cls(**BaseSerialization.deserialize(val)) - class DagCallbackRequest(CallbackRequest): - """ - A Class with information about the success/failure DAG callback to be executed. - - :param full_filepath: File Path to use to run the callback - :param dag_id: DAG ID - :param run_id: Run ID for the DagRun - :param processor_subdir: Directory used by Dag Processor when parsed the dag. - :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback - :param msg: Additional Message that can be used for logging - """ + """A Class with information about the success/failure DAG callback to be executed.""" - def __init__( - self, - full_filepath: str, - dag_id: str, - run_id: str, - processor_subdir: str | None, - is_failure_callback: bool | None = True, - msg: str | None = None, - ): - super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) - self.dag_id = dag_id - self.run_id = run_id - self.is_failure_callback = is_failure_callback + dag_id: str + run_id: str + is_failure_callback: bool | None = True + """Flag to determine whether it is a Failure Callback or Success Callback""" diff --git a/airflow/cli/commands/local_commands/dag_processor_command.py b/airflow/cli/commands/local_commands/dag_processor_command.py index f0c3bc5060ca1..653c5f6bf577f 100644 --- a/airflow/cli/commands/local_commands/dag_processor_command.py +++ b/airflow/cli/commands/local_commands/dag_processor_command.py @@ -19,7 +19,6 @@ from __future__ import annotations import logging -from datetime import timedelta from typing import Any from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option @@ -36,11 +35,10 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: """Create DagFileProcessorProcess instance.""" processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") - processor_timeout = timedelta(seconds=processor_timeout_seconds) return DagProcessorJobRunner( job=Job(), processor=DagFileProcessorManager( - processor_timeout=processor_timeout, + processor_timeout=processor_timeout_seconds, dag_directory=args.subdir, max_runs=args.num_runs, ), @@ -54,10 +52,6 @@ def dag_processor(args): if not conf.getboolean("scheduler", "standalone_dag_processor"): raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.") - sql_conn: str = conf.get("database", "sql_alchemy_conn").lower() - if sql_conn.startswith("sqlite"): - raise SystemExit("Standalone DagProcessor is not supported when using sqlite.") - job_runner = _create_dag_processor_job_runner(args) reload_configuration_for_dag_processing() diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index fff1cc64adfae..753c6a2fe8a76 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -64,6 +64,7 @@ from sqlalchemy.sql import Select from airflow.models.dagwarning import DagWarning + from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -169,7 +170,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se ) -def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None): +def _serialize_dag_capturing_errors(dag: LazyDeserializedDAG, session: Session, processor_subdir: str | None): """ Try to serialize the dag to the DB, but make a note of any errors. @@ -192,7 +193,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir _sync_dag_perms(dag, session=session) else: # Check and update DagCode - DagCode.update_source_code(dag) + DagCode.update_source_code(dag.dag_id, dag.fileloc) return [] except OperationalError: raise @@ -202,7 +203,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))] -def _sync_dag_perms(dag: DAG, session: Session): +def _sync_dag_perms(dag: LazyDeserializedDAG, session: Session): """Sync DAG specific permissions.""" dag_id = dag.dag_id @@ -270,7 +271,7 @@ def _update_import_errors( def update_dag_parsing_results_in_db( - dags: Collection[DAG], + dags: Collection[LazyDeserializedDAG], import_errors: dict[str, str], processor_subdir: str | None, warnings: set[DagWarning], @@ -393,19 +394,26 @@ def update_dags( dm.is_active = True dm.has_import_errors = False dm.last_parsed_time = utcnow() - dm.default_view = dag.default_view if hasattr(dag, "_dag_display_property_value"): dm._dag_display_property_value = dag._dag_display_property_value elif dag.dag_display_name != dag.dag_id: dm._dag_display_property_value = dag.dag_display_name dm.description = dag.description - dm.max_active_tasks = dag.max_active_tasks - dm.max_active_runs = dag.max_active_runs - dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs - dm.has_task_concurrency_limits = any( - t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None - for t in dag.tasks - ) + # TODO: this `if is not None` is maybe not the best. It's convient though + if dag.max_active_tasks is not None: + dm.max_active_tasks = dag.max_active_tasks + if dag.max_active_runs is not None: + dm.max_active_runs = dag.max_active_runs + if dag.max_consecutive_failed_dag_runs is not None: + dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs + + if hasattr(dag, "has_task_concurrency_limits"): + dm.has_task_concurrency_limits = dag.has_task_concurrency_limits + else: + dm.has_task_concurrency_limits = any( + t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None + for t in dag.tasks + ) dm.timetable_summary = dag.timetable.summary dm.timetable_description = dag.timetable.description dm.asset_expression = dag.timetable.asset_condition.as_expression() diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 57c69238a1f7e..0b586c4dc0cd0 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -19,33 +19,36 @@ from __future__ import annotations -import enum +import functools import importlib import inspect import logging import multiprocessing import os import random +import selectors import signal import sys import time import zipfile from collections import defaultdict, deque -from collections.abc import Iterator, MutableMapping +from collections.abc import Callable, Iterator from datetime import datetime, timedelta from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, NamedTuple, cast +from typing import TYPE_CHECKING, NamedTuple import attrs from setproctitle import setproctitle from sqlalchemy import delete, select, update from tabulate import tabulate +from uuid6 import uuid7 import airflow.models from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf -from airflow.dag_processing.processor import DagFileProcessorProcess +from airflow.dag_processing.collection import update_dag_parsing_results_in_db +from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest from airflow.models.dagwarning import DagWarning @@ -53,9 +56,8 @@ from airflow.models.errors import ParseImportError from airflow.secrets.cache import SecretCache from airflow.stats import Stats -from airflow.traces.tracer import Trace, add_span +from airflow.traces.tracer import Trace from airflow.utils import timezone -from airflow.utils.dates import datetime_to_nano from airflow.utils.file import list_py_file_paths, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.mixins import MultiprocessingStartMethodMixin @@ -94,11 +96,7 @@ class DagFileStat: last_num_of_db_queries: int = 0 -class DagParsingSignal(enum.Enum): - """All signals sent to parser.""" - - TERMINATE_MANAGER = "terminate_manager" - END_MANAGER = "end_manager" +log = logging.getLogger("airflow.processor_manager") class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): @@ -109,7 +107,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): Mainly it can spin up DagFileProcessorManager in a subprocess, collect DAG parsing results from it and communicate signal/DAG parsing stat with it. - This class runs in the main `airflow scheduler` process. + This class runs in the main `airflow scheduler` process when standalone_dag_processor is not enabled. :param dag_directory: Directory where DAG definitions are kept. All files in file_paths should be under this directory @@ -128,14 +126,12 @@ def __init__( self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout - # Map from file path to the processor - self._processors: dict[str, DagFileProcessorProcess] = {} - # Pipe for communicating signals - self._process: multiprocessing.process.BaseProcess | None = None + self._process: multiprocessing.Process | None = None self._done: bool = False # Initialized as true so we do not deactivate w/o any actual DAG parsing. self._all_files_processed = True + # Pipe for communicating signals self._parent_signal_conn: MultiprocessingConnection | None = None self._last_parsing_stat_received_at: float = time.monotonic() @@ -145,7 +141,7 @@ def start(self) -> None: context = self._get_multiprocessing_context() self._last_parsing_stat_received_at = time.monotonic() - self._parent_signal_conn, child_signal_conn = context.Pipe() + parent_signal_conn, child_signal_conn = context.Pipe() process = context.Process( target=type(self)._run_processor_manager, args=( @@ -155,9 +151,14 @@ def start(self) -> None: child_signal_conn, ), ) + self._process = process + self._parent_signal_conn = parent_signal_conn + process.start() + # We don't want this end anymore + child_signal_conn.close() self.log.info("Launched DagFileProcessorManager with pid: %s", process.pid) @@ -177,18 +178,17 @@ def _run_processor_manager( # Make this process start as a new process group - that makes it easy # to kill all sub-process of this at the OS-level, rather than having # to iterate the child processes + set_new_process_group() - span = Trace.get_current_span() - span.set_attribute("dag_directory", str(dag_directory)) setproctitle("airflow scheduler -- DagFileProcessorManager") reload_configuration_for_dag_processing() processor_manager = DagFileProcessorManager( dag_directory=dag_directory, max_runs=max_runs, - processor_timeout=processor_timeout, + processor_timeout=processor_timeout.total_seconds(), signal_conn=signal_conn, ) - processor_manager.start() + processor_manager.run() def heartbeat(self) -> None: """Check if the DagFileProcessorManager process is alive, and process any pending messages.""" @@ -206,18 +206,14 @@ def heartbeat(self) -> None: self._heartbeat_manager() def _process_message(self, message): - span = Trace.get_current_span() self.log.debug("Received message of type %s", type(message).__name__) if isinstance(message, DagParsingStat): - span.set_attribute("all_files_processed", str(message.all_files_processed)) self._sync_metadata(message) else: raise RuntimeError(f"Unexpected message received of type {type(message).__name__}") def _heartbeat_manager(self): """Heartbeat DAG file processor and restart it if we are not done.""" - if not self._parent_signal_conn: - raise ValueError("Process not started.") if self._process and not self._process.is_alive(): self._process.join(timeout=0) if not self.done: @@ -262,10 +258,8 @@ def terminate(self): """Send termination signal to DAG parsing processor manager to terminate all DAG file processors.""" if self._process and self._process.is_alive(): self.log.info("Sending termination message to manager.") - try: - self._parent_signal_conn.send(DagParsingSignal.TERMINATE_MANAGER) - except ConnectionError: - pass + self._parent_signal_conn.send(None) + self._parent_signal_conn.close() def end(self): """Terminate (and then kill) the manager process launched.""" @@ -280,7 +274,16 @@ def end(self): self._parent_signal_conn.close() -class DagFileProcessorManager(LoggingMixin): +def _config_int_factory(section: str, key: str): + return functools.partial(conf.getint, section, key) + + +def _config_bool_factory(section: str, key: str): + return functools.partial(conf.getboolean, section, key) + + +@attrs.define +class DagFileProcessorManager: """ Manage processes responsible for parsing DAGs. @@ -298,74 +301,61 @@ class DagFileProcessorManager(LoggingMixin): :param signal_conn: connection to communicate signal with processor agent. """ - def __init__( - self, - dag_directory: os.PathLike[str], - max_runs: int, - processor_timeout: timedelta, - signal_conn: MultiprocessingConnection | None = None, - ): - super().__init__() - # known files; this will be updated every `dag_dir_list_interval` and stuff added/removed accordingly - self._file_paths: list[str] = [] - self._file_path_queue: deque[str] = deque() - self._max_runs = max_runs - # signal_conn is None for dag_processor_standalone mode. - self._direct_scheduler_conn = signal_conn - self._parsing_start_time: float | None = None - self._dag_directory = dag_directory - # Set the signal conn in to non-blocking mode, so that attempting to - # send when the buffer is full errors, rather than hangs for-ever - # attempting to send (this is to avoid deadlocks!) - if self._direct_scheduler_conn: + _dag_directory: os.PathLike[str] + max_runs: int + processor_timeout: float = attrs.field(factory=_config_int_factory("core", "dag_file_processor_timeout")) + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) + _direct_scheduler_conn: MultiprocessingConnection | None = attrs.field(alias="signal_conn", default=None) + + _parallelism: int = attrs.field(factory=_config_int_factory("scheduler", "parsing_processes")) + + dag_dir_list_interval: int = attrs.field( + factory=_config_int_factory("scheduler", "dag_dir_list_interval") + ) + parsing_cleanup_interval: float = attrs.field( + factory=_config_int_factory("scheduler", "parsing_cleanup_interval") + ) + _file_process_interval: float = attrs.field( + factory=_config_int_factory("scheduler", "min_file_process_interval") + ) + stale_dag_threshold: float = attrs.field(factory=_config_int_factory("scheduler", "stale_dag_threshold")) + last_dag_dir_refresh_time: float = attrs.field(default=0, init=False) + + log: logging.Logger = log + + _last_deactivate_stale_dags_time: float = attrs.field(default=0, init=False) + print_stats_interval: float = attrs.field( + factory=_config_int_factory("scheduler", "print_stats_interval") + ) + last_stat_print_time: float = attrs.field(default=0, init=False) + + heartbeat: Callable[[], None] = attrs.field(default=lambda: None) + """An overridable heartbeat called once every time around the loop""" + + _file_paths: list[str] = attrs.field(factory=list, init=False) + _file_path_queue: deque[str] = attrs.field(factory=deque, init=False) + _file_stats: dict[str, DagFileStat] = attrs.field(factory=lambda: defaultdict(DagFileStat), init=False) + + _processors: dict[str, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) + + _parsing_start_time: float = attrs.field(init=False) + _num_run: int = attrs.field(default=0, init=False) + + _callback_to_execute: dict[str, list[CallbackRequest]] = attrs.field( + factory=lambda: defaultdict(list), init=False + ) + + standalone_dag_processor: bool = attrs.field( + factory=_config_bool_factory("scheduler", "standalone_dag_processor") + ) + max_callbacks_per_loop: int = attrs.field( + factory=_config_int_factory("scheduler", "max_callbacks_per_loop") + ) + + def __attrs_post_init__(self): + if self._direct_scheduler_conn is not None: os.set_blocking(self._direct_scheduler_conn.fileno(), False) - self.standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") - self._parallelism = conf.getint("scheduler", "parsing_processes") - - # Parse and schedule each file no faster than this interval. - self._file_process_interval = conf.getint("scheduler", "min_file_process_interval") - # How often to print out DAG file processing stats to the log. Default to - # 30 seconds. - self.print_stats_interval = conf.getint("scheduler", "print_stats_interval") - - # Map from file path to the processor - self._processors: dict[str, DagFileProcessorProcess] = {} - - self._num_run = 0 - - # Map from file path to stats about the file - self._file_stats: MutableMapping[str, DagFileStat] = defaultdict(DagFileStat) - - # Last time that the DAG dir was traversed to look for files - self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0)) - # Last time stats were printed - self.last_stat_print_time = 0 - # Last time we cleaned up DAGs which are no longer in files - self.last_deactivate_stale_dags_time = timezone.make_aware(datetime.fromtimestamp(0)) - # How often to check for DAGs which are no longer in files - self.parsing_cleanup_interval = conf.getint("scheduler", "parsing_cleanup_interval") - # How long to wait for a DAG to be reparsed after its file has been parsed before disabling - self.stale_dag_threshold = conf.getint("scheduler", "stale_dag_threshold") - # How long to wait before timing out a process to parse a DAG file - self._processor_timeout = processor_timeout - # How often to scan the DAGs directory for new files. Default to 5 minutes. - self.dag_dir_list_interval = conf.getint("scheduler", "dag_dir_list_interval") - - # Mapping file name and callbacks requests - self._callback_to_execute: dict[str, list[CallbackRequest]] = defaultdict(list) - - self._log = logging.getLogger("airflow.processor_manager") - - self.waitables: dict[Any, MultiprocessingConnection | DagFileProcessorProcess] = ( - { - self._direct_scheduler_conn: self._direct_scheduler_conn, - } - if self._direct_scheduler_conn is not None - else {} - ) - self.heartbeat: Callable[[], None] = lambda: None - def register_exit_signals(self): """Register signals that stop child processes.""" signal.signal(signal.SIGINT, self._exit_gracefully) @@ -382,7 +372,7 @@ def _exit_gracefully(self, signum, frame): self.log.debug("Finished terminating DAG processors.") sys.exit(os.EX_OK) - def start(self): + def run(self): """ Use multiple processes to parse and generate tasks for the DAGs in parallel. @@ -391,8 +381,6 @@ def start(self): """ self.register_exit_signals() - set_new_process_group() - self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) self.log.info( @@ -402,25 +390,24 @@ def start(self): return self._run_parsing_loop() def _scan_stale_dags(self): - """Scan at fix internal DAGs which are no longer present in files.""" - now = timezone.utcnow() - elapsed_time_since_refresh = (now - self.last_deactivate_stale_dags_time).total_seconds() + """Scan and deactivate DAGs which are no longer present in files.""" + now = time.monotonic() + elapsed_time_since_refresh = now - self._last_deactivate_stale_dags_time if elapsed_time_since_refresh > self.parsing_cleanup_interval: last_parsed = { fp: stat.last_finish_time for fp, stat in self._file_stats.items() if stat.last_finish_time } - DagFileProcessorManager.deactivate_stale_dags( + self.deactivate_stale_dags( last_parsed=last_parsed, dag_directory=self.get_dag_directory(), stale_dag_threshold=self.stale_dag_threshold, ) - self.last_deactivate_stale_dags_time = timezone.utcnow() + self._last_deactivate_stale_dags_time = time.monotonic() - @classmethod @provide_session def deactivate_stale_dags( - cls, - last_parsed: dict[str, datetime], + self, + last_parsed: dict[str, datetime | None], dag_directory: str, stale_dag_threshold: int, session: Session = NEW_SESSION, @@ -428,8 +415,7 @@ def deactivate_stale_dags( """Detect and deactivate DAGs which are no longer present in files.""" to_deactivate = set() query = select(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).where(DagModel.is_active) - standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") - if standalone_dag_processor: + if self.standalone_dag_processor: query = query.where(DagModel.processor_subdir == dag_directory) dags_parsed = session.execute(query) @@ -442,7 +428,7 @@ def deactivate_stale_dags( dag.fileloc in last_parsed and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) < last_parsed[dag.fileloc] ): - cls.logger().info("DAG %s is missing and will be deactivated.", dag.dag_id) + self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) to_deactivate.add(dag.dag_id) if to_deactivate: @@ -454,104 +440,69 @@ def deactivate_stale_dags( ) deactivated = deactivated_dagmodel.rowcount if deactivated: - cls.logger().info("Deactivated %i DAGs which are no longer present in file.", deactivated) + self.log.info("Deactivated %i DAGs which are no longer present in file.", deactivated) def _run_parsing_loop(self): - poll_time = 0.0 + # initialize cache to mutualize calls to Variable.get in DAGs + # needs to be done before this process is forked to create the DAG parsing processes. + SecretCache.init() - self._refresh_dag_dir() - self.prepare_file_path_queue() - max_callbacks_per_loop = conf.getint("scheduler", "max_callbacks_per_loop") + if self._direct_scheduler_conn is not None: + self.selector.register( + self._direct_scheduler_conn, selectors.EVENT_READ, self._read_from_direct_scheduler_conn + ) + + poll_time = 0.0 - self.start_new_processes() while True: - with Trace.start_span(span_name="dag_parsing_loop", component="DagFileProcessorManager") as span: - loop_start_time = time.monotonic() - ready = multiprocessing.connection.wait(self.waitables.keys(), timeout=poll_time) - if span.is_recording(): - span.add_event(name="heartbeat") - self.heartbeat() - if self._direct_scheduler_conn is not None and self._direct_scheduler_conn in ready: - agent_signal = self._direct_scheduler_conn.recv() - - self.log.debug("Received %s signal from DagFileProcessorAgent", agent_signal) - if agent_signal == DagParsingSignal.TERMINATE_MANAGER: - self.terminate() - break - elif agent_signal == DagParsingSignal.END_MANAGER: - self.end() - sys.exit(os.EX_OK) - elif isinstance(agent_signal, CallbackRequest): - self._add_callback_to_queue(agent_signal) - else: - raise ValueError(f"Invalid message {type(agent_signal)}") - - for sentinel in ready: - if sentinel is not self._direct_scheduler_conn: - processor = self.waitables.get(sentinel) - if processor: - self._collect_results_from_processor(processor) - self.waitables.pop(sentinel) - self._processors.pop(processor.file_path) - - if self.standalone_dag_processor: - for callback in DagFileProcessorManager._fetch_callbacks( - max_callbacks_per_loop, self.standalone_dag_processor, self.get_dag_directory() - ): - self._add_callback_to_queue(callback) - self._scan_stale_dags() - DagWarning.purge_inactive_dag_warnings() - refreshed_dag_dir = self._refresh_dag_dir() - - if span.is_recording(): - span.add_event(name="_kill_timed_out_processors") - self._kill_timed_out_processors() - - # Generate more file paths to process if we processed all the files already. Note for this - # to clear down, we must have cleared all files found from scanning the dags dir _and_ have - # cleared all files added as a result of callbacks - if not self._file_path_queue: - self.emit_metrics() - if span.is_recording(): - span.add_event(name="prepare_file_path_queue") - self.prepare_file_path_queue() + loop_start_time = time.monotonic() + + self.heartbeat() + + refreshed_dag_dir = self._refresh_dag_dir() + self._kill_timed_out_processors() + + if not self._file_path_queue: + # Generate more file paths to process if we processed all the files already. Note for this to + # clear down, we must have cleared all files found from scanning the dags dir _and_ have + # cleared all files added as a result of callbacks + self.prepare_file_path_queue() + self.emit_metrics() + elif refreshed_dag_dir: # if new files found in dag dir, add them - elif refreshed_dag_dir: - if span.is_recording(): - span.add_event(name="add_new_file_path_to_queue") - self.add_new_file_path_to_queue() + self.add_new_file_path_to_queue() + + self._refresh_requested_filelocs() - self._refresh_requested_filelocs() - if span.is_recording(): - span.add_event(name="start_new_processes") - self.start_new_processes() + self._start_new_processes() - # Update number of loop iteration. - self._num_run += 1 + self._service_processor_sockets(timeout=poll_time) - # Collect anything else that has finished, but don't kick off any more processors - if span.is_recording(): - span.add_event(name="collect_results") - self.collect_results() + self._collect_results() - if span.is_recording(): - span.add_event(name="print_stat") - self._print_stat() + if self.standalone_dag_processor: + for callback in self._fetch_callbacks(): + self._add_callback_to_queue(callback) + self._scan_stale_dags() + DagWarning.purge_inactive_dag_warnings() + # Update number of loop iteration. + self._num_run += 1 + + self._print_stat() + + if self._direct_scheduler_conn: all_files_processed = all( - self._file_stats[x].last_finish_time is not None for x in self.file_paths + self._file_stats[x].last_finish_time is not None for x in self._file_paths ) - max_runs_reached = self.max_runs_reached() - try: - if self._direct_scheduler_conn: - self._direct_scheduler_conn.send( - DagParsingStat( - max_runs_reached, - all_files_processed, - ) + self._direct_scheduler_conn.send( + DagParsingStat( + self.max_runs_reached(), + all_files_processed, ) + ) except BlockingIOError: # Try again next time around the loop! @@ -560,46 +511,84 @@ def _run_parsing_loop(self): # only drives "max runs") self.log.debug("BlockingIOError received trying to send DagParsingStat, ignoring") - if max_runs_reached: - self.log.info( - "Exiting dag parsing loop as all files have been processed %s times", self._max_runs - ) - if span.is_recording(): - span.add_event( - name="info", - attributes={ - "message": "Exiting dag parsing loop as all files have been processed {self._max_runs} times" - }, - ) - break + if self.max_runs_reached(): + self.log.info( + "Exiting dag parsing loop as all files have been processed %s times", self.max_runs + ) + break - loop_duration = time.monotonic() - loop_start_time - if loop_duration < 1: - poll_time = 1 - loop_duration - else: - poll_time = 0.0 + loop_duration = time.monotonic() - loop_start_time + if loop_duration < 1: + poll_time = 1 - loop_duration + else: + poll_time = 0.0 + + def _service_processor_sockets(self, timeout: float | None = 1.0): + """ + Service subprocess events by polling sockets for activity. + + This runs `select` (or a platform equivalent) to look for activity on the sockets connected to the + parsing subprocesses, and calls the registered handler function for each socket. + + All the parsing processes socket handlers are registered into a single Selector + """ + events = self.selector.select(timeout=timeout) + for key, _ in events: + socket_handler = key.data + need_more = socket_handler(key.fileobj) + + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + + def _read_from_direct_scheduler_conn(self, conn: MultiprocessingConnection) -> bool: + try: + agent_signal = conn.recv() + except (EOFError, ConnectionError): + self.terminate() + sys.exit(os.EX_OK) + + self.log.debug("Received %s signal from DagFileProcessorAgent", agent_signal) + if isinstance(agent_signal, CallbackRequest): + self._add_callback_to_queue(agent_signal) + elif agent_signal is None: + self.terminate() + sys.exit(os.EX_OK) + else: + raise ValueError(f"Invalid message {type(agent_signal)}") + return True + + def _refresh_requested_filelocs(self) -> None: + """Refresh filepaths from dag dir as requested by users via APIs.""" + # Get values from DB table + filelocs = self._get_priority_filelocs() + for fileloc in filelocs: + # Try removing the fileloc if already present + try: + self._file_path_queue.remove(fileloc) + except ValueError: + pass + # enqueue fileloc to the start of the queue. + self._file_path_queue.appendleft(fileloc) - @classmethod @provide_session @retry_db_transaction def _fetch_callbacks( - cls, - max_callbacks: int, - standalone_dag_processor: bool, - dag_directory: str, + self, session: Session = NEW_SESSION, ) -> list[CallbackRequest]: """Fetch callbacks from database and add them to the internal queue for execution.""" - cls.logger().debug("Fetching callbacks from the database.") + self.log.debug("Fetching callbacks from the database.") callback_queue: list[CallbackRequest] = [] + dag_directory = self.get_dag_directory() with prohibit_commit(session) as guard: query = select(DbCallbackRequest) - if standalone_dag_processor: + if self.standalone_dag_processor: query = query.where( DbCallbackRequest.processor_subdir == dag_directory, ) - query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks) + query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(self.max_callbacks_per_loop) query = with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True) callbacks = session.scalars(query) for callback in callbacks: @@ -607,7 +596,7 @@ def _fetch_callbacks( callback_queue.append(callback.get_callback_request()) session.delete(callback) except Exception as e: - cls.logger().warning("Error adding callback for execution: %s, %s", callback, e) + self.log.warning("Error adding callback for execution: %s, %s", callback, e) guard.commit() return callback_queue @@ -624,19 +613,6 @@ def _add_callback_to_queue(self, request: CallbackRequest): self._add_paths_to_queue([request.full_filepath], True) Stats.incr("dag_processing.other_callback_count") - def _refresh_requested_filelocs(self) -> None: - """Refresh filepaths from dag dir as requested by users via APIs.""" - # Get values from DB table - filelocs = DagFileProcessorManager._get_priority_filelocs() - for fileloc in filelocs: - # Try removing the fileloc if already present - try: - self._file_path_queue.remove(fileloc) - except ValueError: - pass - # enqueue fileloc to the start of the queue. - self._file_path_queue.appendleft(fileloc) - @classmethod @provide_session def _get_priority_filelocs(cls, session: Session = NEW_SESSION): @@ -650,47 +626,49 @@ def _get_priority_filelocs(cls, session: Session = NEW_SESSION): def _refresh_dag_dir(self) -> bool: """Refresh file paths from dag dir if we haven't done it for too long.""" - now = timezone.utcnow() - elapsed_time_since_refresh = (now - self.last_dag_dir_refresh_time).total_seconds() - if elapsed_time_since_refresh > self.dag_dir_list_interval: - # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s", self._dag_directory) - self._file_paths = list_py_file_paths(self._dag_directory) - self.last_dag_dir_refresh_time = now - self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) - self.set_file_paths(self._file_paths) + now = time.monotonic() + elapsed_time_since_refresh = now - self.last_dag_dir_refresh_time + if elapsed_time_since_refresh <= self.dag_dir_list_interval: + return False + # Build up a list of Python files that could contain DAGs + self.log.info("Searching for files in %s", self._dag_directory) + self._file_paths = list_py_file_paths(self._dag_directory) + self.last_dag_dir_refresh_time = now + self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) + self.set_file_paths(self._file_paths) + + try: + self.log.debug("Removing old import errors") + self.clear_nonexistent_import_errors() + except Exception: + self.log.exception("Error removing old import errors") + + def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: + """ + Get "full" paths to DAGs if inside ZIP files. + + This is the format used by the remove/delete functions. + """ + if fileloc.endswith(".py") or not zipfile.is_zipfile(fileloc): + yield fileloc + return try: - self.clear_nonexistent_import_errors() - except Exception: - self.log.exception("Error removing old import errors") - - def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: - """ - Get "full" paths to DAGs if inside ZIP files. - - This is the format used by the remove/delete functions. - """ - if fileloc.endswith(".py") or not zipfile.is_zipfile(fileloc): - yield fileloc - return - try: - with zipfile.ZipFile(fileloc) as z: - for info in z.infolist(): - if might_contain_dag(info.filename, True, z): - yield os.path.join(fileloc, info.filename) - except zipfile.BadZipFile: - self.log.exception("There was an error accessing ZIP file %s %s", fileloc) - - dag_filelocs = {full_loc for path in self._file_paths for full_loc in _iter_dag_filelocs(path)} - - DagModel.deactivate_deleted_dags( - dag_filelocs, - processor_subdir=self.get_dag_directory(), - ) + with zipfile.ZipFile(fileloc) as z: + for info in z.infolist(): + if might_contain_dag(info.filename, True, z): + yield os.path.join(fileloc, info.filename) + except zipfile.BadZipFile: + self.log.exception("There was an error accessing ZIP file %s %s", fileloc) + + dag_filelocs = {full_loc for path in self._file_paths for full_loc in _iter_dag_filelocs(path)} + + DagModel.deactivate_deleted_dags( + dag_filelocs, + processor_subdir=self.get_dag_directory(), + ) - return True - return False + return True def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed.""" @@ -711,7 +689,10 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION): query = delete(ParseImportError).where(ParseImportError.processor_subdir == self.get_dag_directory()) if self._file_paths: - query = query.where(ParseImportError.filename.notin_(self._file_paths)) + query = query.where( + ParseImportError.filename.notin_(self._file_paths), + ParseImportError.processor_subdir == self.get_dag_directory(), + ) session.execute(query.execution_options(synchronize_session="fetch")) session.commit() @@ -737,38 +718,39 @@ def _log_file_processing_stats(self, known_file_paths): headers = [ "File Path", "PID", - "Runtime", + "Current Duration", "# DAGs", "# Errors", - "Last Runtime", - "Last Run", - "Last # of DB Queries", + "Last Duration", + "Last Run At", ] rows = [] - now = timezone.utcnow() + utcnow = timezone.utcnow() + now = time.monotonic() for file_path in known_file_paths: stat = self._file_stats[file_path] + proc = self._processors.get(file_path) + num_dags = stat.num_dags + num_errors = stat.import_errors file_name = Path(file_path).stem - processor_pid = self.get_pid(file_path) - processor_start_time = self.get_start_time(file_path) + processor_pid = proc.pid if proc else None + processor_start_time = proc.start_time if proc else None runtime = (now - processor_start_time) if processor_start_time else None last_run = stat.last_finish_time if last_run: - seconds_ago = (now - last_run).total_seconds() + seconds_ago = (utcnow - last_run).total_seconds() Stats.gauge(f"dag_processing.last_run.seconds_ago.{file_name}", seconds_ago) - Stats.gauge(f"dag_processing.last_num_of_db_queries.{file_name}", stat.last_num_of_db_queries) rows.append( ( file_path, processor_pid, runtime, - stat.num_dags, - stat.import_errors, + num_dags, + num_errors, stat.last_duration, last_run, - stat.last_num_of_db_queries, ) ) @@ -784,18 +766,16 @@ def _log_file_processing_stats(self, known_file_paths): num_errors, last_runtime, last_run, - last_num_of_db_queries, ) in rows: formatted_rows.append( ( file_path, pid, - f"{runtime.total_seconds():.2f}s" if runtime else None, + f"{runtime:.2f}s" if runtime else None, num_dags, num_errors, f"{last_runtime:.2f}s" if last_runtime else None, last_run.strftime("%Y-%m-%dT%H:%M:%S") if last_run else None, - last_num_of_db_queries, ) ) log_str = ( @@ -810,36 +790,6 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - def get_pid(self, file_path) -> int | None: - """ - Retrieve the PID of the process processing the given file or None if the file is not being processed. - - :param file_path: the path to the file that's being processed. - """ - if file_path in self._processors: - return self._processors[file_path].pid - return None - - def get_all_pids(self) -> list[int]: - """ - Get all pids. - - :return: a list of the PIDs for the processors that are running - """ - return [x.pid for x in self._processors.values()] - - def get_start_time(self, file_path) -> datetime | None: - """ - Retrieve the last start time for processing a specific path. - - :param file_path: the path to the file that's being processed - :return: the start time of the process that's processing the - specified file or None if the file is not currently being processed. - """ - if file_path in self._processors: - return self._processors[file_path].start_time - return None - def get_dag_directory(self) -> str: """Return the dag_director as a string.""" if isinstance(self._dag_directory, Path): @@ -872,7 +822,7 @@ def set_file_paths(self, new_file_paths): else: self.log.warning("Stopping processor for %s", file_path) Stats.decr("dag_processing.processes", tags={"file_path": file_path, "action": "stop"}) - processor.terminate() + processor.kill(signal.SIGKILL) self._file_stats.pop(file_path) to_remove = set(self._file_stats).difference(self._file_paths) @@ -882,159 +832,62 @@ def set_file_paths(self, new_file_paths): self._processors = filtered_processors - def wait_until_finished(self): - """Sleeps until all the processors are done.""" - for processor in self._processors.values(): - while not processor.done: - time.sleep(0.1) - @provide_session - def _collect_results_from_processor(self, processor, session: Session = NEW_SESSION) -> None: - self.log.debug("Processor for %s finished", processor.file_path) - Stats.decr("dag_processing.processes", tags={"file_path": processor.file_path, "action": "finish"}) - last_finish_time = timezone.utcnow() - - if processor.result is not None: - num_dags, count_import_errors, last_num_of_db_queries = processor.result - else: - self.log.error( - "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code - ) - count_import_errors = -1 - num_dags = 0 - last_num_of_db_queries = 0 - - last_duration = (last_finish_time - processor.start_time).total_seconds() - stat = DagFileStat( - num_dags=num_dags, - import_errors=count_import_errors, - last_finish_time=last_finish_time, - last_duration=last_duration, - run_count=self._file_stats[processor.file_path].run_count + 1, - last_num_of_db_queries=last_num_of_db_queries, - ) - self._file_stats[processor.file_path] = stat - file_name = Path(processor.file_path).stem - """crude exposure of instrumentation code which may need to be furnished""" - span = Trace.get_tracer("DagFileProcessorManager").start_span( - "dag_processing", start_time=datetime_to_nano(processor.start_time) - ) - span.set_attributes( - { - "file_path": processor.file_path, - "run_count": stat.run_count, - } - ) - - if processor.result is None: - span.set_attributes( - { - "error": True, - "processor.exit_code": processor.exit_code, - } - ) - else: - span.set_attributes( - { - "num_dags": num_dags, - "import_errors": count_import_errors, - } + def _collect_results(self, session: Session = NEW_SESSION): + # TODO: Use an explicit session in this fn + finished = [] + for path, proc in self._processors.items(): + if proc.exit_code is None: + # This processor hasn't finished yet + continue + finished.append(path) + + # Collect the DAGS and import errors into the DB, emit metrics etc. + self._file_stats[path] = process_parse_results( + run_duration=time.time() - proc.start_time, + finish_time=timezone.utcnow(), + run_count=self._file_stats[path].run_count, + parsing_result=proc.parsing_result, + path=path, + processor_subdir=self.get_dag_directory(), + session=session, ) - if count_import_errors > 0: - span.set_attribute("error", True) - import_errors = session.scalars( - select(ParseImportError).where(ParseImportError.filename == processor.file_path) - ).all() - for import_error in import_errors: - span.add_event( - name="exception", - attributes={ - "filename": import_error.filename, - "exception.type": "ParseImportError", - "exception.name": "Import error when processing DAG file", - "exception.stacktrace": import_error.stacktrace, - }, - ) - - span.end(end_time=datetime_to_nano(last_finish_time)) - - Stats.timing(f"dag_processing.last_duration.{file_name}", last_duration * 1000.0) - Stats.timing("dag_processing.last_duration", last_duration * 1000.0, tags={"file_name": file_name}) - - def collect_results(self) -> None: - """Collect the result from any finished DAG processors.""" - ready = multiprocessing.connection.wait( - self.waitables.keys() - [self._direct_scheduler_conn], timeout=0 - ) - for sentinel in ready: - if sentinel is not self._direct_scheduler_conn: - processor = cast(DagFileProcessorProcess, self.waitables[sentinel]) - self.waitables.pop(processor.waitable_handle) - self._processors.pop(processor.file_path) - self._collect_results_from_processor(processor) + for path in finished: + self._processors.pop(path) - self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism) + def _create_process(self, file_path): + id = uuid7() - self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) + callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) - @staticmethod - def _create_process(file_path, dag_directory, callback_requests): - """Create DagFileProcessorProcess instance.""" - return DagFileProcessorProcess( - file_path=file_path, - dag_directory=dag_directory, - callback_requests=callback_requests, + return DagFileProcessorProcess.start( + id=id, + path=file_path, + callbacks=callback_to_execute_for_file, + selector=self.selector, ) - @add_span - def start_new_processes(self): + def _start_new_processes(self): """Start more processors if we have enough slots and files to process.""" - # initialize cache to mutualize calls to Variable.get in DAGs - # needs to be done before this process is forked to create the DAG parsing processes. - SecretCache.init() - while self._parallelism > len(self._processors) and self._file_path_queue: file_path = self._file_path_queue.popleft() # Stop creating duplicate processor i.e. processor with the same filepath if file_path in self._processors: continue - callback_to_execute_for_file = self._callback_to_execute[file_path] - processor = self._create_process( - file_path, - self.get_dag_directory(), - callback_to_execute_for_file, - ) - - del self._callback_to_execute[file_path] + processor = self._create_process(file_path) Stats.incr("dag_processing.processes", tags={"file_path": file_path, "action": "start"}) - span = Trace.get_current_span() - span.set_attribute("category", "processing") - processor.start() - self.log.debug("Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path) - if span.is_recording(): - span.add_event( - name="dag_processing processor started", - attributes={"file_path": file_path, "pid": processor.pid}, - ) - self._processors[file_path] = processor - self.waitables[processor.waitable_handle] = processor + self._processors[file_path] = processor Stats.gauge("dag_processing.file_path_queue_size", len(self._file_path_queue)) - @add_span def add_new_file_path_to_queue(self): - for file_path in self.file_paths: + for file_path in self._file_paths: if file_path not in self._file_stats: # We found new file after refreshing dir. add to parsing queue at start self.log.info("Adding new file %s to parsing queue", file_path) self._file_path_queue.appendleft(file_path) - span = Trace.get_current_span() - if span.is_recording(): - span.add_event( - name="adding new file to parsing queue", attributes={"file_path": file_path} - ) def prepare_file_path_queue(self): """ @@ -1051,7 +904,7 @@ def prepare_file_path_queue(self): # Sort the file paths by the parsing order mode list_mode = conf.get("scheduler", "file_parsing_sort_mode") - files_with_mtime = {} + files_with_mtime: dict[str, datetime] = {} file_paths = [] is_mtime_mode = list_mode == "modified_time" @@ -1075,9 +928,8 @@ def prepare_file_path_queue(self): # from being added to file_path_queue # unless they were modified recently and parsing mode is "modified_time" # in which case we don't honor "self._file_process_interval" (min_file_process_interval) - last_finish_time = self._file_stats[file_path].last_finish_time if ( - last_finish_time is not None + (last_finish_time := self._file_stats[file_path].last_finish_time) is not None and (now - last_finish_time).total_seconds() < self._file_process_interval and not (is_mtime_mode and file_modified_time and (file_modified_time > last_finish_time)) ): @@ -1099,7 +951,7 @@ def prepare_file_path_queue(self): ) files_paths_at_run_limit = [ - file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs + file_path for file_path, stat in self._file_stats.items() if stat.run_count == self.max_runs ] file_paths_to_exclude = file_paths_in_progress.union( @@ -1114,59 +966,45 @@ def prepare_file_path_queue(self): ] if self.log.isEnabledFor(logging.DEBUG): - for processor in self._processors.values(): + for path, processor in self._processors.items(): self.log.debug( - "File path %s is still being processed (started: %s)", - processor.file_path, - processor.start_time.isoformat(), + "File path %s is still being processed (started: %s)", path, processor.start_time ) self.log.debug( "Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue) ) - self._add_paths_to_queue(files_paths_to_queue, False) Stats.incr("dag_processing.file_path_queue_update_count") def _kill_timed_out_processors(self): """Kill any file processors that timeout to defend against process hangs.""" - now = timezone.utcnow() + now = time.time() processors_to_remove = [] for file_path, processor in self._processors.items(): duration = now - processor.start_time - if duration > self._processor_timeout: + if duration > self.processor_timeout: self.log.error( - "Processor for %s with PID %s started at %s has timed out, killing it.", + "Processor for %s with PID %s started %d ago killing it.", file_path, processor.pid, - processor.start_time.isoformat(), + duration, ) Stats.decr("dag_processing.processes", tags={"file_path": file_path, "action": "timeout"}) Stats.incr("dag_processing.processor_timeouts", tags={"file_path": file_path}) - # Deprecated; may be removed in a future Airflow release. - Stats.incr("dag_file_processor_timeouts") - processor.kill() - span = Trace.get_current_span() - span.set_attribute("category", "processing") - if span.is_recording(): - span.add_event( - name="dag processing killed processor", - attributes={"file_path": file_path, "action": "timeout"}, - ) + processor.kill(signal.SIGKILL) - # Clean up processor references - self.waitables.pop(processor.waitable_handle) processors_to_remove.append(file_path) stat = DagFileStat( num_dags=0, import_errors=1, - last_finish_time=now, - last_duration=duration.total_seconds(), - run_count=self._file_stats[processor.file_path].run_count + 1, + last_finish_time=timezone.utcnow(), + last_duration=duration, + run_count=self._file_stats[file_path].run_count + 1, last_num_of_db_queries=0, ) - self._file_stats[processor.file_path] = stat + self._file_stats[file_path] = stat # Clean up `self._processors` after iterating over it for proc in processors_to_remove: @@ -1183,26 +1021,22 @@ def _add_paths_to_queue(self, file_paths_to_enqueue: list[str], add_at_front: bo def max_runs_reached(self): """:return: whether all file paths have been processed max_runs times.""" - if self._max_runs == -1: # Unlimited runs. + if self.max_runs == -1: # Unlimited runs. return False - for stat in self._file_stats.values(): - if stat.run_count < self._max_runs: - return False - if self._num_run < self._max_runs: + if self._num_run < self.max_runs: return False - return True + return all(stat.run_count >= self.max_runs for stat in self._file_stats.values()) def terminate(self): """Stop all running processors.""" - for processor in self._processors.values(): - Stats.decr( - "dag_processing.processes", tags={"file_path": processor.file_path, "action": "terminate"} - ) - processor.terminate() + for file_path, processor in self._processors.items(): + Stats.decr("dag_processing.processes", tags={"file_path": file_path, "action": "terminate"}) + # SIGTERM, wait 5s, SIGKILL if still alive + processor.kill(signal.SIGTERM, escalation_delay=5.0) def end(self): """Kill all child processes on exit since we don't want to leave them as orphaned.""" - pids_to_kill = self.get_all_pids() + pids_to_kill = [p.pid for p in self._processors.values()] if pids_to_kill: kill_child_processes_by_pids(pids_to_kill) @@ -1228,10 +1062,6 @@ def emit_metrics(self): } ) - @property - def file_paths(self): - return self._file_paths - def reload_configuration_for_dag_processing(): # Reload configurations and settings to avoid collision with parent process. @@ -1254,3 +1084,40 @@ def reload_configuration_for_dag_processing(): importlib.reload(airflow.settings) airflow.settings.initialize() del os.environ["CONFIG_PROCESSOR_MANAGER_LOGGER"] + + +def process_parse_results( + run_duration: float, + finish_time: datetime, + run_count: int, + path: str, + parsing_result: DagFileParsingResult | None, + processor_subdir: str | None, + session: Session, +) -> DagFileStat: + """Take the parsing result and stats about the parser process and convert it into a DagFileState.""" + stat = DagFileStat( + last_finish_time=finish_time, + last_duration=run_duration, + run_count=run_count + 1, + ) + + file_name = Path(path).stem + Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) + Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) + + if parsing_result is None: + stat.import_errors = 1 + else: + # record DAGs and import errors to database + update_dag_parsing_results_in_db( + dags=parsing_result.serialized_dags, + import_errors=parsing_result.import_errors or {}, + warnings=set(parsing_result.warnings or []), + processor_subdir=processor_subdir, + session=session, + ) + stat.num_dags = len(parsing_result.serialized_dags) + if parsing_result.import_errors: + stat.import_errors = len(parsing_result.import_errors) + return stat diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index a2b4474402588..b5b5a4a0ba821 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -16,610 +16,228 @@ # under the License. from __future__ import annotations -import importlib -import logging import os -import signal -import threading -import time -import zipfile -from collections.abc import Generator, Iterable -from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from setproctitle import setproctitle -from sqlalchemy import event - -from airflow import settings +import sys +import traceback +from collections.abc import Generator +from typing import TYPE_CHECKING, Annotated, Callable, Literal, Union + +import attrs +import pydantic + from airflow.callbacks.callback_requests import ( + CallbackRequest, DagCallbackRequest, TaskCallbackRequest, ) from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.models.dag import DAG from airflow.models.dagbag import DagBag -from airflow.models.pool import Pool -from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import TaskInstance, _run_finished_callback +from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.execution_time.supervisor import WatchedSubprocess +from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats -from airflow.utils import timezone -from airflow.utils.file import iter_airflow_imports, might_contain_dag -from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context -from airflow.utils.mixins import MultiprocessingStartMethodMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - import multiprocessing - from datetime import datetime - from multiprocessing.connection import Connection as MultiprocessingConnection - - from sqlalchemy.orm.session import Session - - from airflow.callbacks.callback_requests import CallbackRequest - from airflow.models.operator import Operator + from airflow.typing_compat import Self + from airflow.utils.context import Context + + +def _parse_file_entrypoint(): + import os + + import structlog + + from airflow.sdk.execution_time import task_runner + # Parse DAG file, send JSON back up! + + comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + input=sys.stdin, + decoder=pydantic.TypeAdapter[DagFileParseRequest](DagFileParseRequest), + ) + msg = comms_decoder.get_message() + comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + + log = structlog.get_logger(logger_name="task") + + result = _parse_file(msg, log) + comms_decoder.send_request(log, result) + + +def _parse_file(msg: DagFileParseRequest, log): + # TODO: Set known_pool names on DagBag! + bag = DagBag( + dag_folder=msg.file, + include_examples=False, + safe_mode=True, + load_op_links=False, + ) + serialized_dags, serialization_import_errors = _serialize_dags(bag, log) + bag.import_errors.update(serialization_import_errors) + dags = [LazyDeserializedDAG(data=serdag) for serdag in serialized_dags] + result = DagFileParsingResult( + fileloc=msg.file, + serialized_dags=dags, + import_errors=bag.import_errors, + # TODO: Make `bag.dag_warnings` not return SQLA model objects + warnings=[], + ) + + if msg.callback_requests: + _execute_callbacks(bag, msg.callback_requests, log) + return result + + +def _serialize_dags(bag, log): + serialization_import_errors = {} + serialized_dags = [] + for dag in bag.dags.values(): + try: + serialized_dag = SerializedDAG.to_dict(dag) + serialized_dags.append(serialized_dag) + except Exception: + log.exception("Failed to serialize DAG: %s", dag.fileloc) + dagbag_import_error_traceback_depth = conf.getint( + "core", "dagbag_import_error_traceback_depth", fallback=None + ) + serialization_import_errors[dag.fileloc] = traceback.format_exc( + limit=-dagbag_import_error_traceback_depth + ) + return serialized_dags, serialization_import_errors -@dataclass -class _QueryCounter: - queries_number: int = 0 +def _execute_callbacks(dagbag: DagBag, callback_requests: list[CallbackRequest], log): + for request in callback_requests: + log.debug("Processing Callback Request", request=request) + if isinstance(request, TaskCallbackRequest): + raise NotImplementedError("Haven't coded Task callback yet!") + # _execute_task_callbacks(dagbag, request) + elif isinstance(request, DagCallbackRequest): + _execute_dag_callbacks(dagbag, request, log) - def inc(self): - self.queries_number += 1 +def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log): + dag = dagbag.dags[request.dag_id] -@contextmanager -def count_queries(session: Session) -> Generator[_QueryCounter, None, None]: - # using list allows to read the updated counter from what context manager returns - counter: _QueryCounter = _QueryCounter() + callbacks = dag.on_failure_callback if request.is_failure_callback else dag.on_success_callback + if not callbacks: + log.warning("Callback requested, but dag didn't have any", dag_id=request.dag_id) + return - @event.listens_for(session, "do_orm_execute") - def _count_db_queries(orm_execute_state): - nonlocal counter - counter.inc() + callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + # TODO:We need a proper context object! + context: Context = {} - yield counter - event.remove(session, "do_orm_execute", _count_db_queries) + for callback in callbacks: + log.info( + "Executing on_%s dag callback", + "failure" if request.is_failure_callback else "success", + fn=callback, + dag_id=request.dag_id, + ) + try: + callback(context) + except Exception: + log.exception("Callback failed", dag_id=request.dag_id) + Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id}) -class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin): +class DagFileParseRequest(pydantic.BaseModel): """ - Runs DAG processing in a separate process using DagFileProcessor. + Request for DAG File Parsing. - :param file_path: a Python file containing Airflow DAG definitions - :param callback_requests: failure callback to execute + This is the request that the manager will send to the DAG parser with the dag file and + any other necessary metadata. """ - # Counter that increments every time an instance of this class is created - class_creation_counter = 0 - - def __init__( - self, - file_path: str, - dag_directory: str, - callback_requests: list[CallbackRequest], - ): - super().__init__() - self._file_path = file_path - self._dag_directory = dag_directory - self._callback_requests = callback_requests - - # The process that was launched to process the given DAG file. - self._process: multiprocessing.process.BaseProcess | None = None - # The result of DagFileProcessor.process_file(file_path). - self._result: tuple[int, int, int] | None = None - # Whether the process is done running. - self._done = False - # When the process started. - self._start_time: datetime | None = None - # This ID is used to uniquely name the process / thread that's launched - # by this processor instance - self._instance_id = DagFileProcessorProcess.class_creation_counter - - self._parent_channel: MultiprocessingConnection | None = None - DagFileProcessorProcess.class_creation_counter += 1 - - @property - def file_path(self) -> str: - return self._file_path - - @staticmethod - def _run_file_processor( - result_channel: MultiprocessingConnection, - parent_channel: MultiprocessingConnection, - file_path: str, - thread_name: str, - dag_directory: str, - callback_requests: list[CallbackRequest], - known_pools: set[str] | None = None, - ) -> None: - """ - Process the given file. - - :param result_channel: the connection to use for passing back the result - :param parent_channel: the parent end of the channel to close in the child - :param file_path: the file to process - :param thread_name: the name to use for the process that is launched - :param callback_requests: failure callback to execute - :return: the process that was launched - """ - # This helper runs in the newly created process - log: logging.Logger = logging.getLogger("airflow.processor") - - # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in - # the child, else it won't get closed properly until we exit. - parent_channel.close() - del parent_channel - - set_context(log, file_path) - setproctitle(f"airflow scheduler - DagFileProcessor {file_path}") - - def _handle_dag_file_processing(): - # Re-configure the ORM engine as there are issues with multiple processes - settings.configure_orm() - - # Change the thread name to differentiate log lines. This is - # really a separate process, but changing the name of the - # process doesn't work, so changing the thread name instead. - threading.current_thread().name = thread_name - - log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) - dag_file_processor = DagFileProcessor(dag_directory=dag_directory, log=log) - result: tuple[int, int, int] = dag_file_processor.process_file( - file_path=file_path, - callback_requests=callback_requests, - known_pools=known_pools, - ) - result_channel.send(result) - - try: - DAG_PROCESSOR_LOG_TARGET = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET") - if DAG_PROCESSOR_LOG_TARGET == "stdout": - with Stats.timer() as timer: - _handle_dag_file_processing() - else: - # The following line ensures that stdout goes to the same destination as the logs. If stdout - # gets sent to logs and logs are sent to stdout, this leads to an infinite loop. This - # necessitates this conditional based on the value of DAG_PROCESSOR_LOG_TARGET. - with ( - redirect_stdout(StreamLogWriter(log, logging.INFO)), - redirect_stderr(StreamLogWriter(log, logging.WARNING)), - Stats.timer() as timer, - ): - _handle_dag_file_processing() - log.info("Processing %s took %.3f seconds", file_path, timer.duration) - except Exception: - # Log exceptions through the logging framework. - log.exception("Got an exception! Propagating...") - raise - finally: - # We re-initialized the ORM within this Process above so we need to - # tear it down manually here - settings.dispose_orm() - - result_channel.close() - - def start(self) -> None: - """Launch the process and start processing the DAG.""" - if conf.getboolean("scheduler", "parsing_pre_import_modules", fallback=True): - # Read the file to pre-import airflow modules used. - # This prevents them from being re-imported from zero in each "processing" process - # and saves CPU time and memory. - zip_file_paths = [] - if zipfile.is_zipfile(self.file_path): - try: - with zipfile.ZipFile(self.file_path) as z: - zip_file_paths.extend( - [ - os.path.join(self.file_path, info.filename) - for info in z.infolist() - if might_contain_dag(info.filename, True, z) - ] - ) - except zipfile.BadZipFile as err: - self.log.error("There was an err accessing %s, %s", self.file_path, err) - if zip_file_paths: - self.import_modules(zip_file_paths) - else: - self.import_modules(self.file_path) - - context = self._get_multiprocessing_context() - - pool_names = {p.pool for p in Pool.get_pools()} - - _parent_channel, _child_channel = context.Pipe(duplex=False) - process = context.Process( - target=type(self)._run_file_processor, - args=( - _child_channel, - _parent_channel, - self.file_path, - f"DagFileProcessor{self._instance_id}", - self._dag_directory, - self._callback_requests, - pool_names, - ), - name=f"DagFileProcessor{self._instance_id}-Process", - ) - self._process = process - self._start_time = timezone.utcnow() - process.start() - - # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it - # from closing in some cases - _child_channel.close() - del _child_channel - - # Don't store it on self until after we've started the child process - we don't want to keep it from - # getting GCd/closed - self._parent_channel = _parent_channel - - def kill(self) -> None: - """Kill the process launched to process the file, and ensure consistent state.""" - if self._process is None: - raise AirflowException("Tried to kill before starting!") - self._kill_process() - - def terminate(self, sigkill: bool = False) -> None: - """ - Terminate (and then kill) the process launched to process the file. - - :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work. - """ - if self._process is None or self._parent_channel is None: - raise AirflowException("Tried to call terminate before starting!") - - self._process.terminate() - # Arbitrarily wait 5s for the process to die - with suppress(TimeoutError): - self._process._popen.wait(5) # type: ignore - if sigkill: - self._kill_process() - self._parent_channel.close() - - def _kill_process(self) -> None: - if self._process is None: - raise AirflowException("Tried to kill process before starting!") - - if self._process.is_alive() and self._process.pid: - self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid) - os.kill(self._process.pid, signal.SIGKILL) - - # Reap the spawned zombie. We active wait, because in Python 3.9 `waitpid` might lead to an - # exception, due to change in Python standard library and possibility of race condition - # see https://bugs.python.org/issue42558 - while self._process._popen.poll() is None: # type: ignore - time.sleep(0.001) - if self._parent_channel: - self._parent_channel.close() + file: str + requests_fd: int + callback_requests: list[CallbackRequest] = pydantic.Field(default_factory=list) + type: Literal["DagFileParseRequest"] = "DagFileParseRequest" - @property - def pid(self) -> int: - """PID of the process launched to process the given file.""" - if self._process is None or self._process.pid is None: - raise AirflowException("Tried to get PID before starting!") - return self._process.pid - @property - def exit_code(self) -> int | None: - """ - After the process is finished, this can be called to get the return code. +class DagFileParsingResult(pydantic.BaseModel): + """ + Result of DAG File Parsing. - :return: the exit code of the process - """ - if self._process is None: - raise AirflowException("Tried to get exit code before starting!") - if not self._done: - raise AirflowException("Tried to call retcode before process was finished!") - return self._process.exitcode + This is the result of a successful DAG parse, in this class, we gather all serialized DAGs, + import errors and warnings to send back to the scheduler to store in the DB. + """ - @property - def done(self) -> bool: - """ - Check if the process launched to process this file is done. + fileloc: str + serialized_dags: list[LazyDeserializedDAG] + warnings: list | None = None + import_errors: dict[str, str] | None = None + type: Literal["DagFileParsingResult"] = "DagFileParsingResult" - :return: whether the process is finished running - """ - if self._process is None or self._parent_channel is None: - raise AirflowException("Tried to see if it's done before starting!") - if self._done: - return True +ToParent = Annotated[ + Union[DagFileParsingResult, GetConnection, GetVariable], + pydantic.Field(discriminator="type"), +] - if self._parent_channel.poll(): - try: - self._result = self._parent_channel.recv() - self._done = True - self.log.debug("Waiting for %s", self._process) - self._process.join() - self._parent_channel.close() - return True - except EOFError: - # If we get an EOFError, it means the child end of the pipe has been closed. This only happens - # in the finally block. But due to a possible race condition, the process may have not yet - # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a - # "suitable" timeout. - self._done = True - # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable. - self._process.join(timeout=5) - if self._process.is_alive(): - # Didn't shut down cleanly - kill it - self._kill_process() - - if not self._process.is_alive(): - self._done = True - self.log.debug("Waiting for %s", self._process) - self._process.join() - self._parent_channel.close() - return True - - return False - @property - def result(self) -> tuple[int, int, int] | None: - """Result of running ``DagFileProcessor.process_file()``.""" - if not self.done: - raise AirflowException("Tried to get the result before it's done!") - return self._result - - @property - def start_time(self) -> datetime: - """Time when this started to process the file.""" - if self._start_time is None: - raise AirflowException("Tried to get start time before it started!") - return self._start_time - - @property - def waitable_handle(self): - return self._process.sentinel - - def import_modules(self, file_path: str | Iterable[str]): - def _import_modules(filepath): - for module in iter_airflow_imports(filepath): - try: - importlib.import_module(module) - except Exception as e: - # only log as warning because an error here is not preventing anything from working, and - # if it's serious, it's going to be surfaced to the user when the dag is actually parsed. - self.log.warning( - "Error when trying to pre-import module '%s' found in %s: %s", - module, - file_path, - e, - ) - - if isinstance(file_path, str): - _import_modules(file_path) - elif isinstance(file_path, Iterable): - for path in file_path: - _import_modules(path) - - -class DagFileProcessor(LoggingMixin): +@attrs.define() +class DagFileProcessorProcess(WatchedSubprocess): """ - Process a Python file containing Airflow DAGs. - - This includes: - - 1. Execute the file and look for DAG objects in the namespace. - 2. Execute any Callbacks if passed to DagFileProcessor.process_file - 3. Serialize the DAGs and save it to DB (or update existing record in the DB). - 4. Record any errors importing the file into ORM + Parses dags with Task SDK API. - Returns a tuple of 'number of dags found' and 'the count of import errors' - - :param log: Logger to save the processing process + Since DAGs are written with the Task SDK, we need to parse them in a task SDK process such that + we can use the Task SDK definitions when serializing. This prevents potential conflicts with classes + in core Airflow. """ - UNIT_TEST_MODE: bool = conf.getboolean("core", "UNIT_TEST_MODE") - - def __init__(self, dag_directory: str, log: logging.Logger): - super().__init__() - self._log = log - self._dag_directory = dag_directory - self._last_num_of_db_queries = 0 + parsing_result: DagFileParsingResult | None = None @classmethod - @provide_session - def execute_callbacks( + def start( # type: ignore[override] cls, - dagbag: DagBag, - callback_requests: list[CallbackRequest], - unit_test_mode: bool, - session: Session = NEW_SESSION, - ) -> None: - """ - Execute on failure callbacks. - - These objects can come from SchedulerJobRunner or from DagProcessorJobRunner. - - :param dagbag: Dag Bag of dags - :param callback_requests: failure callbacks to execute - :param session: DB session. - - :return: number of queries executed - """ - for request in callback_requests: - cls.logger().debug("Processing Callback Request: %s", request) - try: - if isinstance(request, TaskCallbackRequest): - cls._execute_task_callbacks(dagbag, request, unit_test_mode, session=session) - elif isinstance(request, DagCallbackRequest): - cls._execute_dag_callbacks(dagbag, request, session=session) - except Exception: - cls.logger().exception( - "Error executing %s callback for file: %s", - request.__class__.__name__, - request.full_filepath, - ) - session.flush() - session.commit() + path: str | os.PathLike[str], + callbacks: list[CallbackRequest], + target: Callable[[], None] = _parse_file_entrypoint, + **kwargs, + ) -> Self: + return super().start(path, callbacks, target=target, client=None, **kwargs) # type:ignore[arg-type] + + def _on_child_started( # type: ignore[override] + self, callbacks: list[CallbackRequest], path: str | os.PathLike[str], child_comms_fd: int + ): + msg = DagFileParseRequest( + file=os.fspath(path), + requests_fd=child_comms_fd, + callback_requests=callbacks, + ) + self.stdin.write(msg.model_dump_json().encode() + b"\n") - @classmethod - @provide_session - def execute_callbacks_without_dag( - cls, callback_requests: list[CallbackRequest], unit_test_mode: bool, session: Session = NEW_SESSION - ) -> None: - """ - Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors. - - This is so important so that tasks that failed when there is a parse - error don't get stuck in queued state. - """ - for request in callback_requests: - cls.logger().debug("Processing Callback Request: %s", request) - if isinstance(request, TaskCallbackRequest): - cls._execute_task_callbacks(None, request, unit_test_mode, session) - else: - cls.logger().info( - "Not executing %s callback for file %s as there was a dag parse error", - request.__class__.__name__, - request.full_filepath, - ) - session.flush() - session.commit() + def handle_requests(self, log) -> Generator[None, bytes, None]: + # TODO: Make decoder an instance variable, then this can live in the base class + decoder = pydantic.TypeAdapter[ToParent](ToParent) - @classmethod - def _execute_dag_callbacks(cls, dagbag: DagBag, request: DagCallbackRequest, session: Session): - dag = dagbag.dags[request.dag_id] - callbacks, context = DAG.fetch_callback( - dag=dag, - run_id=request.run_id, - success=not request.is_failure_callback, - reason=request.msg, - session=session, - ) or (None, None) - - if callbacks and context: - DAG.execute_callback(callbacks, context, dag.dag_id) - - @classmethod - def _execute_task_callbacks( - cls, dagbag: DagBag | None, request: TaskCallbackRequest, unit_test_mode: bool, session: Session - ) -> None: - """ - Execute the task callbacks. - - :param dagbag: the DagBag to use to get the task instance - :param request: the task callback request - :param session: the session to use - """ - try: - callback_type = TaskInstanceState(request.task_callback_type) - except ValueError: - callback_type = None - is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED) + while True: + line = yield - # previously we ignored any request besides failures. now if given callback type directly, - # then we respect it and execute it. additionally because in this scenario the callback - # is submitted remotely, we assume there is no need to mess with state; we simply run - # the callback + try: + msg = decoder.validate_json(line) + except Exception: + log.exception("Unable to decode message", line=line) + continue - if not is_remote and not request.is_failure_callback: - return + self._handle_request(msg, log) # type: ignore[arg-type] - simple_ti = request.simple_task_instance - ti = TaskInstance.get_task_instance( - dag_id=simple_ti.dag_id, - run_id=simple_ti.run_id, - task_id=simple_ti.task_id, - map_index=simple_ti.map_index, - session=session, - ) - - if not ti: + def _handle_request(self, msg: ToParent, log): # type: ignore[override] + if isinstance(msg, DagFileParsingResult): + self.parsing_result = msg return + # GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable + super()._handle_request(msg, log) - task: Operator | None = None - - if dagbag and simple_ti.dag_id in dagbag.dags: - dag = dagbag.dags[simple_ti.dag_id] - if simple_ti.task_id in dag.task_ids: - task = dag.get_task(simple_ti.task_id) - else: - # We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run - # `handle_failure` so that the state of the TI gets progressed. - # - # Since handle_failure _really_ wants a task, we do our best effort to give it one - task = SerializedDagModel.get_serialized_dag( - dag_id=simple_ti.dag_id, task_id=simple_ti.task_id, session=session - ) + @property + def exit_code(self) -> int | None: + self._check_subprocess_exit() + return self._exit_code - if task: - ti.refresh_from_task(task) - - if callback_type is TaskInstanceState.SUCCESS: - context = ti.get_template_context(session=session) - if TYPE_CHECKING: - assert ti.task - callbacks = ti.task.on_success_callback - _run_finished_callback(callbacks=callbacks, context=context) - cls.logger().info("Executed callback for %s in state %s", ti, ti.state) - elif not is_remote or callback_type is TaskInstanceState.FAILED: - ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session) - cls.logger().info("Executed callback for %s in state %s", ti, ti.state) - session.flush() + @property + def start_time(self) -> float: + return self._process.create_time() - @classmethod - def _get_dagbag(cls, file_path: str, known_pools: set[str] | None): - try: - return DagBag(file_path, include_examples=False, known_pools=known_pools) - except Exception: - cls.logger().exception("Failed at reloading the DAG file %s", file_path) - Stats.incr("dag_file_refresh_error", tags={"file_path": file_path}) - raise - - @provide_session - def process_file( - self, - file_path: str, - callback_requests: list[CallbackRequest], - known_pools: set[str] | None = None, - session: Session = NEW_SESSION, - ) -> tuple[int, int, int]: - """ - Process a Python file containing Airflow DAGs. - - This includes: - - 1. Execute the file and look for DAG objects in the namespace. - 2. Execute any Callbacks if passed to this method. - 3. Serialize the DAGs and save it to DB (or update existing record in the DB). - 4. Mark any DAGs which are no longer present as inactive - 5. Record any errors importing the file into ORM - - :param file_path: the path to the Python file that should be executed - :param callback_requests: failure callback to execute - :return: number of dags found, count of import errors, last number of db queries - """ - self.log.info("Processing file %s for tasks to queue", file_path) - - with count_queries(session) as query_counter: - try: - dagbag = DagFileProcessor._get_dagbag(file_path, known_pools) - except Exception: - self.log.exception("Failed at reloading the DAG file %s", file_path) - Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": file_path}) - return 0, 0, self._cache_last_num_of_db_queries(query_counter) - - if dagbag.dags: - self.log.info("DAG(s) %s retrieved from %s", ", ".join(map(repr, dagbag.dags)), file_path) - self.execute_callbacks(dagbag, callback_requests, self.UNIT_TEST_MODE) - else: - self.log.warning("No viable dags retrieved from %s", file_path) - if callback_requests: - # If there were callback requests for this file but there was a - # parse error we still need to progress the state of TIs, - # otherwise they might be stuck in queued/running for ever! - DagFileProcessor.execute_callbacks_without_dag(callback_requests, self.UNIT_TEST_MODE) - - dagbag.sync_to_db(self._dag_directory, session=session) - - return len(dagbag.dags), len(dagbag.import_errors), self._cache_last_num_of_db_queries(query_counter) - - def _cache_last_num_of_db_queries(self, query_counter: _QueryCounter | None = None): - if query_counter: - self._last_num_of_db_queries = query_counter.queries_number - return self._last_num_of_db_queries + def wait(self) -> int: + raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow/jobs/dag_processor_job_runner.py b/airflow/jobs/dag_processor_job_runner.py index 28128efba474b..787b6bbab96e1 100644 --- a/airflow/jobs/dag_processor_job_runner.py +++ b/airflow/jobs/dag_processor_job_runner.py @@ -23,7 +23,6 @@ from airflow.jobs.job import Job, perform_heartbeat from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -59,7 +58,7 @@ def __init__( def _execute(self) -> int | None: self.log.info("Starting the Dag Processor Job") try: - self.processor.start() + self.processor.run() except Exception: self.log.exception("Exception when executing DagProcessorJob") raise @@ -68,6 +67,5 @@ def _execute(self) -> int | None: self.processor.end() return None - @provide_session - def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: + def heartbeat_callback(self, session: Session | None = None) -> None: Stats.incr("dag_processor_heartbeat", 1, 1) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 73b5538c22239..3d99c05ed4d68 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -64,7 +64,7 @@ from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning, DagWarningType -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.taskinstance import TaskInstance from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES @@ -876,7 +876,7 @@ def process_executor_events( if task.on_retry_callback or task.on_failure_callback: request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, - simple_task_instance=SimpleTaskInstance.from_ti(ti), + ti=ti, msg=msg, processor_subdir=ti.dag_model.processor_subdir, ) @@ -2089,7 +2089,7 @@ def _purge_zombies(self, zombies: list[tuple[TI, str, str]], *, session: Session request = TaskCallbackRequest( full_filepath=file_loc, processor_subdir=processor_subdir, - simple_task_instance=SimpleTaskInstance.from_ti(ti), + ti=ti, msg=str(zombie_message_details), ) session.add( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 6163580870c4e..43e57035cda22 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -119,6 +119,7 @@ from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.dagbag import DagBag from airflow.models.operator import Operator + from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.typing_compat import Literal log = logging.getLogger(__name__) @@ -1829,7 +1830,7 @@ def create_dagrun( @provide_session def bulk_write_to_db( cls, - dags: Collection[DAG], + dags: Collection[DAG] | Collection[LazyDeserializedDAG], processor_subdir: str | None = None, session: Session = NEW_SESSION, ): @@ -1845,22 +1846,24 @@ def bulk_write_to_db( from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) + dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) # type: ignore[misc] orm_dags = dag_op.add_dags(session=session) dag_op.update_dags(orm_dags, processor_subdir=processor_subdir, session=session) - asset_op = AssetModelOperation.collect(dag_op.dags) + if all(isinstance(dag, DAG) for dag in dags): + # TODO: This is a hack for moving the dag processing + asset_op = AssetModelOperation.collect(dag_op.dags) - orm_assets = asset_op.add_assets(session=session) - orm_asset_aliases = asset_op.add_asset_aliases(session=session) - session.flush() # This populates id so we can create fks in later calls. + orm_assets = asset_op.add_assets(session=session) + orm_asset_aliases = asset_op.add_asset_aliases(session=session) + session.flush() # This populates id so we can create fks in later calls. - orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. - asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) - asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) - asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) - asset_op.add_asset_trigger_references(orm_assets, session=session) + orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. + asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) + asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_asset_trigger_references(orm_assets, session=session) session.flush() @provide_session diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py index 503aeba2b5e19..feca2a931c784 100644 --- a/airflow/models/dagcode.py +++ b/airflow/models/dagcode.py @@ -39,7 +39,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import Select - from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion log = logging.getLogger(__name__) @@ -170,18 +169,19 @@ def get_latest_dagcode(cls, dag_id: str, session: Session = NEW_SESSION) -> DagC @classmethod @provide_session - def update_source_code(cls, dag: DAG, session: Session = NEW_SESSION) -> None: + def update_source_code(cls, dag_id: str, fileloc: str, session: Session = NEW_SESSION) -> None: """ Check if the source code of the DAG has changed and update it if needed. - :param dag: The DAG object. + :param dag_id: Dag ID + :param fileloc: The path of code file to read the code from :param session: The database session. :return: None """ - latest_dagcode = cls.get_latest_dagcode(dag.dag_id, session) + latest_dagcode = cls.get_latest_dagcode(dag_id, session) if not latest_dagcode: return - new_source_code = cls.get_code_from_file(dag.fileloc) + new_source_code = cls.get_code_from_file(fileloc) new_source_code_hash = cls.dag_source_hash(new_source_code) if new_source_code_hash != latest_dagcode.source_code_hash: latest_dagcode.source_code = new_source_code diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 25f42ca8e8f71..7761ca72122d7 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -52,6 +52,7 @@ from airflow.models import Operator from airflow.models.dag import DAG + from airflow.serialization.serialized_objects import LazyDeserializedDAG log = logging.getLogger(__name__) @@ -107,11 +108,17 @@ class SerializedDagModel(Base): load_op_links = True - def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None: + def __init__(self, dag: DAG | LazyDeserializedDAG, processor_subdir: str | None = None) -> None: + from airflow.models.dag import DAG + self.dag_id = dag.dag_id self.processor_subdir = processor_subdir + dag_data = {} + if isinstance(dag, DAG): + dag_data = SerializedDAG.to_dict(dag) + else: + dag_data = dag.data - dag_data = SerializedDAG.to_dict(dag) self.dag_hash = SerializedDagModel.hash(dag_data) # partially ordered json data @@ -159,7 +166,7 @@ def _sort_serialized_dag_dict(cls, serialized_dag: Any): @provide_session def write_dag( cls, - dag: DAG, + dag: DAG | LazyDeserializedDAG, min_update_interval: int | None = None, processor_subdir: str | None = None, session: Session = NEW_SESSION, @@ -342,7 +349,7 @@ def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel @staticmethod @provide_session def bulk_sync_to_db( - dags: list[DAG], + dags: list[DAG] | list[LazyDeserializedDAG], processor_subdir: str | None = None, session: Session = NEW_SESSION, ) -> None: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index fa0984f6d5f1e..9b3ab7fdbc63a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -28,10 +28,11 @@ from functools import cache from inspect import signature from textwrap import dedent -from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Union, cast import attrs import lazy_object_proxy +import pydantic from dateutil import relativedelta from pendulum.tz.timezone import FixedTimezone, Timezone @@ -1870,3 +1871,55 @@ def _has_kubernetes() -> bool: except ImportError: HAS_KUBERNETES = False return HAS_KUBERNETES + + +class LazyDeserializedDAG(pydantic.BaseModel): + """ + Lazily build information from the serialized DAG structure. + + An object that will present "enough" of the DAG like interface to update DAG db models etc, without having + to deserialize the full DAG and Task hierarchy. + """ + + data: dict + + NULLABLE_PROPERTIES: ClassVar[set[str]] = { + "is_paused_upon_creation", + "owner", + "dag_display_name", + "description", + "max_active_tasks", + "max_active_runs", + "max_consecutive_failed_dag_runs", + "owner_links", + } + + @property + def hash(self) -> str: + from airflow.models.serialized_dag import SerializedDagModel + + return SerializedDagModel.hash(self.data) + + def next_dagrun_info(self, last): + return None + + def __getattr__(self, name: str, /) -> Any: + if name in self.NULLABLE_PROPERTIES: + return self.data["dag"].get(name) + try: + return self.data["dag"][name] + except KeyError: + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") from None + + @property + def timetable(self): + return decode_timetable(self.data["dag"]["timetable"]) + + @property + def has_task_concurrency_limits(self): + return any(task.get("max_active_tis_per_dag") is not None for task in self.data["dag"]["tasks"]) + + if TYPE_CHECKING: + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = pydantic.Field( + init=False, default=None + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 5eb3883a0fcd8..64dfb54010372 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -32,7 +32,17 @@ from datetime import datetime, timezone from http import HTTPStatus from socket import socket, socketpair -from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, TextIO, cast, overload +from typing import ( + TYPE_CHECKING, + BinaryIO, + Callable, + ClassVar, + Literal, + NoReturn, + TextIO, + cast, + overload, +) from uuid import UUID import attrs @@ -334,7 +344,7 @@ def start( cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) proc = cls( - id=constructor_kwargs.get("id") or getattr(what, "id"), + id=constructor_kwargs.pop("id") or getattr(what, "id"), pid=pid, stdin=feed_stdin, process=psutil.Process(pid), @@ -681,7 +691,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N self._handle_request(msg, log) - def _handle_request(self, msg, log): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = None if isinstance(msg, TaskState): self._terminal_state = msg.state diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 92f400d46e2bb..fce8bc185d9ce 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -303,7 +303,7 @@ def main(): # TODO: add an exception here, it causes an oof of a stack trace! global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder(input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) try: ti, log = startup() run(ti, log) diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index 0873233f5368c..05c220cc6b2d8 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -324,19 +324,15 @@ def configure_logging( }, }, "loggers": { + # Set Airflow logging to the level requested, but most everything else at "INFO" "": { "handlers": ["to_supervisor" if output else "default"], - "level": log_level.upper(), + "level": "INFO", "propagate": True, }, - # Some modules we _never_ want at debug level - "alembic": {"level": "INFO"}, - "asyncio": {"level": "INFO"}, - "cron_descriptor.GetText": {"level": "INFO"}, - "httpcore": {"level": "INFO"}, - "httpx": {"level": "WARN"}, - "psycopg.pq": {"level": "INFO"}, + "airflow": {"level": log_level.upper()}, # These ones are too chatty even at info + "httpx": {"level": "WARN"}, "sqlalchemy.engine": {"level": "WARN"}, }, } diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index ea26dc13bf3db..c5a697f9467da 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -26,7 +26,7 @@ TaskCallbackRequest, ) from airflow.models.dag import DAG -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.bash import BashOperator from airflow.utils import timezone from airflow.utils.state import State @@ -71,11 +71,11 @@ def test_from_json(self, input, request_class): input = TaskCallbackRequest( full_filepath="filepath", - simple_task_instance=SimpleTaskInstance.from_ti(ti=ti), + ti=ti, processor_subdir="/test_dir", ) json_str = input.to_json() - result = request_class.from_json(json_str=json_str) + result = request_class.from_json(json_str) assert result == input def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create_task_instance): @@ -86,7 +86,7 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create session.flush() input = TaskCallbackRequest( full_filepath="filepath", - simple_task_instance=SimpleTaskInstance.from_ti(ti), + ti=ti, processor_subdir="/test_dir", ) json_str = input.to_json() @@ -99,22 +99,19 @@ def test_simple_ti_roundtrip_exec_config_pod(self): from airflow.callbacks.callback_requests import TaskCallbackRequest from airflow.models import TaskInstance - from airflow.models.taskinstance import SimpleTaskInstance from airflow.providers.standard.operators.bash import BashOperator test_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="hello", namespace="ns")) op = BashOperator(task_id="hi", executor_config={"pod_override": test_pod}, bash_command="hi") - ti = TaskInstance(task=op) - s = SimpleTaskInstance.from_ti(ti) - data = TaskCallbackRequest("hi", s).to_json() - actual = TaskCallbackRequest.from_json(data).simple_task_instance.executor_config["pod_override"] + ti = TaskInstance(task=op, run_id="run1") + data = TaskCallbackRequest(full_filepath="hi", ti=ti).to_json() + actual = TaskCallbackRequest.from_json(data).ti.executor_config["pod_override"] assert actual == test_pod def test_simple_ti_roundtrip_dates(self, dag_maker): """A callback request including a TI with an exec config with a V1Pod should safely roundtrip.""" from airflow.callbacks.callback_requests import TaskCallbackRequest from airflow.models import TaskInstance - from airflow.models.taskinstance import SimpleTaskInstance from airflow.providers.standard.operators.bash import BashOperator with dag_maker(schedule=timedelta(weeks=1), serialized=True): @@ -128,7 +125,6 @@ def test_simple_ti_roundtrip_dates(self, dag_maker): ti.set_state("SUCCESS") start_date = ti.start_date end_date = ti.end_date - s = SimpleTaskInstance.from_ti(ti) - data = TaskCallbackRequest("hi", s).to_json() - assert TaskCallbackRequest.from_json(data).simple_task_instance.start_date == start_date - assert TaskCallbackRequest.from_json(data).simple_task_instance.end_date == end_date + data = TaskCallbackRequest(full_filepath="hi", ti=ti).to_json() + assert TaskCallbackRequest.from_json(data).ti.start_date == start_date + assert TaskCallbackRequest.from_json(data).ti.end_date == end_date diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index bbd75361e7970..3373e43363959 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -38,6 +38,7 @@ update_dag_parsing_results_in_db, ) from airflow.exceptions import SerializationError +from airflow.listeners.listener import get_listener_manager from airflow.models import DagModel, Trigger from airflow.models.asset import ( AssetActive, @@ -167,6 +168,15 @@ def clean_db(self, session): clear_db_dags() clear_db_import_errors() + @pytest.fixture(name="dag_import_error_listener") + def _dag_import_error_listener(self): + from tests.listeners import dag_import_error_listener + + get_listener_manager().add_listener(dag_import_error_listener) + yield dag_import_error_listener + get_listener_manager().clear() + dag_import_error_listener.clear() + @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene def test_sync_perms_syncs_dag_specific_perms_on_update( self, monkeypatch, spy_agency: SpyAgency, session, time_machine @@ -272,7 +282,9 @@ def test_serialized_dags_are_written_to_db_on_sync(self, session): assert new_serialized_dags_count == 1 @patch.object(SerializedDagModel, "write_dag") - def test_serialized_dag_errors_are_import_errors(self, mock_serialize, caplog, session): + def test_serialized_dag_errors_are_import_errors( + self, mock_serialize, caplog, session, dag_import_error_listener + ): """ Test that errors serializing a DAG are recorded as import_errors in the DB """ @@ -287,7 +299,7 @@ def test_serialized_dag_errors_are_import_errors(self, mock_serialize, caplog, s update_dag_parsing_results_in_db([dag], import_errors, None, set(), session) assert "SerializationError" in caplog.text - # Should have been edited in places + # Should have been edited in place err = import_errors.get(dag.fileloc) assert "SerializationError" in err @@ -301,7 +313,12 @@ def test_serialized_dag_errors_are_import_errors(self, mock_serialize, caplog, s assert import_error.filename == dag.fileloc assert "SerializationError" in import_error.stacktrace - def test_new_import_error_replaces_old(self, session): + # Ensure the listener was notified + assert len(dag_import_error_listener.new) == 1 + assert len(dag_import_error_listener.existing) == 0 + assert dag_import_error_listener.new["abc.py"] == import_error.stacktrace + + def test_new_import_error_replaces_old(self, session, dag_import_error_listener): """ Test that existing import error is updated and new record not created for a dag with the same filename @@ -331,6 +348,11 @@ def test_new_import_error_replaces_old(self, session): assert import_error.id == prev_error_id assert import_error.stacktrace == "New error" + # Ensure the listener was notified + assert len(dag_import_error_listener.new) == 0 + assert len(dag_import_error_listener.existing) == 1 + assert dag_import_error_listener.existing["abc.py"] == prev_error.stacktrace + def test_remove_error_clears_import_error(self, session): # Pre-condition: there is an import error for the dag file filename = "abc.py" diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index fedc15a7437fb..339bea7871e9a 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -17,37 +17,36 @@ # under the License. from __future__ import annotations -import contextlib +import io import itertools import logging import multiprocessing import os import pathlib import random +import signal import socket -import sys import textwrap import threading import time from collections import deque +from contextlib import suppress from datetime import datetime, timedelta from logging.config import dictConfig from unittest import mock -from unittest.mock import MagicMock, Mock, PropertyMock +from unittest.mock import MagicMock, Mock import pytest import time_machine from sqlalchemy import func +from uuid6 import uuid7 from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.configuration import conf from airflow.dag_processing.manager import ( DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, - DagParsingSignal, - DagParsingStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess from airflow.models import DagBag, DagModel, DbCallbackRequest @@ -56,9 +55,9 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import timezone from airflow.utils.net import get_hostname +from airflow.utils.process_utils import reap_process_group from airflow.utils.session import create_session -from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context from tests.models import TEST_DAGS_FOLDER from tests_common.test_utils.compat import ParseImportError from tests_common.test_utils.config import conf_vars @@ -77,77 +76,36 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) -class FakeDagFileProcessorRunner(DagFileProcessorProcess): - # This fake processor will return the zombies it received in constructor - # as its processing result w/o actually parsing anything. - def __init__(self, file_path, dag_directory, callbacks): - super().__init__(file_path, dag_directory, callbacks) - # We need a "real" selectable handle for waitable_handle to work - readable, writable = multiprocessing.Pipe(duplex=False) - writable.send("abc") - writable.close() - self._waitable_handle = readable - self._result = 0, 0, 0 - - def start(self): - pass - - @property - def start_time(self): - return DEFAULT_DATE - - @property - def pid(self): - return 1234 - - @property - def done(self): - return True - - @property - def result(self): - return self._result - - @staticmethod - def _create_process(file_path, callback_requests, dag_directory): - return FakeDagFileProcessorRunner( - file_path, - dag_directory, - callback_requests, - ) - - @property - def waitable_handle(self): - return self._waitable_handle - - -class TestDagProcessorJobRunner: +class TestDagFileProcessorManager: def setup_method(self): dictConfig(DEFAULT_LOGGING_CONFIG) clear_db_runs() clear_db_serialized_dags() clear_db_dags() clear_db_callbacks() + clear_db_import_errors() def teardown_class(self): clear_db_runs() clear_db_serialized_dags() clear_db_dags() clear_db_callbacks() + clear_db_import_errors() - def run_processor_manager_one_loop(self, processor, parent_pipe): - results = [] - - while True: - processor._run_parsing_loop() + def run_processor_manager_one_loop(self, manager: DagFileProcessorManager) -> None: + manager._run_parsing_loop() - while parent_pipe.poll(timeout=0.01): - obj = parent_pipe.recv() - if not isinstance(obj, DagParsingStat): - results.append(obj) - elif obj.done: - return results - raise RuntimeError("Shouldn't get here - nothing to read, but manager not finished!") + def mock_processor(self) -> DagFileProcessorProcess: + proc = MagicMock() + proc.create_time.return_value = time.time() + proc.wait.return_value = 0 + return DagFileProcessorProcess( + id=uuid7(), + pid=1234, + process=proc, + client=Mock(), + stdin=io.BytesIO(), + ) @pytest.fixture def clear_parse_import_errors(self): @@ -161,17 +119,14 @@ def test_remove_file_clears_import_error(self, tmp_path): # Generate original import error path_to_parse.write_text("an invalid airflow DAG") - child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( dag_directory=path_to_parse.parent, max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, + processor_timeout=365 * 86_400, ) with create_session() as session: - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) import_errors = session.query(ParseImportError).all() assert len(import_errors) == 1 @@ -179,42 +134,24 @@ def test_remove_file_clears_import_error(self, tmp_path): path_to_parse.unlink() # Rerun the parser once the dag file has been removed - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) import_errors = session.query(ParseImportError).all() assert len(import_errors) == 0 session.rollback() - child_pipe.close() - parent_pipe.close() - @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): - child_pipe, parent_pipe = multiprocessing.Pipe() + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - manager = DagFileProcessorManager( - dag_directory=os.fspath(tmp_path), - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - ) - - self.run_processor_manager_one_loop(manager, parent_pipe) - child_pipe.close() - parent_pipe.close() + self.run_processor_manager_one_loop(manager) - @mock.patch("airflow.dag_processing.processor.iter_airflow_imports") - def test_start_new_processes_with_same_filepath(self, _): + def test_start_new_processes_with_same_filepath(self): """ Test that when a processor already exist with a filepath, a new processor won't be created with that filepath. The filepath will just be removed from the list. """ - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) file_1 = "file_1.py" file_2 = "file_2.py" @@ -224,7 +161,7 @@ def test_start_new_processes_with_same_filepath(self, _): # Mock that only one processor exists. This processor runs with 'file_1' manager._processors[file_1] = MagicMock() # Start New Processes - manager.start_new_processes() + manager._start_new_processes() # Because of the config: '[scheduler] parsing_processes = 2' # verify that only one extra process is created @@ -237,12 +174,7 @@ def test_start_new_processes_with_same_filepath(self, _): assert deque([file_3]) == manager._file_path_queue def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") @@ -256,12 +188,7 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): assert "missing_file.txt" not in manager._file_stats def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) mock_processor = MagicMock() mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") @@ -284,12 +211,7 @@ def test_file_paths_in_queue_sorted_alphabetically( dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -308,12 +230,7 @@ def test_file_paths_in_queue_sorted_random_seeded_by_host( dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -367,12 +284,7 @@ def test_file_paths_in_queue_sorted_by_modified_time( mock_getmtime.side_effect = list(paths_with_mtime.values()) mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -399,12 +311,7 @@ def test_file_paths_in_queue_excludes_missing_file( mock_getmtime.side_effect = [1.0, 2.0, FileNotFoundError()] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() @@ -430,12 +337,7 @@ def test_add_new_file_to_parsing_queue( mock_getmtime.side_effect = [1.0, 2.0, 3.0] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() @@ -470,17 +372,12 @@ def test_recently_modified_file_is_parsed_with_mtime_mode( mock_getmtime.side_effect = [initial_file_1_mtime] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=3, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=3) # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) manager._file_stats = { - "file_1.py": DagFileStat(1, 0, last_finish_time, timedelta(seconds=1.0), 1, 1), + "file_1.py": DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), } with time_machine.travel(freezed_base_time): manager.set_file_paths(dag_files) @@ -525,12 +422,7 @@ def test_file_paths_in_queue_sorted_by_priority( dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] mock_find_path.return_value = dag_files - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) manager._file_path_queue = deque(["file_2.py", "file_3.py", "file_4.py", "file_1.py"]) @@ -548,8 +440,7 @@ def test_scan_stale_dags(self): manager = DagFileProcessorManager( dag_directory="directory", max_runs=1, - processor_timeout=timedelta(minutes=10), - signal_conn=MagicMock(), + processor_timeout=10 * 60, ) test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") @@ -611,12 +502,7 @@ def test_scan_stale_dags_standalone_mode(self): Ensure only dags from current dag_directory are updated """ dag_directory = "directory" - manager = DagFileProcessorManager( - dag_directory=dag_directory, - max_runs=1, - processor_timeout=timedelta(minutes=10), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory=dag_directory, max_runs=1) test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") dagbag = DagBag(test_dag_path, read_dags_from_db=False) @@ -654,55 +540,30 @@ def test_scan_stale_dags_standalone_mode(self): active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() assert active_dag_count == 1 - @mock.patch( - "airflow.dag_processing.processor.DagFileProcessorProcess.waitable_handle", new_callable=PropertyMock - ) - @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) - @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill") - def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid, mock_waitable_handle): - mock_pid.return_value = 1234 - mock_waitable_handle.return_value = 3 - manager = DagFileProcessorManager( - dag_directory="directory", - max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=MagicMock(), - ) + def test_kill_timed_out_processors_kill(self): + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1, processor_timeout=5) - processor = DagFileProcessorProcess( - file_path="abc.txt", - dag_directory=TEST_DAG_FOLDER, - callback_requests=[], - ) - processor._start_time = timezone.make_aware(datetime.min) + processor = self.mock_processor() + processor._process.create_time.return_value = timezone.make_aware(datetime.min).timestamp() manager._processors = {"abc.txt": processor} - manager.waitables[3] = processor - initial_waitables = len(manager.waitables) - manager._kill_timed_out_processors() - mock_kill.assert_called_once_with() + with mock.patch.object(type(processor), "kill") as mock_kill: + manager._kill_timed_out_processors() + mock_kill.assert_called_once_with(signal.SIGKILL) assert len(manager._processors) == 0 - assert len(manager.waitables) == initial_waitables - 1 - @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock) - @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess") - def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid): - mock_pid.return_value = 1234 + def test_kill_timed_out_processors_no_kill(self): manager = DagFileProcessorManager( dag_directory=TEST_DAG_FOLDER, max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=MagicMock(), + processor_timeout=5, ) - processor = DagFileProcessorProcess( - file_path="abc.txt", - dag_directory=str(TEST_DAG_FOLDER), - callback_requests=[], - ) - processor._start_time = timezone.make_aware(datetime.max) + processor = self.mock_processor() + processor._process.create_time.return_value = timezone.make_aware(datetime.max).timestamp() manager._processors = {"abc.txt": processor} - manager._kill_timed_out_processors() - mock_dag_file_processor.kill.assert_not_called() + with mock.patch.object(type(processor), "kill") as mock_kill: + manager._kill_timed_out_processors() + mock_kill.assert_not_called() @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) @@ -718,23 +579,10 @@ def test_dag_with_system_exit(self): clear_db_dags() clear_db_serialized_dags() - child_pipe, parent_pipe = multiprocessing.Pipe() - - manager = DagFileProcessorManager( - dag_directory=dag_directory, - max_runs=1, - processor_timeout=timedelta(seconds=5), - signal_conn=child_pipe, - ) + manager = DagFileProcessorManager(dag_directory=dag_directory, max_runs=1) manager._run_parsing_loop() - result = None - while parent_pipe.poll(timeout=None): - result = parent_pipe.recv() - if isinstance(result, DagParsingStat) and result.done: - break - # Three files in folder should be processed assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 @@ -758,31 +606,17 @@ def test_import_error_with_dag_directory(self, tmp_path): f.write("an invalid airflow DAG") with create_session() as session: - child_pipe, parent_pipe = multiprocessing.Pipe() + manager = DagFileProcessorManager(dag_directory=processor_dir_1, max_runs=1) - manager = DagFileProcessorManager( - dag_directory=processor_dir_1, - max_runs=1, - signal_conn=child_pipe, - processor_timeout=timedelta(seconds=5), - ) - - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) import_errors = session.query(ParseImportError).order_by("id").all() assert len(import_errors) == 1 assert import_errors[0].processor_subdir == str(processor_dir_1) - child_pipe, parent_pipe = multiprocessing.Pipe() + manager = DagFileProcessorManager(dag_directory=processor_dir_2, max_runs=1) - manager = DagFileProcessorManager( - dag_directory=processor_dir_2, - max_runs=1, - signal_conn=child_pipe, - processor_timeout=timedelta(seconds=5), - ) - - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) import_errors = session.query(ParseImportError).order_by("id").all() assert len(import_errors) == 2 @@ -793,8 +627,7 @@ def test_import_error_with_dag_directory(self, tmp_path): @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(30) - @mock.patch("airflow.dag_processing.manager.DagFileProcessorProcess") - def test_pipe_full_deadlock(self, mock_processor): + def test_pipe_full_deadlock(self): dag_filepath = TEST_DAG_FOLDER / "test_scheduler_dags.py" child_pipe, parent_pipe = multiprocessing.Pipe() @@ -815,7 +648,7 @@ def keep_pipe_full(pipe, exit_event): if exit_event.is_set(): break - req = CallbackRequest(str(dag_filepath)) + req = CallbackRequest(full_filepath=dag_filepath.as_posix()) logger.info("Sending CallbackRequests %d", n) try: pipe.send(req) @@ -829,29 +662,25 @@ def keep_pipe_full(pipe, exit_event): thread = threading.Thread(target=keep_pipe_full, args=(parent_pipe, exit_event)) - fake_processors = [] - - def fake_processor_(*args, **kwargs): - nonlocal fake_processors - processor = FakeDagFileProcessorRunner._create_process(*args, **kwargs) - fake_processors.append(processor) - return processor - - mock_processor.side_effect = fake_processor_ - manager = DagFileProcessorManager( dag_directory=dag_filepath, # A reasonable large number to ensure that we trigger the deadlock max_runs=100, - processor_timeout=timedelta(seconds=5), + processor_timeout=5, signal_conn=child_pipe, + # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes + # too quickly + file_process_interval=0.01, ) try: thread.start() # If this completes without hanging, then the test is good! - manager._run_parsing_loop() + with mock.patch.object( + DagFileProcessorProcess, "start", side_effect=lambda *args, **kwargs: self.mock_processor() + ): + manager.run() exit_event.set() finally: logger.info("Closing pipes") @@ -868,55 +697,35 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): path_to_parse = tmp_path / "temp_dag.py" dag_code = textwrap.dedent( """ - from airflow import DAG - dag = DAG(dag_id='temp_dag', schedule='0 0 * * *') - """ + from airflow import DAG + dag = DAG(dag_id='temp_dag') + """ ) path_to_parse.write_text(dag_code) - child_pipe, parent_pipe = multiprocessing.Pipe() + manager = DagFileProcessorManager(dag_directory=path_to_parse.parent, max_runs=1) - manager = DagFileProcessorManager( - dag_directory=path_to_parse.parent, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - ) - - self.run_processor_manager_one_loop(manager, parent_pipe) - last_runtime = manager._file_stats[manager.file_paths[0]].last_duration - assert last_runtime is not None - - child_pipe.close() - parent_pipe.close() + self.run_processor_manager_one_loop(manager) + last_runtime = manager._file_stats[os.fspath(path_to_parse)].last_duration statsd_timing_mock.assert_has_calls( [ - mock.call("dag_processing.last_duration.temp_dag", last_runtime * 1000.0), - mock.call( - "dag_processing.last_duration", - last_runtime * 1000.0, - tags={"file_name": "temp_dag"}, - ), + mock.call("dag_processing.last_duration.temp_dag", last_runtime), + mock.call("dag_processing.last_duration", last_runtime, tags={"file_name": "temp_dag"}), ], any_order=True, ) def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = timezone.utcnow() - timedelta(minutes=10) + manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 manager._refresh_dag_dir() # Assert dag not deleted in SDM assert SerializedDagModel.has_dag("test_zip_dag") @@ -927,19 +736,14 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) + manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = timezone.utcnow() - timedelta(minutes=10) + manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 # Mock might_contain_dag to mimic deleting the python file from the zip with mock.patch("airflow.dag_processing.manager.might_contain_dag", return_value=False): @@ -968,13 +772,8 @@ def test_refresh_dags_dir_does_not_interfer_with_dags_outside_its_subdir(self, t assert dag.get_is_active() assert DagCode.has_dag(dag.dag_id) - manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER / "subdir2" / "subdir3", - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), - ) - manager.last_dag_dir_refresh_time = timezone.utcnow() - timedelta(minutes=10) + manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER / "subdir2" / "subdir3", max_runs=1) + manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 manager._refresh_dag_dir() @@ -989,7 +788,6 @@ def test_refresh_dags_dir_does_not_interfer_with_dags_outside_its_subdir(self, t } ) def test_fetch_callbacks_from_database(self, tmp_path): - """Test DagFileProcessorManager._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -1011,16 +809,12 @@ def test_fetch_callbacks_from_database(self, tmp_path): session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - child_pipe, parent_pipe = multiprocessing.Pipe() manager = DagFileProcessorManager( - dag_directory=os.fspath(tmp_path), - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, + dag_directory=os.fspath(tmp_path), max_runs=1, standalone_dag_processor=True ) with create_session() as session: - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) assert session.query(DbCallbackRequest).count() == 0 @conf_vars( @@ -1052,16 +846,10 @@ def test_fetch_callbacks_for_current_dag_directory_only(self, tmp_path): session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=tmp_path, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - ) + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) with create_session() as session: - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) assert session.query(DbCallbackRequest).count() == 1 @conf_vars( @@ -1086,20 +874,14 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=str(tmp_path), - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - ) + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) with create_session() as session: - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) assert session.query(DbCallbackRequest).count() == 3 with create_session() as session: - self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) assert session.query(DbCallbackRequest).count() == 1 @conf_vars( @@ -1121,18 +903,10 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=10)) - child_pipe, parent_pipe = multiprocessing.Pipe() - manager = DagFileProcessorManager( - dag_directory=tmp_path, - max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=child_pipe, - ) + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - with create_session() as session: - results = self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager) - assert (len(results)) == 0 # Verify no callbacks removed from database. with create_session() as session: assert session.query(DbCallbackRequest).count() == 1 @@ -1142,8 +916,7 @@ def test_callback_queue(self, tmp_path): manager = DagFileProcessorManager( dag_directory=TEST_DAG_FOLDER, max_runs=1, - processor_timeout=timedelta(days=365), - signal_conn=MagicMock(), + processor_timeout=365 * 86_400, ) dag1_req1 = DagCallbackRequest( @@ -1151,7 +924,7 @@ def test_callback_queue(self, tmp_path): dag_id="dag1", run_id="run1", is_failure_callback=False, - processor_subdir=tmp_path, + processor_subdir=tmp_path.as_posix(), msg=None, ) dag1_req2 = DagCallbackRequest( @@ -1159,7 +932,7 @@ def test_callback_queue(self, tmp_path): dag_id="dag1", run_id="run1", is_failure_callback=False, - processor_subdir=tmp_path, + processor_subdir=tmp_path.as_posix(), msg=None, ) @@ -1168,7 +941,7 @@ def test_callback_queue(self, tmp_path): dag_id="dag2", run_id="run1", is_failure_callback=False, - processor_subdir=tmp_path, + processor_subdir=tmp_path.as_posix(), msg=None, ) @@ -1197,85 +970,31 @@ def test_callback_queue(self, tmp_path): dag1_req2, ] - -def _wait_for_processor_agent_to_complete(processor_agent: DagFileProcessorAgent): - start_timer = time.monotonic() - while time.monotonic() - start_timer < 10: - if processor_agent.done and all( - [processor.done for processor in processor_agent._processors.values()] - ): - break - processor_agent.heartbeat() - time.sleep(0.1) + with mock.patch.object( + DagFileProcessorProcess, "start", side_effect=lambda *args, **kwargs: self.mock_processor() + ) as start: + manager._start_new_processes() + # Callbacks passed to process ctor + start.assert_any_call( + id=mock.ANY, path=dag1_req1.full_filepath, callbacks=[dag1_req1, dag1_req2], selector=mock.ANY + ) + # And removed from the queue + assert dag1_req1.full_filepath not in manager._callback_to_execute class TestDagFileProcessorAgent: - def setup_method(self): - # Make sure that the configure_logging is not cached - self.old_modules = dict(sys.modules) - - def teardown_method(self): - # Remove any new modules imported during the test run. This lets us - # import the same source files for more than one test. - remove_list = [] - for mod in sys.modules: - if mod not in self.old_modules: - remove_list.append(mod) - - for mod in remove_list: - del sys.modules[mod] - - def test_reload_module(self): - """ - Configure the context to have logging.logging_config_class set to a fake logging - class path, thus when reloading logging module the airflow.processor_manager - logger should not be configured. - """ - with settings_context(SETTINGS_FILE_VALID): - # Launch a process through DagFileProcessorAgent, which will try - # reload the logging module. - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - log_file_loc = conf.get("logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION") - - with contextlib.suppress(OSError): - os.remove(log_file_loc) - - # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) - processor_agent.start() - - processor_agent._process.join() - # Since we are reloading logging config not creating this file, - # we should expect it to be nonexistent. - - assert not os.path.isfile(log_file_loc) - - @conf_vars({("core", "load_examples"): "False"}) - def test_parse_once(self): - clear_db_serialized_dags() - clear_db_dags() - - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - processor_agent = DagFileProcessorAgent(test_dag_path, 1, timedelta(days=365)) - processor_agent.start() - while not processor_agent.done: - processor_agent.heartbeat() - - assert processor_agent.all_files_processed - assert processor_agent.done - - with create_session() as session: - dag_ids = session.query(DagModel.dag_id).order_by("dag_id").all() - assert dag_ids == [("test_start_date_scheduling",), ("test_task_start_date_scheduling",)] - - dag_ids = session.query(SerializedDagModel.dag_id).order_by("dag_id").all() - assert dag_ids == [("test_start_date_scheduling",), ("test_task_start_date_scheduling",)] + @pytest.fixture(autouse=True) + def _disable_examples(self): + with conf_vars({("core", "load_examples"): "False"}): + yield def test_launch_process(self): + from airflow.configuration import conf + test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" log_file_loc = conf.get("logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION") - with contextlib.suppress(OSError): + with suppress(OSError): os.remove(log_file_loc) # Starting dag processing with 0 max_runs to avoid redundant operations. @@ -1338,17 +1057,13 @@ def test_process_message_invalid_type(self): with pytest.raises(RuntimeError, match="Unexpected message received of type str"): processor_agent._process_message(message) - def test_heartbeat_manager(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) - processor_agent._parent_signal_conn = None - with pytest.raises(ValueError, match="Process not started"): - processor_agent._heartbeat_manager() - @mock.patch("airflow.utils.process_utils.reap_process_group") - def test_heartbeat_manager_process_restart(self, mock_pg): + def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = MagicMock() + monkeypatch.setattr(processor_agent._process, "pid", 1234) + monkeypatch.setattr(processor_agent._process, "exitcode", 1) processor_agent.start = Mock() processor_agent._process.is_alive.return_value = False with mock.patch.object(processor_agent._process, "join"): @@ -1380,27 +1095,6 @@ def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock processor_agent.log.error.assert_called() processor_agent.start.assert_called() - def test_heartbeat_manager_terminate(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) - processor_agent._parent_signal_conn = Mock() - processor_agent._process = Mock() - processor_agent._process.is_alive.return_value = True - processor_agent.log.info = Mock() - - processor_agent.terminate() - processor_agent._parent_signal_conn.send.assert_called_with(DagParsingSignal.TERMINATE_MANAGER) - - def test_heartbeat_manager_terminate_conn_err(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) - processor_agent._process = Mock() - processor_agent._process.is_alive.return_value = True - processor_agent._parent_signal_conn = Mock() - processor_agent._parent_signal_conn.send.side_effect = ConnectionError - processor_agent.log.info = Mock() - - processor_agent.terminate() - processor_agent._parent_signal_conn.send.assert_called_with(DagParsingSignal.TERMINATE_MANAGER) - def test_heartbeat_manager_end_no_process(self): processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._process = Mock() @@ -1412,6 +1106,21 @@ def test_heartbeat_manager_end_no_process(self): processor_agent.log.warning.assert_called_with("Ending without manager process.") processor_agent._process.join.assert_not_called() + @pytest.mark.execution_timeout(5) + def test_terminate(self, tmp_path): + processor_agent = DagFileProcessorAgent(tmp_path, -1, timedelta(days=365)) + + processor_agent.start() + try: + processor_agent.terminate() + + processor_agent._process.join(timeout=1) + assert processor_agent._process.is_alive() is False + assert processor_agent._process.exitcode == 0 + except Exception: + reap_process_group(processor_agent._process.pid, logger=logger) + raise + @conf_vars({("logging", "dag_processor_manager_log_stdout"): "True"}) def test_log_to_stdout(self, capfd): test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" @@ -1421,7 +1130,6 @@ def test_log_to_stdout(self, capfd): processor_agent.start() processor_agent._process.join() - _wait_for_processor_agent_to_complete(processor_agent) # Capture the stdout and stderr out, _ = capfd.readouterr() @@ -1436,7 +1144,6 @@ def test_not_log_to_stdout(self, capfd): processor_agent.start() processor_agent._process.join() - _wait_for_processor_agent_to_complete(processor_agent) # Capture the stdout and stderr out, _ = capfd.readouterr() diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index d3ecd98b91680..f90475a2e96ff 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -19,36 +19,31 @@ import pathlib import sys -from unittest import mock -from unittest.mock import MagicMock, patch -from zipfile import ZipFile +from typing import TYPE_CHECKING +from unittest.mock import patch import pytest - -from airflow.callbacks.callback_requests import TaskCallbackRequest -from airflow.configuration import TEST_DAGS_FOLDER, conf -from airflow.dag_processing.processor import DagFileProcessor, DagFileProcessorProcess +import structlog + +from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest +from airflow.configuration import conf +from airflow.dag_processing.processor import ( + DagFileParseRequest, + DagFileParsingResult, + _parse_file, +) from airflow.models import DagBag, TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import State -from airflow.utils.types import DagRunType +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunTriggeredByType, DagRunType -from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.config import conf_vars, env_vars -from tests_common.test_utils.db import ( - clear_db_dags, - clear_db_jobs, - clear_db_runs, - clear_db_serialized_dags, -) -from tests_common.test_utils.mock_executor import MockExecutor -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType +if TYPE_CHECKING: + from kgb import SpyAgency pytestmark = pytest.mark.db_test @@ -75,63 +70,15 @@ def disable_load_example(): @pytest.mark.usefixtures("disable_load_example") class TestDagFileProcessor: - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_dags() - clear_db_jobs() - clear_db_serialized_dags() - - def setup_class(self): - self.clean_db() - - def setup_method(self): - # Speed up some tests by not running the tasks, just look at what we - # enqueue! - self.null_exec = MockExecutor() - self.scheduler_job = None - - def teardown_method(self) -> None: - if self.scheduler_job and self.scheduler_job.job_runner.processor_agent: - self.scheduler_job.job_runner.processor_agent.end() - self.scheduler_job = None - self.clean_db() - - def _process_file(self, file_path, dag_directory, session): - dag_file_processor = DagFileProcessor(dag_directory=str(dag_directory), log=mock.MagicMock()) - - dag_file_processor.process_file(file_path, []) - - @patch.object(TaskInstance, "handle_failure") - def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): - dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()) - with create_session() as session: - session.query(TaskInstance).delete() - dag = dagbag.get_dag("example_branch_operator") - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - state=State.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), - session=session, - **triggered_by_kwargs, - ) - task = dag.get_task(task_id="run_this_first") - ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING) - session.add(ti) - - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks(dagbag, requests, dag_file_processor.UNIT_TEST_MODE, session) - mock_ti_handle_failure.assert_called_once_with( - error="Message", test_mode=conf.getboolean("core", "unit_test_mode"), session=session + def _process_file( + self, file_path, callback_requests: list[CallbackRequest] | None = None + ) -> DagFileParsingResult: + return _parse_file( + DagFileParseRequest(file=file_path, requests_fd=1, callback_requests=callback_requests or []), + log=structlog.get_logger(), ) + @pytest.mark.xfail(reason="TODO: AIP-72") @pytest.mark.parametrize( ["has_serialized_dag"], [pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")], @@ -139,179 +86,142 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): @patch.object(TaskInstance, "handle_failure") def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag): dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()) with create_session() as session: session.query(TaskInstance).delete() dag = dagbag.get_dag("example_branch_operator") + assert dag is not None dag.sync_to_db() - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dagrun = dag.create_dagrun( - state=State.RUNNING, + state=DagRunState.RUNNING, logical_date=DEFAULT_DATE, run_type=DagRunType.SCHEDULED, data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), + triggered_by=DagRunTriggeredByType.TEST, session=session, - **triggered_by_kwargs, ) task = dag.get_task(task_id="run_this_first") - ti = TaskInstance(task, run_id=dagrun.run_id, state=State.QUEUED) + ti = TaskInstance(task, run_id=dagrun.run_id, state=TaskInstanceState.QUEUED) session.add(ti) if has_serialized_dag: assert SerializedDagModel.write_dag(dag, session=session) is True session.flush() - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks_without_dag(requests, True, session) + requests = [TaskCallbackRequest(full_filepath="A", ti=ti, msg="Message")] + self._process_file(dag.fileloc, requests) mock_ti_handle_failure.assert_called_once_with( error="Message", test_mode=conf.getboolean("core", "unit_test_mode"), session=session ) - def test_failure_callbacks_should_not_drop_hostname(self): - dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()) - dag_file_processor.UNIT_TEST_MODE = False + def test_dagbag_import_errors_captured(self, spy_agency: SpyAgency): + @spy_agency.spy_for(DagBag.collect_dags, owner=DagBag) + def fake_collect_dags(dagbag: DagBag, *args, **kwargs): + dagbag.import_errors["a.py"] = "Import error" + + resp = self._process_file("a.py") + + assert not resp.serialized_dags + assert resp.import_errors is not None + assert "a.py" in resp.import_errors + + +# @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) +# @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) +# @mock.patch("airflow.dag_processing.processor.redirect_stdout") +# def test_dag_parser_output_when_logging_to_stdout(self, mock_redirect_stdout_for_file): +# processor = DagFileProcessorProcess( +# file_path="abc.txt", +# dag_directory=[], +# callback_requests=[], +# ) +# processor._run_file_processor( +# result_channel=MagicMock(), +# parent_channel=MagicMock(), +# file_path="fake_file_path", +# thread_name="fake_thread_name", +# callback_requests=[], +# dag_directory=[], +# ) +# mock_redirect_stdout_for_file.assert_not_called() +# +# @conf_vars({("logging", "dag_processor_log_target"): "file"}) +# @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) +# @mock.patch("airflow.dag_processing.processor.redirect_stdout") +# def test_dag_parser_output_when_logging_to_file(self, mock_redirect_stdout_for_file): +# processor = DagFileProcessorProcess( +# file_path="abc.txt", +# dag_directory=[], +# callback_requests=[], +# ) +# processor._run_file_processor( +# result_channel=MagicMock(), +# parent_channel=MagicMock(), +# file_path="fake_file_path", +# thread_name="fake_thread_name", +# callback_requests=[], +# dag_directory=[], +# ) +# mock_redirect_stdout_for_file.assert_called_once() + + +def test_parse_file_with_dag_callbacks(spy_agency): + from airflow import DAG + + called = False + + def on_failure(context): + nonlocal called + called = True + + dag = DAG(dag_id="a", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + requests = [ + DagCallbackRequest( + full_filepath="A", + msg="Message", + dag_id="a", + run_id="b", + ) + ] + _parse_file( + DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() + ) - with create_session() as session: - dag = dagbag.get_dag("example_branch_operator") - task = dag.get_task(task_id="run_this_first") - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - state=State.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), - session=session, - **triggered_by_kwargs, - ) - ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING) - ti.hostname = "test_hostname" - session.add(ti) + assert called is True - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks(dagbag, requests, False) - with create_session() as session: - tis = session.query(TaskInstance) - assert tis[0].hostname == "test_hostname" +@pytest.mark.xfail(reason="TODO: AIP-72: Task level callbacks not yet supported") +def test_parse_file_with_task_callbacks(spy_agency): + from airflow import DAG - def test_process_file_should_failure_callback(self, monkeypatch, tmp_path, get_test_dag): - callback_file = tmp_path.joinpath("callback.txt") - callback_file.touch() - monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file)) - dag_file_processor = DagFileProcessor(dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()) + called = False - dag = get_test_dag("test_on_failure_callback") - task = dag.get_task(task_id="test_on_failure_callback_task") - with create_session() as session: - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - state=State.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), - session=session, - **triggered_by_kwargs, - ) - ti = dagrun.get_task_instance(task.task_id) - ti.refresh_from_task(task) - - requests = [ - TaskCallbackRequest( - full_filepath=dag.fileloc, - simple_task_instance=SimpleTaskInstance.from_ti(ti), - msg="Message", - ) - ] - dag_file_processor.process_file(dag.fileloc, requests) - - ti.refresh_from_db() - msg = " ".join([str(k) for k in ti.key.primary]) + " fired callback" - assert msg in callback_file.read_text() - - @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) - @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) - @mock.patch("airflow.dag_processing.processor.redirect_stdout") - def test_dag_parser_output_when_logging_to_stdout(self, mock_redirect_stdout_for_file): - processor = DagFileProcessorProcess( - file_path="abc.txt", - dag_directory=[], - callback_requests=[], - ) - processor._run_file_processor( - result_channel=MagicMock(), - parent_channel=MagicMock(), - file_path="fake_file_path", - thread_name="fake_thread_name", - callback_requests=[], - dag_directory=[], - ) - mock_redirect_stdout_for_file.assert_not_called() - - @conf_vars({("logging", "dag_processor_log_target"): "file"}) - @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) - @mock.patch("airflow.dag_processing.processor.redirect_stdout") - def test_dag_parser_output_when_logging_to_file(self, mock_redirect_stdout_for_file): - processor = DagFileProcessorProcess( - file_path="abc.txt", - dag_directory=[], - callback_requests=[], - ) - processor._run_file_processor( - result_channel=MagicMock(), - parent_channel=MagicMock(), - file_path="fake_file_path", - thread_name="fake_thread_name", - callback_requests=[], - dag_directory=[], - ) - mock_redirect_stdout_for_file.assert_called_once() - - @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) - @mock.patch.object(DagFileProcessorProcess, "_get_multiprocessing_context") - def test_no_valueerror_with_parseable_dag_in_zip(self, mock_context, tmp_path): - mock_context.return_value.Pipe.return_value = (MagicMock(), MagicMock()) - zip_filename = (tmp_path / "test_zip.zip").as_posix() - with ZipFile(zip_filename, "w") as zip_file: - zip_file.writestr(TEMP_DAG_FILENAME, PARSEABLE_DAG_FILE_CONTENTS) - - processor = DagFileProcessorProcess( - file_path=zip_filename, - dag_directory=[], - callback_requests=[], - ) - processor.start() - - @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) - @mock.patch.object(DagFileProcessorProcess, "_get_multiprocessing_context") - def test_nullbyte_exception_handling_when_preimporting_airflow(self, mock_context, tmp_path): - mock_context.return_value.Pipe.return_value = (MagicMock(), MagicMock()) - dag_filename = (tmp_path / "test_dag.py").as_posix() - with open(dag_filename, "wb") as file: - file.write(b"hello\x00world") - - processor = DagFileProcessorProcess( - file_path=dag_filename, - dag_directory=[], - callback_requests=[], - ) - processor.start() + def on_failure(context): + nonlocal called + called = True - def test_counter_for_last_num_of_db_queries(self): - dag_filepath = TEST_DAG_FOLDER / "test_dag_for_db_queries_counter.py" + with DAG(dag_id="a", on_failure_callback=on_failure) as dag: + BaseOperator(task_id="b", on_failure_callback=on_failure) - with create_session() as session: - with assert_queries_count( - expected_count=154, - margin=10, - session=session, - ): - self._process_file(dag_filepath, TEST_DAG_FOLDER, session) + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + requests = [ + TaskCallbackRequest.model_construct( + full_filepath="A", + msg="Message", + ti=None, + ) + ] + _parse_file( + DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() + ) + + assert called is True diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 1c144eb844540..bfec538c24cb4 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -26,7 +26,6 @@ import pytest import time_machine -from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import DefaultHelpParser, GroupCommand from airflow.cli.cli_parser import AirflowHelpFormatter from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType @@ -362,10 +361,9 @@ def test_debug_dump(caplog): def test_base_executor_cannot_send_callback(): - cbr = CallbackRequest("some_file_path_for_callback") executor = BaseExecutor() with pytest.raises(ValueError): - executor.send_callback(cbr) + executor.send_callback(mock.Mock()) def test_parser_and_formatter_class(): diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 5be6e62af7310..089dbf6e1db05 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -62,7 +62,7 @@ from airflow.models.log import Log from airflow.models.pool import Pool from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.bash import BashOperator from airflow.sdk.definitions.asset import Asset @@ -418,7 +418,7 @@ def test_process_executor_events_with_callback(self, mock_stats_incr, mock_task_ assert ti1.state == State.QUEUED mock_task_callback.assert_called_once_with( full_filepath=dag.fileloc, - simple_task_instance=mock.ANY, + ti=mock.ANY, processor_subdir=None, msg=f"Executor {executor} reported that the task instance " " " @@ -5712,14 +5712,13 @@ def test_find_and_purge_zombies(self, load_examples, session): callback_requests = executor.callback_sink.send.call_args.args assert len(callback_requests) == 1 callback_request = callback_requests[0] - assert isinstance(callback_request.simple_task_instance, SimpleTaskInstance) assert callback_request.full_filepath == dag.fileloc assert callback_request.msg == str(self.job_runner._generate_zombie_message_details(ti)) assert callback_request.is_failure_callback is True - assert callback_request.simple_task_instance.dag_id == ti.dag_id - assert callback_request.simple_task_instance.task_id == ti.task_id - assert callback_request.simple_task_instance.run_id == ti.run_id - assert callback_request.simple_task_instance.map_index == ti.map_index + assert callback_request.ti.dag_id == ti.dag_id + assert callback_request.ti.task_id == ti.task_id + assert callback_request.ti.run_id == ti.run_id + assert callback_request.ti.map_index == ti.map_index def test_zombie_message(self, load_examples): """ @@ -5829,18 +5828,18 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.fileloc, - simple_task_instance=SimpleTaskInstance.from_ti(ti), + ti=ti, processor_subdir=TEST_DAG_FOLDER, msg=str(self.job_runner._generate_zombie_message_details(ti)), ) ] callback_requests = scheduler_job.executor.callback_sink.send.call_args.args assert len(callback_requests) == 1 - assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { - result.simple_task_instance.key for result in callback_requests + assert {zombie.ti.id for zombie in expected_failure_callback_requests} == { + result.ti.id for result in callback_requests } - expected_failure_callback_requests[0].simple_task_instance = None - callback_requests[0].simple_task_instance = None + expected_failure_callback_requests[0].ti = None + callback_requests[0].ti = None assert expected_failure_callback_requests[0] == callback_requests[0] def test_cleanup_stale_dags(self): diff --git a/tests/listeners/dag_import_error_listener.py b/tests/listeners/dag_import_error_listener.py index a4426c1324059..b56c93bafedaa 100644 --- a/tests/listeners/dag_import_error_listener.py +++ b/tests/listeners/dag_import_error_listener.py @@ -26,17 +26,13 @@ @hookimpl def on_new_dag_import_error(filename, stacktrace): """Execute when new dag import error appears""" - new["filename"] = stacktrace - print("new error>> filename:" + str(filename)) - print("new error>> stacktrace:" + str(stacktrace)) + new[filename] = stacktrace @hookimpl def on_existing_dag_import_error(filename, stacktrace): """Execute when existing dag import error appears""" - existing["filename"] = stacktrace - print("existing error>> filename:" + str(filename)) - print("existing error>> stacktrace:" + str(stacktrace)) + existing[filename] = stacktrace def clear(): diff --git a/tests/listeners/test_dag_import_error_listener.py b/tests/listeners/test_dag_import_error_listener.py deleted file mode 100644 index aa085d3cfd7f3..0000000000000 --- a/tests/listeners/test_dag_import_error_listener.py +++ /dev/null @@ -1,163 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import os -import sys -from unittest import mock - -import pytest - -from airflow import settings -from airflow.configuration import TEST_DAGS_FOLDER -from airflow.dag_processing.processor import DagFileProcessor -from airflow.listeners.listener import get_listener_manager -from airflow.models import DagModel -from airflow.models.errors import ParseImportError -from airflow.utils import timezone - -from tests.listeners import dag_import_error_listener -from tests_common.test_utils.config import conf_vars, env_vars -from tests_common.test_utils.db import ( - clear_db_dags, - clear_db_import_errors, - clear_db_jobs, - clear_db_pools, - clear_db_runs, - clear_db_serialized_dags, -) -from tests_common.test_utils.mock_executor import MockExecutor - -pytestmark = pytest.mark.db_test - -DEFAULT_DATE = timezone.datetime(2016, 1, 1) -PY311 = sys.version_info >= (3, 11) - -# Include the words "airflow" and "dag" in the file contents, -# tricking airflow into thinking these -# files contain a DAG (otherwise Airflow will skip them) -PARSEABLE_DAG_FILE_CONTENTS = '"airflow DAG"' -UNPARSEABLE_DAG_FILE_CONTENTS = "airflow DAG" -INVALID_DAG_WITH_DEPTH_FILE_CONTENTS = "def something():\n return airflow_DAG\nsomething()" - -# Filename to be used for dags that are created in an ad-hoc manner and can be removed/ -# created at runtime -TEMP_DAG_FILENAME = "temp_dag.py" - - -@pytest.fixture(scope="class") -def disable_load_example(): - with conf_vars({("core", "load_examples"): "false"}): - with env_vars({"AIRFLOW__CORE__LOAD_EXAMPLES": "false"}): - yield - - -@pytest.mark.usefixtures("disable_load_example") -class TestDagFileProcessor: - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_pools() - clear_db_dags() - clear_db_import_errors() - clear_db_jobs() - clear_db_serialized_dags() - - def setup_class(self): - self.clean_db() - - def setup_method(self): - # Speed up some tests by not running the tasks, just look at what we - # enqueue! - self.null_exec = MockExecutor() - self.scheduler_job = None - - def teardown_method(self) -> None: - if self.scheduler_job and self.scheduler_job.job_runner.processor_agent: - self.scheduler_job.job_runner.processor_agent.end() - self.scheduler_job = None - self.clean_db() - - def _process_file(self, file_path, dag_directory, session): - dag_file_processor = DagFileProcessor(dag_directory=str(dag_directory), log=mock.MagicMock()) - - dag_file_processor.process_file(file_path, []) - - def test_newly_added_import_error(self, tmp_path, session): - dag_import_error_listener.clear() - get_listener_manager().add_listener(dag_import_error_listener) - - dag_file = os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py") - temp_dagfile = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - with open(dag_file) as main_dag, open(temp_dagfile, "w") as next_dag: - for line in main_dag: - next_dag.write(line) - # first we parse the dag - self._process_file(temp_dagfile, dag_directory=tmp_path, session=session) - # assert DagModel.has_import_errors is false - dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() - assert not dm.has_import_errors - # corrupt the file - with open(temp_dagfile, "a") as file: - file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - - self._process_file(temp_dagfile, dag_directory=tmp_path, session=session) - import_errors = session.query(ParseImportError).all() - - assert len(import_errors) == 1 - import_error = import_errors[0] - assert import_error.filename == temp_dagfile - assert import_error.stacktrace - dm = session.query(DagModel).filter(DagModel.fileloc == temp_dagfile).first() - assert dm.has_import_errors - - # Ensure the listener was notified - assert len(dag_import_error_listener.new) == 1 - assert dag_import_error_listener.new["filename"] == import_error.stacktrace - - def test_already_existing_import_error(self, tmp_path): - dag_import_error_listener.clear() - get_listener_manager().add_listener(dag_import_error_listener) - - filename_to_parse = tmp_path.joinpath(TEMP_DAG_FILENAME).as_posix() - # Generate original import error - with open(filename_to_parse, "w") as file_to_parse: - file_to_parse.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) - session = settings.Session() - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - import_error_1 = ( - session.query(ParseImportError).filter(ParseImportError.filename == filename_to_parse).one() - ) - - # process the file multiple times - for _ in range(10): - self._process_file(filename_to_parse, dag_directory=tmp_path, session=session) - - import_error_2 = ( - session.query(ParseImportError).filter(ParseImportError.filename == filename_to_parse).one() - ) - - # assert that the ID of the import error did not change - assert import_error_1.id == import_error_2.id - - # Ensure the listener was notified - assert len(dag_import_error_listener.existing) == 1 - assert dag_import_error_listener.existing["filename"] == import_error_1.stacktrace - assert dag_import_error_listener.existing["filename"] == import_error_2.stacktrace diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index bbce98badfe9c..d036f5f12776d 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -192,6 +192,6 @@ def mytask(): session.commit() dagcode2 = DagCode.get_latest_dagcode(dag.dag_id) assert dagcode2.source_code_hash == 2 - DagCode.update_source_code(dag) + DagCode.update_source_code(dag.dag_id, dag.fileloc) dag_code3 = DagCode.get_latest_dagcode(dag.dag_id) assert dag_code3.source_code_hash != 2