Skip to content

Commit

Permalink
Replace var names select, base_select, base_query with `stateme…
Browse files Browse the repository at this point in the history
…nt` (#44270)

Previously it was `base_select` then recently renamed to `select`.  `select` is not
a great choice because it collides with the sqlalchemy function.  I think `query` is the best name but, going with `statement` in response to review comments.  Which isn't so bad, and is consistent with arg to sqlalchemy's `execute` func so, there's that.p

Also best to remove the "base" part of it because we tend to mutate it, making it not a "base" of anything.
  • Loading branch information
dstandish authored Nov 25, 2024
1 parent 4404e64 commit 28ce656
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 67 deletions.
44 changes: 21 additions & 23 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ def your_route(session: Annotated[Session, Depends(get_session)]):


def apply_filters_to_select(
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
*, statement: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
if filters is None:
return base_select
return statement
for f in filters:
if f is None:
continue
base_select = f.to_orm(base_select)
statement = f.to_orm(statement)

return base_select
return statement


async def get_async_session() -> AsyncSession:
Expand All @@ -89,7 +87,7 @@ def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
@overload
async def paginated_select_async(
*,
query: Select,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
Expand All @@ -102,7 +100,7 @@ async def paginated_select_async(
@overload
async def paginated_select_async(
*,
query: Select,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
Expand All @@ -114,40 +112,40 @@ async def paginated_select_async(

async def paginated_select_async(
*,
query: Select,
statement: Select,
filters: Sequence[BaseParam | None] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
query = apply_filters_to_select(
base_select=query,
statement = apply_filters_to_select(
statement=statement,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = await get_query_count_async(query, session=session)
total_entries = await get_query_count_async(statement, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

query = apply_filters_to_select(
base_select=query,
statement = apply_filters_to_select(
statement=statement,
filters=[order_by, offset, limit],
)

return query, total_entries
return statement, total_entries


@overload
def paginated_select(
*,
select: Select,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
Expand All @@ -160,7 +158,7 @@ def paginated_select(
@overload
def paginated_select(
*,
select: Select,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
Expand All @@ -173,28 +171,28 @@ def paginated_select(
@provide_session
def paginated_select(
*,
select: Select,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
base_select=select,
statement = apply_filters_to_select(
statement=statement,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = get_query_count(base_select, session=session)
total_entries = get_query_count(statement, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])
statement = apply_filters_to_select(statement=statement, filters=[order_by, offset, limit])

return base_select, total_entries
return statement, total_entries
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_assets(
) -> AssetCollectionResponse:
"""Get assets."""
assets_select, total_entries = paginated_select(
select=select(AssetModel),
statement=select(AssetModel),
filters=[uri_pattern, dag_ids],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_asset_events(
) -> AssetEventCollectionResponse:
"""Get asset events."""
assets_event_select, total_entries = paginated_select(
select=select(AssetEvent),
statement=select(AssetEvent),
filters=[asset_id, source_dag_id, source_task_id, source_run_id, source_map_index],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query)
dag_asset_queued_events_select, total_entries = paginated_select(statement=query)
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down Expand Up @@ -269,7 +269,7 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query)
dag_asset_queued_events_select, total_entries = paginated_select(statement=query)
adrqs = session.execute(dag_asset_queued_events_select).all()
if not adrqs:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with dag_id: `{dag_id}` was not found")
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def list_backfills(
session: Annotated[AsyncSession, Depends(get_async_session)],
) -> BackfillCollectionResponse:
select_stmt, total_entries = await paginated_select_async(
query=select(Backfill).where(Backfill.dag_id == dag_id),
statement=select(Backfill).where(Backfill.dag_id == dag_id),
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_connections(
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select=select(Connection),
statement=select(Connection),
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ def get_dag_runs(
This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs.
"""
base_query = select(DagRun)
query = select(DagRun)

if dag_id != "~":
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found")

base_query = base_query.filter(DagRun.dag_id == dag_id)
query = query.filter(DagRun.dag_id == dag_id)

dag_run_select, total_entries = paginated_select(
select=base_query,
statement=query,
filters=[logical_date, start_date_range, end_date_range, update_at_range, state],
order_by=order_by,
offset=offset,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dag_stats(
) -> DagStatsCollectionResponse:
"""Get Dag statistics."""
dagruns_select, _ = paginated_select(
select=dagruns_select_with_state_count,
statement=dagruns_select_with_state_count,
filters=[dag_ids],
session=session,
return_total_entries=False,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def list_dag_warnings(
) -> DAGWarningCollectionResponse:
"""Get a list of DAG warnings."""
dag_warnings_select, total_entries = paginated_select(
select=select(DagWarning),
statement=select(DagWarning),
filters=[warning_type, dag_id],
order_by=order_by,
offset=offset,
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_dags(
) -> DAGCollectionResponse:
"""Get all DAGs."""
dags_select, total_entries = paginated_select(
select=dags_select_with_latest_dag_run,
statement=dags_select_with_latest_dag_run,
filters=[
only_active,
paused,
Expand Down Expand Up @@ -125,9 +125,9 @@ def get_dag_tags(
session: Annotated[Session, Depends(get_session)],
) -> DAGTagCollectionResponse:
"""Get all DAG tags."""
base_select = select(DagTag.name).group_by(DagTag.name)
query = select(DagTag.name).group_by(DagTag.name)
dag_tags_select, total_entries = paginated_select(
select=base_select,
statement=query,
filters=[tag_name_pattern],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -263,7 +263,7 @@ def patch_dags(
update_mask = ["is_paused"]

dags_select, total_entries = paginated_select(
select=dags_select_with_latest_dag_run,
statement=dags_select_with_latest_dag_run,
filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state],
order_by=None,
offset=offset,
Expand Down
26 changes: 13 additions & 13 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,32 @@ def get_event_logs(
after: datetime | None = None,
) -> EventLogCollectionResponse:
"""Get all Event Logs."""
base_select = select(Log).group_by(Log.id)
query = select(Log).group_by(Log.id)
# TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11`
if dag_id is not None:
base_select = base_select.where(Log.dag_id == dag_id)
query = query.where(Log.dag_id == dag_id)
if task_id is not None:
base_select = base_select.where(Log.task_id == task_id)
query = query.where(Log.task_id == task_id)
if run_id is not None:
base_select = base_select.where(Log.run_id == run_id)
query = query.where(Log.run_id == run_id)
if map_index is not None:
base_select = base_select.where(Log.map_index == map_index)
query = query.where(Log.map_index == map_index)
if try_number is not None:
base_select = base_select.where(Log.try_number == try_number)
query = query.where(Log.try_number == try_number)
if owner is not None:
base_select = base_select.where(Log.owner == owner)
query = query.where(Log.owner == owner)
if event is not None:
base_select = base_select.where(Log.event == event)
query = query.where(Log.event == event)
if excluded_events is not None:
base_select = base_select.where(Log.event.notin_(excluded_events))
query = query.where(Log.event.notin_(excluded_events))
if included_events is not None:
base_select = base_select.where(Log.event.in_(included_events))
query = query.where(Log.event.in_(included_events))
if before is not None:
base_select = base_select.where(Log.dttm < before)
query = query.where(Log.dttm < before)
if after is not None:
base_select = base_select.where(Log.dttm > after)
query = query.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
select=base_select,
statement=query,
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_import_errors(
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select=select(ParseImportError),
statement=select(ParseImportError),
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_jobs(
# TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11`

jobs_select, total_entries = paginated_select(
select=base_select,
statement=base_select,
filters=[
start_date_range,
end_date_range,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_pools(
) -> PoolCollectionResponse:
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select=select(Pool),
statement=select(Pool),
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
18 changes: 9 additions & 9 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ def get_mapped_task_instances(
session: Annotated[Session, Depends(get_session)],
) -> TaskInstanceCollectionResponse:
"""Get list of mapped task instances."""
base_query = (
query = (
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index >= 0)
.join(TI.dag_run)
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
unfiltered_total_count = get_query_count(base_query, session=session)
unfiltered_total_count = get_query_count(query, session=session)
if unfiltered_total_count == 0:
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
Expand All @@ -157,7 +157,7 @@ def get_mapped_task_instances(
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

task_instance_select, total_entries = paginated_select(
select=base_query,
statement=query,
filters=[
logical_date_range,
start_date_range,
Expand Down Expand Up @@ -341,13 +341,13 @@ def get_task_instances(
This endpoint allows specifying `~` as the dag_id, dag_run_id to retrieve Task Instances for all DAGs
and DAG runs.
"""
base_query = select(TI).join(TI.dag_run)
query = select(TI).join(TI.dag_run)

if dag_id != "~":
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: `{dag_id}` was not found")
base_query = base_query.where(TI.dag_id == dag_id)
query = query.where(TI.dag_id == dag_id)

if dag_run_id != "~":
dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id))
Expand All @@ -356,10 +356,10 @@ def get_task_instances(
status.HTTP_404_NOT_FOUND,
f"DagRun with run_id: `{dag_run_id}` was not found",
)
base_query = base_query.where(TI.run_id == dag_run_id)
query = query.where(TI.run_id == dag_run_id)

task_instance_select, total_entries = paginated_select(
select=base_query,
statement=query,
filters=[
logical_date,
start_date_range,
Expand Down Expand Up @@ -426,9 +426,9 @@ def get_task_instances_batch(
TI,
).set_value(body.order_by)

base_query = select(TI).join(TI.dag_run)
query = select(TI).join(TI.dag_run)
task_instance_select, total_entries = paginated_select(
select=base_query,
statement=query,
filters=[
dag_ids,
dag_run_ids,
Expand Down
Loading

0 comments on commit 28ce656

Please sign in to comment.