diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 29f95ef4c742a..058eef6ab8922 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -140,10 +140,8 @@ def register_dataset_change( dags_to_reparse = dags_to_queue_from_dataset_alias - dags_to_queue_from_dataset if dags_to_reparse: - session.add_all( - DagPriorityParsingRequest(fileloc=fileloc) - for fileloc in {dag.fileloc for dag in dags_to_reparse} - ) + file_locs = {dag.fileloc for dag in dags_to_reparse} + cls._send_dag_priority_parsing_request(file_locs, session) session.flush() cls.notify_dataset_changed(dataset=dataset) @@ -208,6 +206,35 @@ def _postgres_queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset_id).on_conflict_do_nothing() session.execute(stmt, values) + @classmethod + def _send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None: + if session.bind.dialect.name == "postgresql": + return cls._postgres_send_dag_priority_parsing_request(file_locs, session) + return cls._slow_path_send_dag_priority_parsing_request(file_locs, session) + + @classmethod + def _slow_path_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None: + def _send_dag_priority_parsing_request_if_needed(fileloc: str) -> str | None: + # Don't error whole transaction when a single DagPriorityParsingRequest item conflicts. + # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint + req = DagPriorityParsingRequest(fileloc=fileloc) + try: + with session.begin_nested(): + session.merge(req) + except exc.IntegrityError: + cls.logger().debug("Skipping request %s, already present", req, exc_info=True) + return None + return req.fileloc + + (_send_dag_priority_parsing_request_if_needed(fileloc) for fileloc in file_locs) + + @classmethod + def _postgres_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None: + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(DagPriorityParsingRequest).on_conflict_do_nothing() + session.execute(stmt, {"fileloc": fileloc for fileloc in file_locs}) + def resolve_dataset_manager() -> DatasetManager: """Retrieve the dataset manager."""