Skip to content

Commit

Permalink
Fix dml function in for loops in with blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Feb 5, 2025
1 parent 9653f62 commit 9846296
Show file tree
Hide file tree
Showing 4 changed files with 581 additions and 32 deletions.
11 changes: 11 additions & 0 deletions edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ class CompilerContextLevel(compiler.ContextLevel):
dml_stmts: Dict[Union[irast.MutatingStmt, irast.Set],
pgast.CommonTableExpr]

#: Inline DML functions may require additional CTEs.
#: Record such CTEs as well as the path used by their iterators.
#: This ensures CTEs are created only once, and that the correct
#: iterator bonds are applied.
inline_dml_ctes: dict[
irast.PathId,
tuple[irast.PathId, pgast.CommonTableExpr],
]

#: SQL statement corresponding to the IR statement
#: currently being compiled.
stmt: pgast.SelectStmt
Expand Down Expand Up @@ -363,6 +372,7 @@ def __init__(
self.type_inheritance_ctes = {}
self.ordered_type_ctes = []
self.dml_stmts = {}
self.inline_dml_ctes = {}
self.parent_rel = None
self.pending_query = None
self.materializing = frozenset()
Expand Down Expand Up @@ -405,6 +415,7 @@ def __init__(
self.type_inheritance_ctes = prevlevel.type_inheritance_ctes
self.ordered_type_ctes = prevlevel.ordered_type_ctes
self.dml_stmts = prevlevel.dml_stmts
self.inline_dml_ctes = prevlevel.inline_dml_ctes
self.parent_rel = prevlevel.parent_rel
self.pending_query = prevlevel.pending_query
self.materializing = prevlevel.materializing
Expand Down
3 changes: 3 additions & 0 deletions edb/pgsql/compiler/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ def process_update_body(
pathctx.put_path_source_rvar(
update_stmt, subject_path_id, table_relation
)
put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt)

update_cte.query = update_stmt

Expand Down Expand Up @@ -3105,6 +3106,8 @@ def process_delete_body(
"""
ctx.toplevel_stmt.append_cte(delete_cte)

put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query)

pointers = ir_stmt.links_to_delete[typeref.id]

for ptrref in pointers:
Expand Down
78 changes: 46 additions & 32 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3509,42 +3509,57 @@ def _compile_inlined_call_args(
if irutils.contains_dml(expr.body):
last_iterator = ctx.enclosing_cte_iterator

# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)
# If this function call has already been compiled to a CTE, don't
# recompile the arguments.
# (This will happen when a DML-containing funcion in a FOR loop is
# WITH bound, for example.)
if ir_set.path_id in ctx.inline_dml_ctes:
args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id]

_compile_call_args(ir_set, ctx=arg_ctx)
else:
# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)

# Add iterator identity
args_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('args'))
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
_compile_call_args(ir_set, ctx=arg_ctx)

for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
# Add iterator identity
args_pathid = irast.PathId.new_dummy(
ctx.env.aliases.get('args')
)
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
pathctx.put_path_bond(arg_ctx.rel, arg_path_id, iterator=True)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
)
pathctx.put_path_bond(
arg_ctx.rel, arg_path_id, iterator=True
)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
ctx.toplevel_stmt.append_cte(arg_cte)

ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte)

arg_iterator = pgast.IteratorCTE(
path_id=args_pathid,
Expand All @@ -3556,7 +3571,6 @@ def _compile_inlined_call_args(
),
iterator_bond=True,
)
ctx.toplevel_stmt.append_cte(arg_cte)

# Merge the new iterator
ctx.path_scope = ctx.path_scope.new_child()
Expand Down
Loading

0 comments on commit 9846296

Please sign in to comment.