Skip to content

Commit

Permalink
Prevent queries if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Dec 24, 2024
1 parent 8dd043e commit 6e82d46
Showing 1 changed file with 30 additions and 45 deletions.
75 changes: 30 additions & 45 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,56 +638,41 @@ def add_dag_asset_alias_references(
if alias_id not in orm_refs
)

def add_dag_asset_name_uri_references(self, *, session: Session) -> None:
orm_name_refs = set(
@staticmethod
def _add_dag_asset_references(
references: set[tuple[str, str]],
model: type[DagScheduleAssetNameReference] | type[DagScheduleAssetUriReference],
attr: str,
*,
session: Session,
) -> None:
if not references:
return
orm_refs = set(
session.scalars(
select(DagScheduleAssetNameReference.dag_id, DagScheduleAssetNameReference.name).where(
DagScheduleAssetNameReference.dag_id.in_(
dag_id for dag_id, _ in self.schedule_asset_name_references
)
)
select(model.dag_id, model.name).where(model.dag_id.in_(dag_id for dag_id, _ in references))
)
)
new_name_refs = self.schedule_asset_name_references - orm_name_refs
old_name_refs = orm_name_refs - self.schedule_asset_name_references
if old_name_refs:
session.execute(
delete(DagScheduleAssetNameReference).where(
tuple_(DagScheduleAssetNameReference.dag_id, DagScheduleAssetNameReference.name).in_(
old_name_refs
)
)
)
if new_name_refs:
session.execute(
insert(DagScheduleAssetNameReference),
[{"dag_id": d, "name": n} for d, n in new_name_refs],
)
new_refs = references - orm_refs
old_refs = orm_refs - references
if old_refs:
session.execute(delete(model).where(tuple_(model.dag_id, getattr(model, attr)).in_(old_refs)))
if new_refs:
session.execute(insert(model), [{"dag_id": d, attr: r} for d, r in new_refs])

orm_uri_refs = set(
session.scalars(
select(DagScheduleAssetUriReference.dag_id, DagScheduleAssetUriReference.uri).where(
DagScheduleAssetUriReference.dag_id.in_(
dag_id for dag_id, _ in self.schedule_asset_uri_references
)
)
)
def add_dag_asset_name_uri_references(self, *, session: Session) -> None:
self._add_dag_asset_references(
self.schedule_asset_name_references,
DagScheduleAssetNameReference,
"name",
session=session,
)
self._add_dag_asset_references(
self.schedule_asset_uri_references,
DagScheduleAssetUriReference,
"uri",
session=session,
)
new_uri_refs = self.schedule_asset_uri_references - orm_uri_refs
old_uri_refs = orm_uri_refs - self.schedule_asset_uri_references
if old_uri_refs:
session.execute(
delete(DagScheduleAssetUriReference).where(
tuple_(DagScheduleAssetUriReference.dag_id, DagScheduleAssetUriReference.uri).in_(
old_uri_refs
)
)
)
if new_uri_refs:
session.execute(
insert(DagScheduleAssetUriReference),
[{"dag_id": d, "uri": u} for d, u in new_uri_refs],
)

def add_task_asset_references(
self,
Expand Down

0 comments on commit 6e82d46

Please sign in to comment.