diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index 0f81c9c91937..ef0e5940dd40 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -41,6 +41,14 @@ def resolve( _ = context.ResolverContext(initial=ctx) - command.compile_dml(query, ctx=ctx) + top_level_ctes = command.compile_dml(query, ctx=ctx) - return dispatch.resolve(query, ctx=ctx) + query = dispatch.resolve(query, ctx=ctx) + + if top_level_ctes: + assert isinstance(query, pgast.Query) + if not query.ctes: + query.ctes = [] + query.ctes.extend(top_level_ctes) + + return query diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index d67472b4d3eb..d2598fe63b3f 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -130,17 +130,21 @@ def _pull_columns_from_table( return res -def compile_dml(stmt: pgast.Base, *, ctx: Context) -> None: +def compile_dml( + stmt: pgast.Base, *, ctx: Context +) -> List[pgast.CommonTableExpr]: # extract all dml stmts dml_stmts_sql = _collect_dml_stmts(stmt) if len(dml_stmts_sql) == 0: - return + return [] # preprocess each SQL dml stmt into EdgeQL stmts = [_preprocess_insert_stmt(s, ctx=ctx) for s in dml_stmts_sql] # merge EdgeQL stmts & compile to SQL - ctx.compiled_dml = _compile_preprocessed_dml(stmts, ctx=ctx) + ctx.compiled_dml, ctes = _compile_preprocessed_dml(stmts, ctx=ctx) + + return ctes def _collect_dml_stmts(stmt: pgast.Base) -> List[pgast.InsertStmt]: @@ -400,7 +404,10 @@ def is_default(e: pgast.BaseExpr) -> bool: def _compile_preprocessed_dml( stmts: List[PreprocessedDML], ctx: context.ResolverContextLevel -) -> Mapping[pgast.Query, context.CompiledDML]: +) -> Tuple[ + Mapping[pgast.Query, context.CompiledDML], + List[pgast.CommonTableExpr], +]: """ Compiles *all* DML statements in the query. @@ -514,7 +521,11 @@ def _compile_preprocessed_dml( output_relation_name=stmt_ctes[-1].name, output_namespace=output_namespace, ) - return result + + # return remaining CTEs to be included at the end of the top-level query + # (they probably to triggers) + + return result, ctes def _merge_and_prepare_external_rels( diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index 8dace8e6a1c0..dadeb869ceab 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -186,7 +186,7 @@ class ResolverContextLevel(compiler.ContextLevel): subquery_depth: int # List of CTEs to add the top-level statement. - # This is currently only used by DML compilation to ensure that all DML is + # This is used, for example, by DML compilation to ensure that all DML is # in the top-level WITH binding. ctes_buffer: List[pgast.CommonTableExpr] diff --git a/tests/test_sql_dml.py b/tests/test_sql_dml.py index b19b836fa189..f6e1cff6f0f3 100644 --- a/tests/test_sql_dml.py +++ b/tests/test_sql_dml.py @@ -56,6 +56,21 @@ def tearDown(self): create property can_edit: bool; }; }; + + create type Log { + create property line: str; + }; + + create type Hello { + create property world: str; + + create trigger log_insert_each after insert for each do ( + insert Log { line := 'inserted each ' ++ __new__.world } + ); + create trigger log_insert_all after insert for all do ( + insert Log { line := 'inserted all' } + ); + }; """ ] @@ -338,3 +353,27 @@ async def test_sql_dml_insert_16(self): INSERT INTO "Document" (title) VALUES ('Report'), (DEFAULT); ''' ) + + async def test_sql_dml_insert_17(self): + res = await self.scon.fetch( + ''' + WITH + a as (INSERT INTO "Hello" (world) VALUES ('a')), + b as (INSERT INTO "Hello" (world) VALUES ('b_0'), ('b_1')) + SELECT line FROM "Log" ORDER BY line; + ''' + ) + # changes to the database are not visible in the same query + self.assert_shape(res, 0, 0) + + # so we need to re-select + res = await self.squery_values('SELECT line FROM "Log" ORDER BY line;') + self.assertEqual( + res, + [ + ["inserted all"], + ["inserted each a"], + ["inserted each b_0"], + ["inserted each b_1"], + ], + )