From 28ce656772c845a5fb4962bada982d2f01b106b0 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:49:33 -0800 Subject: [PATCH] Replace var names `select`, `base_select`, `base_query` with `statement` (#44270) Previously it was `base_select` then recently renamed to `select`. `select` is not a great choice because it collides with the sqlalchemy function. I think `query` is the best name but, going with `statement` in response to review comments. Which isn't so bad, and is consistent with arg to sqlalchemy's `execute` func so, there's that.p Also best to remove the "base" part of it because we tend to mutate it, making it not a "base" of anything. --- airflow/api_fastapi/common/db/common.py | 44 +++++++++---------- .../core_api/routes/public/assets.py | 8 ++-- .../core_api/routes/public/backfills.py | 2 +- .../core_api/routes/public/connections.py | 2 +- .../core_api/routes/public/dag_run.py | 6 +-- .../core_api/routes/public/dag_stats.py | 2 +- .../core_api/routes/public/dag_warning.py | 2 +- .../core_api/routes/public/dags.py | 8 ++-- .../core_api/routes/public/event_logs.py | 26 +++++------ .../core_api/routes/public/import_error.py | 2 +- .../api_fastapi/core_api/routes/public/job.py | 2 +- .../core_api/routes/public/pools.py | 2 +- .../core_api/routes/public/task_instances.py | 18 ++++---- .../core_api/routes/public/variables.py | 2 +- .../api_fastapi/core_api/routes/ui/dags.py | 2 +- airflow/utils/db.py | 4 +- 16 files changed, 65 insertions(+), 67 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 1621ee6f9bab2..fc7907e5bd25a 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -56,18 +56,16 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, - base_select: Select, - filters: Sequence[BaseParam | None] | None = None, + *, statement: Select, filters: Sequence[BaseParam | None] | None = None ) -> Select: if filters is None: - return base_select + return statement for f in filters: if f is None: continue - base_select = f.to_orm(base_select) + statement = f.to_orm(statement) - return base_select + return statement async def get_async_session() -> AsyncSession: @@ -89,7 +87,7 @@ def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): @overload async def paginated_select_async( *, - query: Select, + statement: Select, filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -102,7 +100,7 @@ async def paginated_select_async( @overload async def paginated_select_async( *, - query: Select, + statement: Select, filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -114,7 +112,7 @@ async def paginated_select_async( async def paginated_select_async( *, - query: Select, + statement: Select, filters: Sequence[BaseParam | None] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -122,32 +120,32 @@ async def paginated_select_async( session: AsyncSession, return_total_entries: bool = True, ) -> tuple[Select, int | None]: - query = apply_filters_to_select( - base_select=query, + statement = apply_filters_to_select( + statement=statement, filters=filters, ) total_entries = None if return_total_entries: - total_entries = await get_query_count_async(query, session=session) + total_entries = await get_query_count_async(statement, session=session) # TODO: Re-enable when permissions are handled. Readable / writable entities, # for instance: # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) - query = apply_filters_to_select( - base_select=query, + statement = apply_filters_to_select( + statement=statement, filters=[order_by, offset, limit], ) - return query, total_entries + return statement, total_entries @overload def paginated_select( *, - select: Select, + statement: Select, filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -160,7 +158,7 @@ def paginated_select( @overload def paginated_select( *, - select: Select, + statement: Select, filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -173,7 +171,7 @@ def paginated_select( @provide_session def paginated_select( *, - select: Select, + statement: Select, filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -181,20 +179,20 @@ def paginated_select( session: Session = NEW_SESSION, return_total_entries: bool = True, ) -> tuple[Select, int | None]: - base_select = apply_filters_to_select( - base_select=select, + statement = apply_filters_to_select( + statement=statement, filters=filters, ) total_entries = None if return_total_entries: - total_entries = get_query_count(base_select, session=session) + total_entries = get_query_count(statement, session=session) # TODO: Re-enable when permissions are handled. Readable / writable entities, # for instance: # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) - base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit]) + statement = apply_filters_to_select(statement=statement, filters=[order_by, offset, limit]) - return base_select, total_entries + return statement, total_entries diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 5aa37c7a6f9f8..b7cc9140e973b 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -95,7 +95,7 @@ def get_assets( ) -> AssetCollectionResponse: """Get assets.""" assets_select, total_entries = paginated_select( - select=select(AssetModel), + statement=select(AssetModel), filters=[uri_pattern, dag_ids], order_by=order_by, offset=offset, @@ -145,7 +145,7 @@ def get_asset_events( ) -> AssetEventCollectionResponse: """Get asset events.""" assets_event_select, total_entries = paginated_select( - select=select(AssetEvent), + statement=select(AssetEvent), filters=[asset_id, source_dag_id, source_task_id, source_run_id, source_map_index], order_by=order_by, offset=offset, @@ -210,7 +210,7 @@ def get_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select(select=query) + dag_asset_queued_events_select, total_entries = paginated_select(statement=query) adrqs = session.execute(dag_asset_queued_events_select).all() if not adrqs: @@ -269,7 +269,7 @@ def get_dag_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select(select=query) + dag_asset_queued_events_select, total_entries = paginated_select(statement=query) adrqs = session.execute(dag_asset_queued_events_select).all() if not adrqs: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with dag_id: `{dag_id}` was not found") diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 78b2beb558895..94d0fd1ed48b8 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -61,7 +61,7 @@ async def list_backfills( session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: select_stmt, total_entries = await paginated_select_async( - query=select(Backfill).where(Backfill.dag_id == dag_id), + statement=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index 46ebcfcf98ca1..edfece1333bbb 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -100,7 +100,7 @@ def get_connections( ) -> ConnectionCollectionResponse: """Get all connection entries.""" connection_select, total_entries = paginated_select( - select=select(Connection), + statement=select(Connection), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index c26650767c98a..e9905cbc83a94 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -274,17 +274,17 @@ def get_dag_runs( This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. """ - base_query = select(DagRun) + query = select(DagRun) if dag_id != "~": dag: DAG = request.app.state.dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found") - base_query = base_query.filter(DagRun.dag_id == dag_id) + query = query.filter(DagRun.dag_id == dag_id) dag_run_select, total_entries = paginated_select( - select=base_query, + statement=query, filters=[logical_date, start_date_range, end_date_range, update_at_range, state], order_by=order_by, offset=offset, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_stats.py b/airflow/api_fastapi/core_api/routes/public/dag_stats.py index 119961f8c5f36..89a22f7face6e 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_stats.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_stats.py @@ -55,7 +55,7 @@ def get_dag_stats( ) -> DagStatsCollectionResponse: """Get Dag statistics.""" dagruns_select, _ = paginated_select( - select=dagruns_select_with_state_count, + statement=dagruns_select_with_state_count, filters=[dag_ids], session=session, return_total_entries=False, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index e933710bc6903..df1e636faa50f 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -59,7 +59,7 @@ def list_dag_warnings( ) -> DAGWarningCollectionResponse: """Get a list of DAG warnings.""" dag_warnings_select, total_entries = paginated_select( - select=select(DagWarning), + statement=select(DagWarning), filters=[warning_type, dag_id], order_by=order_by, offset=offset, diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 99a86508edad4..6a099b9b614b4 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -82,7 +82,7 @@ def get_dags( ) -> DAGCollectionResponse: """Get all DAGs.""" dags_select, total_entries = paginated_select( - select=dags_select_with_latest_dag_run, + statement=dags_select_with_latest_dag_run, filters=[ only_active, paused, @@ -125,9 +125,9 @@ def get_dag_tags( session: Annotated[Session, Depends(get_session)], ) -> DAGTagCollectionResponse: """Get all DAG tags.""" - base_select = select(DagTag.name).group_by(DagTag.name) + query = select(DagTag.name).group_by(DagTag.name) dag_tags_select, total_entries = paginated_select( - select=base_select, + statement=query, filters=[tag_name_pattern], order_by=order_by, offset=offset, @@ -263,7 +263,7 @@ def patch_dags( update_mask = ["is_paused"] dags_select, total_entries = paginated_select( - select=dags_select_with_latest_dag_run, + statement=dags_select_with_latest_dag_run, filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], order_by=None, offset=offset, diff --git a/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow/api_fastapi/core_api/routes/public/event_logs.py index 51feb7e22cfb2..aa1504a51f391 100644 --- a/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -97,32 +97,32 @@ def get_event_logs( after: datetime | None = None, ) -> EventLogCollectionResponse: """Get all Event Logs.""" - base_select = select(Log).group_by(Log.id) + query = select(Log).group_by(Log.id) # TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11` if dag_id is not None: - base_select = base_select.where(Log.dag_id == dag_id) + query = query.where(Log.dag_id == dag_id) if task_id is not None: - base_select = base_select.where(Log.task_id == task_id) + query = query.where(Log.task_id == task_id) if run_id is not None: - base_select = base_select.where(Log.run_id == run_id) + query = query.where(Log.run_id == run_id) if map_index is not None: - base_select = base_select.where(Log.map_index == map_index) + query = query.where(Log.map_index == map_index) if try_number is not None: - base_select = base_select.where(Log.try_number == try_number) + query = query.where(Log.try_number == try_number) if owner is not None: - base_select = base_select.where(Log.owner == owner) + query = query.where(Log.owner == owner) if event is not None: - base_select = base_select.where(Log.event == event) + query = query.where(Log.event == event) if excluded_events is not None: - base_select = base_select.where(Log.event.notin_(excluded_events)) + query = query.where(Log.event.notin_(excluded_events)) if included_events is not None: - base_select = base_select.where(Log.event.in_(included_events)) + query = query.where(Log.event.in_(included_events)) if before is not None: - base_select = base_select.where(Log.dttm < before) + query = query.where(Log.dttm < before) if after is not None: - base_select = base_select.where(Log.dttm > after) + query = query.where(Log.dttm > after) event_logs_select, total_entries = paginated_select( - select=base_select, + statement=query, order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index 233f94df3102d..6090676b5601f 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -86,7 +86,7 @@ def get_import_errors( ) -> ImportErrorCollectionResponse: """Get all import errors.""" import_errors_select, total_entries = paginated_select( - select=select(ParseImportError), + statement=select(ParseImportError), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/job.py b/airflow/api_fastapi/core_api/routes/public/job.py index 0619ec8e66648..1f8808980cb89 100644 --- a/airflow/api_fastapi/core_api/routes/public/job.py +++ b/airflow/api_fastapi/core_api/routes/public/job.py @@ -96,7 +96,7 @@ def get_jobs( # TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11` jobs_select, total_entries = paginated_select( - select=base_select, + statement=base_select, filters=[ start_date_range, end_date_range, diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 6fe1cb3a312b3..d80bf75d9e8ab 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -96,7 +96,7 @@ def get_pools( ) -> PoolCollectionResponse: """Get all pools entries.""" pools_select, total_entries = paginated_select( - select=select(Pool), + statement=select(Pool), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index f2bab41d06117..8d54e372c4b40 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -135,13 +135,13 @@ def get_mapped_task_instances( session: Annotated[Session, Depends(get_session)], ) -> TaskInstanceCollectionResponse: """Get list of mapped task instances.""" - base_query = ( + query = ( select(TI) .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index >= 0) .join(TI.dag_run) ) # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 - unfiltered_total_count = get_query_count(base_query, session=session) + unfiltered_total_count = get_query_count(query, session=session) if unfiltered_total_count == 0: dag = request.app.state.dag_bag.get_dag(dag_id) if not dag: @@ -157,7 +157,7 @@ def get_mapped_task_instances( raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) task_instance_select, total_entries = paginated_select( - select=base_query, + statement=query, filters=[ logical_date_range, start_date_range, @@ -341,13 +341,13 @@ def get_task_instances( This endpoint allows specifying `~` as the dag_id, dag_run_id to retrieve Task Instances for all DAGs and DAG runs. """ - base_query = select(TI).join(TI.dag_run) + query = select(TI).join(TI.dag_run) if dag_id != "~": dag = request.app.state.dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: `{dag_id}` was not found") - base_query = base_query.where(TI.dag_id == dag_id) + query = query.where(TI.dag_id == dag_id) if dag_run_id != "~": dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id)) @@ -356,10 +356,10 @@ def get_task_instances( status.HTTP_404_NOT_FOUND, f"DagRun with run_id: `{dag_run_id}` was not found", ) - base_query = base_query.where(TI.run_id == dag_run_id) + query = query.where(TI.run_id == dag_run_id) task_instance_select, total_entries = paginated_select( - select=base_query, + statement=query, filters=[ logical_date, start_date_range, @@ -426,9 +426,9 @@ def get_task_instances_batch( TI, ).set_value(body.order_by) - base_query = select(TI).join(TI.dag_run) + query = select(TI).join(TI.dag_run) task_instance_select, total_entries = paginated_select( - select=base_query, + statement=query, filters=[ dag_ids, dag_run_ids, diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index a96aa51b5dd64..bd91e9403c152 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -90,7 +90,7 @@ def get_variables( ) -> VariableCollectionResponse: """Get all Variables entries.""" variable_select, total_entries = paginated_select( - select=select(Variable), + statement=select(Variable), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/ui/dags.py b/airflow/api_fastapi/core_api/routes/ui/dags.py index 017ef3c165701..002d11e488943 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dags.py +++ b/airflow/api_fastapi/core_api/routes/ui/dags.py @@ -103,7 +103,7 @@ def recent_dag_runs( .order_by(recent_runs_subquery.c.logical_date.desc()) ) dags_with_recent_dag_runs_select_filter, _ = paginated_select( - select=dags_with_recent_dag_runs_select, + statement=dags_with_recent_dag_runs_select, filters=[ only_active, paused, diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 3c657fd447fd1..1c17e1d00f5fe 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -1448,7 +1448,7 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) -async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: +async def get_query_count_async(statement: Select, *, session: AsyncSession) -> int: """ Get count of a query. @@ -1459,7 +1459,7 @@ async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: :meta private: """ - count_stmt = select(func.count()).select_from(query.order_by(None).subquery()) + count_stmt = select(func.count()).select_from(statement.order_by(None).subquery()) return await session.scalar(count_stmt)