diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index 2a15e457f6b..3008bc31bf2 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -86,6 +86,9 @@ def resolve( else: raise AssertionError() + if limit := ctx.options.implicit_limit: + resolved = apply_implicit_limit(resolved, limit, resolved_table, ctx) + command.fini_external_params(ctx) if top_level_ctes: @@ -179,3 +182,18 @@ def as_plain_select( for index, c in enumerate(table.columns) ], ) + + +def apply_implicit_limit( + expr: pgast.Base, + limit: int, + table: Optional[context.Table], + ctx: context.ResolverContextLevel, +) -> pgast.Base: + e = as_plain_select(expr, table, ctx) + if not e: + return expr + + if e.limit_count is None: + e.limit_count = pgast.NumericConstant(val=str(limit)) + return e diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index 149c83c473a..4e823c77ef9 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -64,6 +64,9 @@ class Options: # DisableNormalization to recompile the query without normalization. normalized_params: List[int] + # Apply a limit to the number of rows in the top-level query + implicit_limit: Optional[int] + @dataclass(kw_only=True) class Scope: diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 56e2fd6da53..1206bbcba82 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -2665,6 +2665,7 @@ def compile_sql_as_unit_group( disambiguate_column_names=True, backend_runtime_params=ctx.backend_runtime_params, protocol_version=ctx.protocol_version, + implicit_limit=ctx.implicit_limit, ) qug = dbstate.QueryUnitGroup( diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index 8787c8e87f0..3b265387234 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -80,6 +80,7 @@ def compile_sql( disambiguate_column_names: bool, backend_runtime_params: pg_params.BackendRuntimeParams, protocol_version: defines.ProtocolVersion, + implicit_limit: Optional[int] = None, ) -> List[dbstate.SQLQueryUnit]: def _try( q: str, normalized_params: List[int] @@ -101,6 +102,7 @@ def _try( backend_runtime_params=backend_runtime_params, protocol_version=protocol_version, normalized_params=normalized_params, + implicit_limit=implicit_limit, ) normalized_params = list(source.extra_type_oids()) @@ -207,6 +209,7 @@ def _compile_sql( backend_runtime_params: pg_params.BackendRuntimeParams, protocol_version: defines.ProtocolVersion, normalized_params: List[int], + implicit_limit: Optional[int] = None, ) -> List[dbstate.SQLQueryUnit]: opts = ResolverOptionsPartial( query_str=query_str, @@ -219,6 +222,7 @@ def _compile_sql( ), disambiguate_column_names=disambiguate_column_names, normalized_params=normalized_params, + implicit_limit=implicit_limit, ) # orig_stmts are the statements prior to constant extraction @@ -557,6 +561,7 @@ class ResolverOptionsPartial: include_edgeql_io_format_alternative: Optional[bool] disambiguate_column_names: bool normalized_params: List[int] + implicit_limit: Optional[int] def resolve_query( @@ -606,6 +611,7 @@ def resolve_query( ), disambiguate_column_names=opts.disambiguate_column_names, normalized_params=opts.normalized_params, + implicit_limit=opts.implicit_limit, ) resolved = pg_resolver.resolve(stmt, schema, options) source = pg_codegen.generate(resolved.ast, with_source_map=True) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index c62b1d7b9c1..7b07d070c48 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -1248,6 +1248,7 @@ async def assert_sql_query_result( query, exp_result, *, + implicit_limit=0, msg=None, sort=None, variables=None, @@ -1263,6 +1264,7 @@ async def assert_sql_query_result( await self.assert_query_result( query, exp_result, + implicit_limit=implicit_limit, msg=msg, sort=sort, variables=variables, diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 9187cb4f897..26cc82765ae 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -2850,6 +2850,14 @@ async def test_sql_native_query_24(self): 'COPY "Genre" TO STDOUT', [] ) + async def test_sql_native_query_25(self): + # implict limit + await self.assert_sql_query_result( + 'VALUES (1), (2), (3), (4), (5), (6), (7)', + [{'column1': 1}, {'column1': 2}, {'column1': 3}], + implicit_limit=3, + ) + class TestSQLQueryNonTransactional(tb.SQLQueryTestCase):