Skip to content

Commit

Permalink
Fix dml function in for loops in with blocks. (#8306)
Browse files Browse the repository at this point in the history
Given the schema:
```
type Bar {
    required property a -> int64;
};
function foo(x: int64) -> Bar {
    using ((insert Bar{ a := x }))
};
```

The query `with temp := (for x in {1,2,3} union(foo(x))) select temp.a` would produce:
```
{1, 1, 1, 2, 2, 2, 3, 3, 3}
```

A similar result would appear if the dml function contained `update` or `delete`.

This was caused by two related bugs:

1. Sets can be compiled twice if they are DML and contained within a for loop within a with block.

For normal DML statements, `CompilerContextLevel.dml_stmts` is used to track such CTEs and ensure that they are not duplicated. However, for inlined function arguments, this CTE was not tracked and was therefore duplicated and joined multiple times.

The fix was to track the inlined argument CTEs using a newly added `CompilerContextLevel.inline_dml_ctes`. Unlike the iterator CTEs, these need to be tracked by `PathId` instead of the `Set` directly. If the CTE is used in a view, then it may create a copy of the `Set` with the same path.

2. Update and Delete did not apply iterator path bond to the statement body.

This was not an issue with explicit DML, as the dml range could be relied on to merge in any iterators. However, this did not join in the new inlined argument CTE.
  • Loading branch information
dnwpark authored Feb 5, 2025
1 parent 9653f62 commit ab3ef0c
Show file tree
Hide file tree
Showing 4 changed files with 581 additions and 32 deletions.
11 changes: 11 additions & 0 deletions edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ class CompilerContextLevel(compiler.ContextLevel):
dml_stmts: Dict[Union[irast.MutatingStmt, irast.Set],
pgast.CommonTableExpr]

#: Inline DML functions may require additional CTEs.
#: Record such CTEs as well as the path used by their iterators.
#: This ensures CTEs are created only once, and that the correct
#: iterator bonds are applied.
inline_dml_ctes: dict[
irast.PathId,
tuple[irast.PathId, pgast.CommonTableExpr],
]

#: SQL statement corresponding to the IR statement
#: currently being compiled.
stmt: pgast.SelectStmt
Expand Down Expand Up @@ -363,6 +372,7 @@ def __init__(
self.type_inheritance_ctes = {}
self.ordered_type_ctes = []
self.dml_stmts = {}
self.inline_dml_ctes = {}
self.parent_rel = None
self.pending_query = None
self.materializing = frozenset()
Expand Down Expand Up @@ -405,6 +415,7 @@ def __init__(
self.type_inheritance_ctes = prevlevel.type_inheritance_ctes
self.ordered_type_ctes = prevlevel.ordered_type_ctes
self.dml_stmts = prevlevel.dml_stmts
self.inline_dml_ctes = prevlevel.inline_dml_ctes
self.parent_rel = prevlevel.parent_rel
self.pending_query = prevlevel.pending_query
self.materializing = prevlevel.materializing
Expand Down
3 changes: 3 additions & 0 deletions edb/pgsql/compiler/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ def process_update_body(
pathctx.put_path_source_rvar(
update_stmt, subject_path_id, table_relation
)
put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt)

update_cte.query = update_stmt

Expand Down Expand Up @@ -3105,6 +3106,8 @@ def process_delete_body(
"""
ctx.toplevel_stmt.append_cte(delete_cte)

put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query)

pointers = ir_stmt.links_to_delete[typeref.id]

for ptrref in pointers:
Expand Down
78 changes: 46 additions & 32 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3509,42 +3509,57 @@ def _compile_inlined_call_args(
if irutils.contains_dml(expr.body):
last_iterator = ctx.enclosing_cte_iterator

# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)
# If this function call has already been compiled to a CTE, don't
# recompile the arguments.
# (This will happen when a DML-containing funcion in a FOR loop is
# WITH bound, for example.)
if ir_set.path_id in ctx.inline_dml_ctes:
args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id]

_compile_call_args(ir_set, ctx=arg_ctx)
else:
# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)

# Add iterator identity
args_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('args'))
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
_compile_call_args(ir_set, ctx=arg_ctx)

for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
# Add iterator identity
args_pathid = irast.PathId.new_dummy(
ctx.env.aliases.get('args')
)
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
pathctx.put_path_bond(arg_ctx.rel, arg_path_id, iterator=True)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
)
pathctx.put_path_bond(
arg_ctx.rel, arg_path_id, iterator=True
)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
ctx.toplevel_stmt.append_cte(arg_cte)

ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte)

arg_iterator = pgast.IteratorCTE(
path_id=args_pathid,
Expand All @@ -3556,7 +3571,6 @@ def _compile_inlined_call_args(
),
iterator_bond=True,
)
ctx.toplevel_stmt.append_cte(arg_cte)

# Merge the new iterator
ctx.path_scope = ctx.path_scope.new_child()
Expand Down
Loading

0 comments on commit ab3ef0c

Please sign in to comment.