Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dml function in for loops in with blocks. #8306

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ class CompilerContextLevel(compiler.ContextLevel):
dml_stmts: Dict[Union[irast.MutatingStmt, irast.Set],
pgast.CommonTableExpr]

#: Inline DML functions may require additional CTEs.
#: Record such CTEs as well as the path used by their iterators.
#: This ensures CTEs are created only once, and that the correct
#: iterator bonds are applied.
inline_dml_ctes: dict[
irast.PathId,
tuple[irast.PathId, pgast.CommonTableExpr],
]

#: SQL statement corresponding to the IR statement
#: currently being compiled.
stmt: pgast.SelectStmt
Expand Down Expand Up @@ -363,6 +372,7 @@ def __init__(
self.type_inheritance_ctes = {}
self.ordered_type_ctes = []
self.dml_stmts = {}
self.inline_dml_ctes = {}
self.parent_rel = None
self.pending_query = None
self.materializing = frozenset()
Expand Down Expand Up @@ -405,6 +415,7 @@ def __init__(
self.type_inheritance_ctes = prevlevel.type_inheritance_ctes
self.ordered_type_ctes = prevlevel.ordered_type_ctes
self.dml_stmts = prevlevel.dml_stmts
self.inline_dml_ctes = prevlevel.inline_dml_ctes
self.parent_rel = prevlevel.parent_rel
self.pending_query = prevlevel.pending_query
self.materializing = prevlevel.materializing
Expand Down
3 changes: 3 additions & 0 deletions edb/pgsql/compiler/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ def process_update_body(
pathctx.put_path_source_rvar(
update_stmt, subject_path_id, table_relation
)
put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt)

update_cte.query = update_stmt

Expand Down Expand Up @@ -3105,6 +3106,8 @@ def process_delete_body(
"""
ctx.toplevel_stmt.append_cte(delete_cte)

put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query)

pointers = ir_stmt.links_to_delete[typeref.id]

for ptrref in pointers:
Expand Down
78 changes: 46 additions & 32 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3509,42 +3509,57 @@ def _compile_inlined_call_args(
if irutils.contains_dml(expr.body):
last_iterator = ctx.enclosing_cte_iterator

# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)
# If this function call has already been compiled to a CTE, don't
# recompile the arguments.
# (This will happen when a DML-containing funcion in a FOR loop is
# WITH bound, for example.)
if ir_set.path_id in ctx.inline_dml_ctes:
args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id]

_compile_call_args(ir_set, ctx=arg_ctx)
else:
# Compile args into an iterator CTE
with ctx.newrel() as arg_ctx:
dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx)
clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx)

# Add iterator identity
args_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('args'))
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
_compile_call_args(ir_set, ctx=arg_ctx)

for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
# Add iterator identity
args_pathid = irast.PathId.new_dummy(
ctx.env.aliases.get('args')
)
with arg_ctx.subrel() as args_pathid_ctx:
relctx.create_iterator_identity_for_path(
args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx
)
args_id_rvar = relctx.rvar_for_rel(
args_pathid_ctx.rel, lateral=True, ctx=arg_ctx
)
relctx.include_rvar(
arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx
)
pathctx.put_path_bond(arg_ctx.rel, arg_path_id, iterator=True)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
for ir_arg in expr.args.values():
arg_path_id = ir_arg.expr.path_id
# Ensure args appear in arg CTE
pathctx.get_path_output(
arg_ctx.rel,
arg_path_id,
aspect=pgce.PathAspect.VALUE,
env=arg_ctx.env,
)
pathctx.put_path_bond(
arg_ctx.rel, arg_path_id, iterator=True
)

arg_cte = pgast.CommonTableExpr(
name=ctx.env.aliases.get('args'),
query=arg_ctx.rel,
materialized=False,
)
ctx.toplevel_stmt.append_cte(arg_cte)

ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte)

arg_iterator = pgast.IteratorCTE(
path_id=args_pathid,
Expand All @@ -3556,7 +3571,6 @@ def _compile_inlined_call_args(
),
iterator_bond=True,
)
ctx.toplevel_stmt.append_cte(arg_cte)

# Merge the new iterator
ctx.path_scope = ctx.path_scope.new_child()
Expand Down
Loading