Skip to content

Commit

Permalink
Implict limit for SQL over native (#8206) (#8213)
Browse files Browse the repository at this point in the history
Co-authored-by: Aljaž Mur Eržen <[email protected]>
  • Loading branch information
msullivan and aljazerzen committed Jan 14, 2025
1 parent ac13854 commit 6377c58
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 0 deletions.
18 changes: 18 additions & 0 deletions edb/pgsql/resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def resolve(
else:
raise AssertionError()

if limit := ctx.options.implicit_limit:
resolved = apply_implicit_limit(resolved, limit, resolved_table, ctx)

command.fini_external_params(ctx)

if top_level_ctes:
Expand Down Expand Up @@ -179,3 +182,18 @@ def as_plain_select(
for index, c in enumerate(table.columns)
],
)


def apply_implicit_limit(
expr: pgast.Base,
limit: int,
table: Optional[context.Table],
ctx: context.ResolverContextLevel,
) -> pgast.Base:
e = as_plain_select(expr, table, ctx)
if not e:
return expr

if e.limit_count is None:
e.limit_count = pgast.NumericConstant(val=str(limit))
return e
3 changes: 3 additions & 0 deletions edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Options:
# DisableNormalization to recompile the query without normalization.
normalized_params: List[int]

# Apply a limit to the number of rows in the top-level query
implicit_limit: Optional[int]


@dataclass(kw_only=True)
class Scope:
Expand Down
1 change: 1 addition & 0 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,7 @@ def compile_sql_as_unit_group(
disambiguate_column_names=True,
backend_runtime_params=ctx.backend_runtime_params,
protocol_version=ctx.protocol_version,
implicit_limit=ctx.implicit_limit,
)

qug = dbstate.QueryUnitGroup(
Expand Down
6 changes: 6 additions & 0 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def compile_sql(
disambiguate_column_names: bool,
backend_runtime_params: pg_params.BackendRuntimeParams,
protocol_version: defines.ProtocolVersion,
implicit_limit: Optional[int] = None,
) -> List[dbstate.SQLQueryUnit]:
def _try(
q: str, normalized_params: List[int]
Expand All @@ -101,6 +102,7 @@ def _try(
backend_runtime_params=backend_runtime_params,
protocol_version=protocol_version,
normalized_params=normalized_params,
implicit_limit=implicit_limit,
)

normalized_params = list(source.extra_type_oids())
Expand Down Expand Up @@ -207,6 +209,7 @@ def _compile_sql(
backend_runtime_params: pg_params.BackendRuntimeParams,
protocol_version: defines.ProtocolVersion,
normalized_params: List[int],
implicit_limit: Optional[int] = None,
) -> List[dbstate.SQLQueryUnit]:
opts = ResolverOptionsPartial(
query_str=query_str,
Expand All @@ -219,6 +222,7 @@ def _compile_sql(
),
disambiguate_column_names=disambiguate_column_names,
normalized_params=normalized_params,
implicit_limit=implicit_limit,
)

# orig_stmts are the statements prior to constant extraction
Expand Down Expand Up @@ -557,6 +561,7 @@ class ResolverOptionsPartial:
include_edgeql_io_format_alternative: Optional[bool]
disambiguate_column_names: bool
normalized_params: List[int]
implicit_limit: Optional[int]


def resolve_query(
Expand Down Expand Up @@ -606,6 +611,7 @@ def resolve_query(
),
disambiguate_column_names=opts.disambiguate_column_names,
normalized_params=opts.normalized_params,
implicit_limit=opts.implicit_limit,
)
resolved = pg_resolver.resolve(stmt, schema, options)
source = pg_codegen.generate(resolved.ast, with_source_map=True)
Expand Down
2 changes: 2 additions & 0 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ async def assert_sql_query_result(
query,
exp_result,
*,
implicit_limit=0,
msg=None,
sort=None,
variables=None,
Expand All @@ -1263,6 +1264,7 @@ async def assert_sql_query_result(
await self.assert_query_result(
query,
exp_result,
implicit_limit=implicit_limit,
msg=msg,
sort=sort,
variables=variables,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,14 @@ async def test_sql_native_query_24(self):
'COPY "Genre" TO STDOUT', []
)

async def test_sql_native_query_25(self):
# implict limit
await self.assert_sql_query_result(
'VALUES (1), (2), (3), (4), (5), (6), (7)',
[{'column1': 1}, {'column1': 2}, {'column1': 3}],
implicit_limit=3,
)


class TestSQLQueryNonTransactional(tb.SQLQueryTestCase):

Expand Down

0 comments on commit 6377c58

Please sign in to comment.