Skip to content

Commit

Permalink
fix: Wrong SQL query build for GQL Relay node (#2128)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Jun 12, 2024
1 parent 5cfd9df commit 14996f2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 39 deletions.
1 change: 1 addition & 0 deletions changes/2128.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix wrong SQL query build for GQL Relay node
74 changes: 51 additions & 23 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@

if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.selectable import ScalarSelect

from .gql import GraphQueryContext
from .user import UserRole
Expand Down Expand Up @@ -1323,14 +1325,13 @@ def _build_sql_stmt_from_connection_args(
order_expr: OrderExprArg | None = None,
*,
connection_args: ConnectionArgs,
) -> tuple[sa.sql.Select, list[WhereClauseType]]:
) -> tuple[sa.sql.Select, sa.sql.Select, list[WhereClauseType]]:
stmt = sa.select(orm_class)
count_stmt = sa.select(sa.func.count()).select_from(orm_class)
conditions: list[WhereClauseType] = []

cursor_id, pagination_order, requested_page_size = connection_args

# Default ordering by id column
id_ordering_item: OrderingItem = OrderingItem(id_column, OrderDirection.ASC)
ordering_item_list: list[OrderingItem] = []
if order_expr is not None:
parser = order_expr.parser
Expand All @@ -1339,10 +1340,14 @@ def _build_sql_stmt_from_connection_args(
# Apply SQL order_by
match pagination_order:
case ConnectionPaginationOrder.FORWARD | None:
# Default ordering by id column
id_ordering_item = OrderingItem(id_column, OrderDirection.ASC)
set_ordering = lambda col, direction: (
col.asc() if direction == OrderDirection.ASC else col.desc()
)
case ConnectionPaginationOrder.BACKWARD:
# Default ordering by id column
id_ordering_item = OrderingItem(id_column, OrderDirection.DESC)
set_ordering = lambda col, direction: (
col.desc() if direction == OrderDirection.ASC else col.asc()
)
Expand All @@ -1352,21 +1357,39 @@ def _build_sql_stmt_from_connection_args(

# Set cursor by comparing scalar values of subquery that queried by cursor id
if cursor_id is not None:
_, _id = AsyncNode.resolve_global_id(info, cursor_id)
match pagination_order:
case ConnectionPaginationOrder.FORWARD | None:
conditions.append(id_column > _id)
set_subquery = lambda col, subquery, direction: (
col >= subquery if direction == OrderDirection.ASC else col <= subquery
)
case ConnectionPaginationOrder.BACKWARD:
conditions.append(id_column < _id)
set_subquery = lambda col, subquery, direction: (
col <= subquery if direction == OrderDirection.ASC else col >= subquery
)
_, cursor_row_id = AsyncNode.resolve_global_id(info, cursor_id)

def subq_to_condition(
column_to_be_compared: InstrumentedAttribute,
subquery: ScalarSelect,
direction: OrderDirection,
) -> WhereClauseType:
match pagination_order:
case ConnectionPaginationOrder.FORWARD | None:
if direction == OrderDirection.ASC:
cond = column_to_be_compared > subquery
else:
cond = column_to_be_compared < subquery

# Comparing ID field - The direction of inequality sign - is not effected by `direction` argument here
# because the ordering direction of ID field is always determined by `pagination_order` only.
condition_when_same_with_subq = (column_to_be_compared == subquery) & (
id_column > cursor_row_id
)
case ConnectionPaginationOrder.BACKWARD:
if direction == OrderDirection.ASC:
cond = column_to_be_compared < subquery
else:
cond = column_to_be_compared > subquery
condition_when_same_with_subq = (column_to_be_compared == subquery) & (
id_column < cursor_row_id
)

return cond | condition_when_same_with_subq

for col, direction in ordering_item_list:
subq = sa.select(col).where(id_column == _id).scalar_subquery()
stmt = stmt.where(set_subquery(col, subq, direction))
subq = sa.select(col).where(id_column == cursor_row_id).scalar_subquery()
conditions.append(subq_to_condition(col, subq, direction))

if requested_page_size is not None:
# Add 1 to determine has_next_page or has_previous_page
Expand All @@ -1378,7 +1401,8 @@ def _build_sql_stmt_from_connection_args(

for cond in conditions:
stmt = stmt.where(cond)
return stmt, conditions
count_stmt = count_stmt.where(cond)
return stmt, count_stmt, conditions


def _build_sql_stmt_from_sql_arg(
Expand All @@ -1390,8 +1414,9 @@ def _build_sql_stmt_from_sql_arg(
*,
limit: int | None = None,
offset: int | None = None,
) -> tuple[sa.sql.Select, list[WhereClauseType]]:
) -> tuple[sa.sql.Select, sa.sql.Select, list[WhereClauseType]]:
stmt = sa.select(orm_class)
count_stmt = sa.select(sa.func.count()).select_from(orm_class)
conditions: list[WhereClauseType] = []

if order_expr is not None:
Expand All @@ -1412,11 +1437,13 @@ def _build_sql_stmt_from_sql_arg(
stmt = stmt.offset(offset)
for cond in conditions:
stmt = stmt.where(cond)
return stmt, conditions
count_stmt = count_stmt.where(cond)
return stmt, count_stmt, conditions


class GraphQLConnectionSQLInfo(NamedTuple):
sql_stmt: sa.sql.Select
sql_count_stmt: sa.sql.Select
sql_conditions: list[WhereClauseType]
cursor: str | None
pagination_order: ConnectionPaginationOrder | None
Expand Down Expand Up @@ -1455,7 +1482,7 @@ def generate_sql_info_for_gql_connection(
connection_args = validate_connection_args(
after=after, first=first, before=before, last=last
)
stmt, conditions = _build_sql_stmt_from_connection_args(
stmt, count_stmt, conditions = _build_sql_stmt_from_connection_args(
info,
orm_class,
id_column,
Expand All @@ -1465,14 +1492,15 @@ def generate_sql_info_for_gql_connection(
)
return GraphQLConnectionSQLInfo(
stmt,
count_stmt,
conditions,
connection_args.cursor,
connection_args.pagination_order,
connection_args.requested_page_size,
)
else:
page_size = first
stmt, conditions = _build_sql_stmt_from_sql_arg(
stmt, count_stmt, conditions = _build_sql_stmt_from_sql_arg(
info,
orm_class,
id_column,
Expand All @@ -1481,4 +1509,4 @@ def generate_sql_info_for_gql_connection(
limit=page_size,
offset=offset,
)
return GraphQLConnectionSQLInfo(stmt, conditions, None, None, page_size)
return GraphQLConnectionSQLInfo(stmt, count_stmt, conditions, None, None, page_size)
7 changes: 3 additions & 4 deletions src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ async def resolve_user_nodes(
)
(
query,
_,
conditions,
cursor,
pagination_order,
Expand Down Expand Up @@ -936,7 +937,8 @@ async def get_connection(
)
(
query,
conditions,
cnt_query,
_,
cursor,
pagination_order,
page_size,
Expand All @@ -952,9 +954,6 @@ async def get_connection(
before=before,
last=last,
)
cnt_query = sa.select(sa.func.count()).select_from(GroupRow)
for cond in conditions:
cnt_query = cnt_query.where(cond)
async with graph_ctx.db.begin_readonly_session() as db_session:
group_rows = (await db_session.scalars(query)).all()
result = [cls.from_row(row) for row in group_rows]
Expand Down
6 changes: 2 additions & 4 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,8 @@ async def get_connection(
)
(
query,
conditions,
cnt_query,
_,
cursor,
pagination_order,
page_size,
Expand All @@ -1464,9 +1465,6 @@ async def get_connection(
before=before,
last=last,
)
cnt_query = sa.select(sa.func.count()).select_from(UserRow)
for cond in conditions:
cnt_query = cnt_query.where(cond)
async with graph_ctx.db.begin_readonly_session() as db_session:
user_rows = (await db_session.scalars(query)).all()
result = [cls.from_row(row) for row in user_rows]
Expand Down
12 changes: 4 additions & 8 deletions src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,8 @@ async def get_connection(
)
(
query,
conditions,
cnt_query,
_,
cursor,
pagination_order,
page_size,
Expand All @@ -1902,9 +1903,6 @@ async def get_connection(
before=before,
last=last,
)
cnt_query = sa.select(sa.func.count()).select_from(VFolderRow)
for cond in conditions:
cnt_query = cnt_query.where(cond)

async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_rows = (await db_session.scalars(query)).all()
Expand Down Expand Up @@ -2482,7 +2480,8 @@ async def get_connection(
)
(
query,
conditions,
cnt_query,
_,
cursor,
pagination_order,
page_size,
Expand All @@ -2498,9 +2497,6 @@ async def get_connection(
before=before,
last=last,
)
cnt_query = sa.select(sa.func.count()).select_from(VFolderRow)
for cond in conditions:
cnt_query = cnt_query.where(cond)
async with graph_ctx.db.begin_readonly_session() as db_session:
model_store_project_gids = (
(
Expand Down

0 comments on commit 14996f2

Please sign in to comment.