From ab3ef0c9d016ed6cb6f37285ba9a6344b93190ea Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 5 Feb 2025 16:50:15 -0500 Subject: [PATCH] Fix dml function in for loops in with blocks. (#8306) Given the schema: ``` type Bar { required property a -> int64; }; function foo(x: int64) -> Bar { using ((insert Bar{ a := x })) }; ``` The query `with temp := (for x in {1,2,3} union(foo(x))) select temp.a` would produce: ``` {1, 1, 1, 2, 2, 2, 3, 3, 3} ``` A similar result would appear if the dml function contained `update` or `delete`. This was caused by two related bugs: 1. Sets can be compiled twice if they are DML and contained within a for loop within a with block. For normal DML statements, `CompilerContextLevel.dml_stmts` is used to track such CTEs and ensure that they are not duplicated. However, for inlined function arguments, this CTE was not tracked and was therefore duplicated and joined multiple times. The fix was to track the inlined argument CTEs using a newly added `CompilerContextLevel.inline_dml_ctes`. Unlike the iterator CTEs, these need to be tracked by `PathId` instead of the `Set` directly. If the CTE is used in a view, then it may create a copy of the `Set` with the same path. 2. Update and Delete did not apply iterator path bond to the statement body. This was not an issue with explicit DML, as the dml range could be relied on to merge in any iterators. However, this did not join in the new inlined argument CTE. --- edb/pgsql/compiler/context.py | 11 + edb/pgsql/compiler/dml.py | 3 + edb/pgsql/compiler/relgen.py | 78 ++-- tests/test_edgeql_functions_inline.py | 521 ++++++++++++++++++++++++++ 4 files changed, 581 insertions(+), 32 deletions(-) diff --git a/edb/pgsql/compiler/context.py b/edb/pgsql/compiler/context.py index efc32cc2554..95f1e2d2baa 100644 --- a/edb/pgsql/compiler/context.py +++ b/edb/pgsql/compiler/context.py @@ -229,6 +229,15 @@ class CompilerContextLevel(compiler.ContextLevel): dml_stmts: Dict[Union[irast.MutatingStmt, irast.Set], pgast.CommonTableExpr] + #: Inline DML functions may require additional CTEs. + #: Record such CTEs as well as the path used by their iterators. + #: This ensures CTEs are created only once, and that the correct + #: iterator bonds are applied. + inline_dml_ctes: dict[ + irast.PathId, + tuple[irast.PathId, pgast.CommonTableExpr], + ] + #: SQL statement corresponding to the IR statement #: currently being compiled. stmt: pgast.SelectStmt @@ -363,6 +372,7 @@ def __init__( self.type_inheritance_ctes = {} self.ordered_type_ctes = [] self.dml_stmts = {} + self.inline_dml_ctes = {} self.parent_rel = None self.pending_query = None self.materializing = frozenset() @@ -405,6 +415,7 @@ def __init__( self.type_inheritance_ctes = prevlevel.type_inheritance_ctes self.ordered_type_ctes = prevlevel.ordered_type_ctes self.dml_stmts = prevlevel.dml_stmts + self.inline_dml_ctes = prevlevel.inline_dml_ctes self.parent_rel = prevlevel.parent_rel self.pending_query = prevlevel.pending_query self.materializing = prevlevel.materializing diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index c443cbe433a..c65afc59cb7 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -1820,6 +1820,7 @@ def process_update_body( pathctx.put_path_source_rvar( update_stmt, subject_path_id, table_relation ) + put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt) update_cte.query = update_stmt @@ -3105,6 +3106,8 @@ def process_delete_body( """ ctx.toplevel_stmt.append_cte(delete_cte) + put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query) + pointers = ir_stmt.links_to_delete[typeref.id] for ptrref in pointers: diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 361de222023..ff1474b7fe1 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -3509,42 +3509,57 @@ def _compile_inlined_call_args( if irutils.contains_dml(expr.body): last_iterator = ctx.enclosing_cte_iterator - # Compile args into an iterator CTE - with ctx.newrel() as arg_ctx: - dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx) - clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx) + # If this function call has already been compiled to a CTE, don't + # recompile the arguments. + # (This will happen when a DML-containing funcion in a FOR loop is + # WITH bound, for example.) + if ir_set.path_id in ctx.inline_dml_ctes: + args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id] - _compile_call_args(ir_set, ctx=arg_ctx) + else: + # Compile args into an iterator CTE + with ctx.newrel() as arg_ctx: + dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx) + clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx) - # Add iterator identity - args_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('args')) - with arg_ctx.subrel() as args_pathid_ctx: - relctx.create_iterator_identity_for_path( - args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx - ) - args_id_rvar = relctx.rvar_for_rel( - args_pathid_ctx.rel, lateral=True, ctx=arg_ctx - ) - relctx.include_rvar( - arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx - ) + _compile_call_args(ir_set, ctx=arg_ctx) - for ir_arg in expr.args.values(): - arg_path_id = ir_arg.expr.path_id - # Ensure args appear in arg CTE - pathctx.get_path_output( - arg_ctx.rel, - arg_path_id, - aspect=pgce.PathAspect.VALUE, - env=arg_ctx.env, + # Add iterator identity + args_pathid = irast.PathId.new_dummy( + ctx.env.aliases.get('args') + ) + with arg_ctx.subrel() as args_pathid_ctx: + relctx.create_iterator_identity_for_path( + args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx + ) + args_id_rvar = relctx.rvar_for_rel( + args_pathid_ctx.rel, lateral=True, ctx=arg_ctx + ) + relctx.include_rvar( + arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx ) - pathctx.put_path_bond(arg_ctx.rel, arg_path_id, iterator=True) - arg_cte = pgast.CommonTableExpr( - name=ctx.env.aliases.get('args'), - query=arg_ctx.rel, - materialized=False, - ) + for ir_arg in expr.args.values(): + arg_path_id = ir_arg.expr.path_id + # Ensure args appear in arg CTE + pathctx.get_path_output( + arg_ctx.rel, + arg_path_id, + aspect=pgce.PathAspect.VALUE, + env=arg_ctx.env, + ) + pathctx.put_path_bond( + arg_ctx.rel, arg_path_id, iterator=True + ) + + arg_cte = pgast.CommonTableExpr( + name=ctx.env.aliases.get('args'), + query=arg_ctx.rel, + materialized=False, + ) + ctx.toplevel_stmt.append_cte(arg_cte) + + ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte) arg_iterator = pgast.IteratorCTE( path_id=args_pathid, @@ -3556,7 +3571,6 @@ def _compile_inlined_call_args( ), iterator_bond=True, ) - ctx.toplevel_stmt.append_cte(arg_cte) # Merge the new iterator ctx.path_scope = ctx.path_scope.new_child() diff --git a/tests/test_edgeql_functions_inline.py b/tests/test_edgeql_functions_inline.py index 9487ca4a87f..7e2d9a313c4 100644 --- a/tests/test_edgeql_functions_inline.py +++ b/tests/test_edgeql_functions_inline.py @@ -5592,6 +5592,155 @@ async def test_edgeql_functions_inline_insert_basic_10(self): [{'a': 1, 'b': 10}], ) + async def test_edgeql_functions_inline_insert_basic_11(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'with temp := foo(1)' + 'select temp.a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'with temp := (for x in {2, 3, 4} union (select foo(x)))' + 'select temp.a', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + await self.assert_query_result( + 'with temp := (if true then foo(5) else {})' + 'select temp.a', + [5], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then foo(6) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if true then {} else foo(7))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then {} else foo(8))' + 'select temp.a', + [8], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_basic_12(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'with temp := foo(1)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'with temp := (for x in {2, 3, 4} union (select foo(x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + await self.assert_query_result( + 'with temp := (if true then foo(5) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then foo(6) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if true then {} else foo(7))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'with temp := (if false then {} else foo(8))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8], + sort=True, + ) + async def test_edgeql_functions_inline_insert_iterator_01(self): await self.con.execute(''' create type Bar { @@ -8042,6 +8191,192 @@ async def reset_data(): sort=True, ) + async def test_edgeql_functions_inline_update_basic_06(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x })); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(0, 2)' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x-1, x)))' + 'select temp.a', + [0, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(0, 2) else {})' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(0, 2) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(0, 2))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(0, 2))' + 'select temp.a', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_07(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x })); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(0, 2)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x-1, x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(0, 2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(0, 2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(0, 2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(0, 2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3, 4, 5], + sort=True, + ) + async def test_edgeql_functions_inline_update_iterator_01(self): await self.con.execute(''' create type Bar { @@ -9973,6 +10308,192 @@ async def reset_data(): [], ) + async def test_edgeql_functions_inline_delete_basic_06(self): + # Check dml function in with clause has effect + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(2)' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x)))' + 'select temp.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(2) else {})' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(2) else {})' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(2))' + 'select temp.a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(2))' + 'select temp.a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_delete_basic_07(self): + # Check dml function in with clause has effect but is not used + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Bar{a := 5}; + ''') + + await reset_data() + await self.assert_query_result( + 'with temp := foo(2)' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (for x in {1, 2, 3} union (select foo(x)))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [4, 5], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'with temp := (if true then foo(2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then foo(2) else {})' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if true then {} else foo(2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'with temp := (if false then {} else foo(2))' + 'select 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [3, 4, 5], + sort=True, + ) + async def test_edgeql_functions_inline_delete_iterator_01(self): await self.con.execute(''' create type Bar {