diff --git a/edb/pgsql/compiler/context.py b/edb/pgsql/compiler/context.py index efc32cc2554..95f1e2d2baa 100644 --- a/edb/pgsql/compiler/context.py +++ b/edb/pgsql/compiler/context.py @@ -229,6 +229,15 @@ class CompilerContextLevel(compiler.ContextLevel): dml_stmts: Dict[Union[irast.MutatingStmt, irast.Set], pgast.CommonTableExpr] + #: Inline DML functions may require additional CTEs. + #: Record such CTEs as well as the path used by their iterators. + #: This ensures CTEs are created only once, and that the correct + #: iterator bonds are applied. + inline_dml_ctes: dict[ + irast.PathId, + tuple[irast.PathId, pgast.CommonTableExpr], + ] + #: SQL statement corresponding to the IR statement #: currently being compiled. stmt: pgast.SelectStmt @@ -363,6 +372,7 @@ def __init__( self.type_inheritance_ctes = {} self.ordered_type_ctes = [] self.dml_stmts = {} + self.inline_dml_ctes = {} self.parent_rel = None self.pending_query = None self.materializing = frozenset() @@ -405,6 +415,7 @@ def __init__( self.type_inheritance_ctes = prevlevel.type_inheritance_ctes self.ordered_type_ctes = prevlevel.ordered_type_ctes self.dml_stmts = prevlevel.dml_stmts + self.inline_dml_ctes = prevlevel.inline_dml_ctes self.parent_rel = prevlevel.parent_rel self.pending_query = prevlevel.pending_query self.materializing = prevlevel.materializing diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index c443cbe433a..c65afc59cb7 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -1820,6 +1820,7 @@ def process_update_body( pathctx.put_path_source_rvar( update_stmt, subject_path_id, table_relation ) + put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt) update_cte.query = update_stmt @@ -3105,6 +3106,8 @@ def process_delete_body( """ ctx.toplevel_stmt.append_cte(delete_cte) + put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query) + pointers = ir_stmt.links_to_delete[typeref.id] for ptrref in pointers: diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 361de222023..ff1474b7fe1 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -3509,42 +3509,57 @@ def _compile_inlined_call_args( if irutils.contains_dml(expr.body): last_iterator = ctx.enclosing_cte_iterator - # Compile args into an iterator CTE - with ctx.newrel() as arg_ctx: - dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx) - clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx) + # If this function call has already been compiled to a CTE, don't + # recompile the arguments. + # (This will happen when a DML-containing funcion in a FOR loop is + # WITH bound, for example.) + if ir_set.path_id in ctx.inline_dml_ctes: + args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id] - _compile_call_args(ir_set, ctx=arg_ctx) + else: + # Compile args into an iterator CTE + with ctx.newrel() as arg_ctx: + dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx) + clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx) - # Add iterator identity - args_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('args')) - with arg_ctx.subrel() as args_pathid_ctx: - relctx.create_iterator_identity_for_path( - args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx - ) - args_id_rvar = relctx.rvar_for_rel( - args_pathid_ctx.rel, lateral=True, ctx=arg_ctx - ) - relctx.include_rvar( - arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx - ) + _compile_call_args(ir_set, ctx=arg_ctx) - for ir_arg in expr.args.values(): - arg_path_id = ir_arg.expr.path_id - # Ensure args appear in arg CTE - pathctx.get_path_output( - arg_ctx.rel, - arg_path_id, - aspect=pgce.PathAspect.VALUE, - env=arg_ctx.env, + # Add iterator identity + args_pathid = irast.PathId.new_dummy( + ctx.env.aliases.get('args') + ) + with arg_ctx.subrel() as args_pathid_ctx: + relctx.create_iterator_identity_for_path( + args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx + ) + args_id_rvar = relctx.rvar_for_rel( + args_pathid_ctx.rel, lateral=True, ctx=arg_ctx + ) + relctx.include_rvar( + arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx ) - pathctx.put_path_bond(arg_ctx.rel, arg_path_id, iterator=True) - arg_cte = pgast.CommonTableExpr( - name=ctx.env.aliases.get('args'), - query=arg_ctx.rel, - materialized=False, - ) + for ir_arg in expr.args.values(): + arg_path_id = ir_arg.expr.path_id + # Ensure args appear in arg CTE + pathctx.get_path_output( + arg_ctx.rel, + arg_path_id, + aspect=pgce.PathAspect.VALUE, + env=arg_ctx.env, + ) + pathctx.put_path_bond( + arg_ctx.rel, arg_path_id, iterator=True + ) + + arg_cte = pgast.CommonTableExpr( + name=ctx.env.aliases.get('args'), + query=arg_ctx.rel, + materialized=False, + ) + ctx.toplevel_stmt.append_cte(arg_cte) + + ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte) arg_iterator = pgast.IteratorCTE( path_id=args_pathid, @@ -3556,7 +3571,6 @@ def _compile_inlined_call_args( ), iterator_bond=True, ) - ctx.toplevel_stmt.append_cte(arg_cte) # Merge the new iterator ctx.path_scope = ctx.path_scope.new_child() diff --git a/tests/test_edgeql_functions_inline.py b/tests/test_edgeql_functions_inline.py index 9487ca4a87f..7e2d9a313c4 100644 --- a/tests/test_edgeql_functions_inline.py +++ b/tests/test_edgeql_functions_inline.py @@ -5592,6 +5592,155 @@ async def test_edgeql_functions_inline_insert_basic_10(self): [{'a': 1, 'b': 10}], ) + async def test_edgeql_functions_inline_insert_basic_11(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'with temp := foo(1)' + 'select temp.a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'with temp := (for x in {2, 3, 4} union (select foo(x)))' + 'select temp.a', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + await self.assert_query_result( + 'with temp := (if true then foo(5) else {})' + 'select temp.a', + [5], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then foo(6) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if true then {} else foo(7))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then {} else foo(8))' + 'select temp.a', + [8], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_basic_12(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'with temp := foo(1)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'with temp := (for x in {2, 3, 4} union (select foo(x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + await self.assert_query_result( + 'with temp := (if true then foo(5) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then foo(6) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if true then {} else foo(7))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then {} else foo(8))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8], + sort=True, + ) + async def test_edgeql_functions_inline_insert_iterator_01(self): await self.con.execute(''' create type Bar { @@ -8042,6 +8191,192 @@ async def reset_data(): sort=True, ) + async def test_edgeql_functions_inline_update_basic_06(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x })); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(0, 2)' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x-1, x)))' + 'select temp.a', + [0, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(0, 2) else {})' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(0, 2) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(0, 2))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(0, 2))' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_07(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x })); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(0, 2)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x-1, x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(0, 2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(0, 2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(0, 2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(0, 2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + async def test_edgeql_functions_inline_update_iterator_01(self): await self.con.execute(''' create type Bar { @@ -9973,6 +10308,192 @@ async def reset_data(): [], ) + async def test_edgeql_functions_inline_delete_basic_06(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(2)' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x)))' + 'select temp.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(2) else {})' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(2) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(2))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(2))' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_delete_basic_07(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(2)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + async def test_edgeql_functions_inline_delete_iterator_01(self): await self.con.execute(''' create type Bar {