From 23284839d07a155eb35199ffb7efe401222e0ae6 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 27 Jun 2024 14:25:51 -0400 Subject: [PATCH] Add OverlayOp enum. (#7501) --- edb/pgsql/compiler/config.py | 2 +- edb/pgsql/compiler/context.py | 12 +++++- edb/pgsql/compiler/dml.py | 70 +++++++++++++++++++++++---------- edb/pgsql/compiler/relctx.py | 74 ++++++++++++++++++----------------- 4 files changed, 99 insertions(+), 59 deletions(-) diff --git a/edb/pgsql/compiler/config.py b/edb/pgsql/compiler/config.py index bfea9053abc..44cc9f39048 100644 --- a/edb/pgsql/compiler/config.py +++ b/edb/pgsql/compiler/config.py @@ -582,7 +582,7 @@ def _rewrite_config_insert( relctx.add_type_rel_overlay( ir_set.typeref, - 'replace', + context.OverlayOp.REPLACE, overwrite_query, path_id=ir_set.path_id, ctx=ctx, diff --git a/edb/pgsql/compiler/context.py b/edb/pgsql/compiler/context.py index 7b101fcfd0b..310e26d88e9 100644 --- a/edb/pgsql/compiler/context.py +++ b/edb/pgsql/compiler/context.py @@ -44,6 +44,7 @@ import immutables as immu from edb.common import compiler +from edb.common import enum as s_enum from edb.pgsql import ast as pgast from edb.pgsql import params as pgparams @@ -87,8 +88,15 @@ class OutputFormat(enum.Enum): NO_STMT = pgast.SelectStmt() +class OverlayOp(s_enum.StrEnum): + UNION = 'union' + REPLACE = 'replace' + FILTER = 'filter' + EXCEPT = 'except' + + OverlayEntry = tuple[ - str, + OverlayOp, Union[pgast.BaseRelation, pgast.CommonTableExpr], 'irast.PathId', ] @@ -186,7 +194,7 @@ class RelOverlays: Tuple[uuid.UUID, str], Tuple[ Tuple[ - str, + OverlayOp, Union[pgast.BaseRelation, pgast.CommonTableExpr], irast.PathId, ], ... diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index 43ed18bd20d..896d82ec5c4 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -436,7 +436,7 @@ def fini_dml_stmt( assert len(parts.dml_ctes) == 1 cte = next(iter(parts.dml_ctes.values()))[0] relctx.add_type_rel_overlay( - ir_stmt.subject.typeref, 'union', cte, + ir_stmt.subject.typeref, context.OverlayOp.UNION, cte, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) elif isinstance(ir_stmt, irast.UpdateStmt): base_typeref = ir_stmt.subject.typeref.real_material_type @@ -461,11 +461,11 @@ def fini_dml_stmt( # First, filter out objects that have been updated, then union them # back in. (If we just did union, we'd see the old values also.) relctx.add_type_rel_overlay( - typeref, 'filter', cte, + typeref, context.OverlayOp.FILTER, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) relctx.add_type_rel_overlay( - typeref, 'union', cte, + typeref, context.OverlayOp.UNION, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) @@ -482,7 +482,7 @@ def fini_dml_stmt( stop_ref = base_typeref relctx.add_type_rel_overlay( - typeref, 'except', cte, + typeref, context.OverlayOp.EXCEPT, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) @@ -1549,10 +1549,17 @@ def compile_insert_else_body_failure_check( overlays_map = ctx.rel_overlays.type.get(None, immu.Map()) for k, overlays in overlays_map.items(): # Strip out filters, which we don't care about in this context - overlays = tuple([(k, r, p) for k, r, p in overlays if k != 'filter']) + overlays = tuple([ + (k, r, p) + for k, r, p in overlays + if k != context.OverlayOp.FILTER + ]) # Drop the initial set - if overlays and overlays[0][0] == 'union': - overlays = (('replace', *overlays[0][1:]), *overlays[1:]) + if overlays and overlays[0][0] == context.OverlayOp.UNION: + overlays = ( + (context.OverlayOp.REPLACE, *overlays[0][1:]), + *overlays[1:] + ) overlays_map = overlays_map.set(k, overlays) ctx.rel_overlays.type = ctx.rel_overlays.type.set(None, overlays_map) @@ -2606,8 +2613,13 @@ def process_link_update( # context to ensure that references to the link in the result # of this DML statement yield the expected results. relctx.add_ptr_rel_overlay( - mptrref, 'except', delcte, path_id=path_id.ptr_path(), - dml_stmts=ctx.dml_stmt_stack, ctx=ctx) + mptrref, + context.OverlayOp.EXCEPT, + delcte, + path_id=path_id.ptr_path(), + dml_stmts=ctx.dml_stmt_stack, + ctx=ctx + ) toplevel.append_cte(delcte) else: delqry = None @@ -2645,8 +2657,12 @@ def process_link_update( # to work without it subctx.rel_overlays = subctx.rel_overlays.copy() relctx.add_ptr_rel_overlay( - ptrref, 'except', delcte, path_id=path_id.ptr_path(), - ctx=subctx) + ptrref, + context.OverlayOp.EXCEPT, + delcte, + path_id=path_id.ptr_path(), + ctx=subctx + ) check_cte, _ = process_link_values( ir_stmt=ir_stmt, @@ -2790,14 +2806,22 @@ def register_overlays( # based filter to filter out links that were already present # and have been re-added. relctx.add_ptr_rel_overlay( - mptrref, 'filter', overlay_cte, dml_stmts=ctx.dml_stmt_stack, + mptrref, + context.OverlayOp.FILTER, + overlay_cte, + dml_stmts=ctx.dml_stmt_stack, path_id=path_id.ptr_path(), - ctx=octx) + ctx=octx + ) relctx.add_ptr_rel_overlay( - mptrref, 'union', overlay_cte, dml_stmts=ctx.dml_stmt_stack, + mptrref, + context.OverlayOp.UNION, + overlay_cte, + dml_stmts=ctx.dml_stmt_stack, path_id=path_id.ptr_path(), - ctx=octx) + ctx=octx + ) if policy_ctx: policy_ctx.rel_overlays = policy_ctx.rel_overlays.copy() @@ -3138,16 +3162,20 @@ def compile_trigger( overlays.extend(ov) # Handle deletions by turning except into union - # Drop 'filter', which is included by update but doesn't help us here + # Drop FILTER, which is included by update but doesn't help us here overlays = [ - ('union', *x[1:]) if x[0] == 'except' else x + ( + (context.OverlayOp.UNION, *x[1:]) + if x[0] == context.OverlayOp.EXCEPT + else x + ) for x in overlays - if x[0] != 'filter' + if x[0] != context.OverlayOp.FILTER ] - # Replace an initial union with 'replace', since we *don't* want whatever + # Replace an initial union with REPLACE, since we *don't* want whatever # already existed - assert overlays and overlays[0][0] == 'union' - overlays[0] = ('replace', *overlays[0][1:]) + assert overlays and overlays[0][0] == context.OverlayOp.UNION + overlays[0] = (context.OverlayOp.REPLACE, *overlays[0][1:]) # Produce a CTE containing all of the affected objects for this trigger with ctx.newrel() as ictx: diff --git a/edb/pgsql/compiler/relctx.py b/edb/pgsql/compiler/relctx.py index 53455e8c3f6..f693a371c0c 100644 --- a/edb/pgsql/compiler/relctx.py +++ b/edb/pgsql/compiler/relctx.py @@ -1566,7 +1566,7 @@ def range_for_material_objtype( pathctx.put_path_value_rvar(qry, sub_path_id, rvar) pathctx.put_path_source_rvar(qry, sub_path_id, rvar) - ops.append(('union', qry)) + ops.append((context.OverlayOp.UNION, qry)) rvar = range_from_queryset( ops, @@ -1623,7 +1623,7 @@ def range_for_material_objtype( pathctx.put_path_source_rvar(qry, path_id, rvar) pathctx.put_path_bond(qry, path_id) - set_ops.append(('union', qry)) + set_ops.append((context.OverlayOp.UNION, qry)) for op, cte, cte_path_id in overlays: rvar = rvar_for_rel(cte, typeref=typeref, ctx=ctx) @@ -1653,8 +1653,8 @@ def range_for_material_objtype( pathctx.put_path_source_rvar(qry2, path_id, qry_rvar) pathctx.put_path_bond(qry2, path_id) - if op == 'replace': - op = 'union' + if op == context.OverlayOp.REPLACE: + op = context.OverlayOp.UNION set_ops = [] set_ops.append((op, qry2)) @@ -1720,7 +1720,7 @@ def range_for_typeref( pathctx.put_path_bond(qry, path_id) - set_ops.append(('union', qry)) + set_ops.append((context.OverlayOp.UNION, qry)) rvar = range_from_queryset( set_ops, @@ -1796,7 +1796,7 @@ def anti_join( def range_from_queryset( - set_ops: Sequence[Tuple[str, pgast.SelectStmt]], + set_ops: Sequence[Tuple[context.OverlayOp, pgast.SelectStmt]], objname: sn.Name, *, prep_filter: Callable[ @@ -1815,7 +1815,7 @@ def range_from_queryset( qry = set_ops[0][1] for op, rarg in set_ops[1:]: - if op == 'filter': + if op == context.OverlayOp.FILTER: qry = wrap_set_op_query(qry, ctx=ctx) prep_filter(qry, rarg) anti_join(qry, rarg, path_id, ctx=ctx) @@ -2049,7 +2049,7 @@ def prep_filter(larg: pgast.SelectStmt, rarg: pgast.SelectStmt) -> None: qry.target_list.append( pgast.ResTarget(val=selexpr, name=output_colname)) - sub_set_ops.append(('union', qry)) + sub_set_ops.append((context.OverlayOp.UNION, qry)) # We need the identity var for semi_join to work and # the source rvar so that linkprops can be found here. @@ -2080,7 +2080,7 @@ def prep_filter(larg: pgast.SelectStmt, rarg: pgast.SelectStmt) -> None: pathctx.put_path_identity_var(sub_qry, path_id, var=target_ref) pathctx.put_path_source_rvar(sub_qry, path_id, sub_rvar) - set_ops.append(('union', sub_qry)) + set_ops.append((context.OverlayOp.UNION, sub_qry)) # Only fire off the overlays at the end of each expanded inhview. # This only matters when we are doing expand_inhviews, and prevents @@ -2182,12 +2182,13 @@ def rvar_for_rel( def _add_type_rel_overlay( - typeid: uuid.UUID, - op: str, - rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, - dml_stmts: Iterable[irast.MutatingLikeStmt] = (), - path_id: irast.PathId, - ctx: context.CompilerContextLevel) -> None: + typeid: uuid.UUID, + op: context.OverlayOp, + rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, + dml_stmts: Iterable[irast.MutatingLikeStmt] = (), + path_id: irast.PathId, + ctx: context.CompilerContextLevel +) -> None: entry = (op, rel, path_id) dml_stmts2 = dml_stmts if dml_stmts else (None,) # If there is a "global" overlay, and there is none for the @@ -2204,13 +2205,14 @@ def _add_type_rel_overlay( def add_type_rel_overlay( - typeref: irast.TypeRef, - op: str, - rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, - stop_ref: Optional[irast.TypeRef]=None, - dml_stmts: Iterable[irast.MutatingLikeStmt] = (), - path_id: irast.PathId, - ctx: context.CompilerContextLevel) -> None: + typeref: irast.TypeRef, + op: context.OverlayOp, + rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, + stop_ref: Optional[irast.TypeRef]=None, + dml_stmts: Iterable[irast.MutatingLikeStmt] = (), + path_id: irast.PathId, + ctx: context.CompilerContextLevel +) -> None: typeref = typeref.real_material_type objs = [typeref] if typeref.ancestors: @@ -2273,13 +2275,14 @@ def reuse_type_rel_overlays( def _add_ptr_rel_overlay( - typeid: uuid.UUID, - ptrref_name: str, - op: str, - rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, - dml_stmts: Iterable[irast.MutatingLikeStmt] = (), - path_id: irast.PathId, - ctx: context.CompilerContextLevel) -> None: + typeid: uuid.UUID, + ptrref_name: str, + op: context.OverlayOp, + rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, + dml_stmts: Iterable[irast.MutatingLikeStmt] = (), + path_id: irast.PathId, + ctx: context.CompilerContextLevel +) -> None: entry = (op, rel, path_id) dml_stmts2 = dml_stmts if dml_stmts else (None,) @@ -2298,12 +2301,13 @@ def _add_ptr_rel_overlay( def add_ptr_rel_overlay( - ptrref: irast.PointerRef, - op: str, - rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, - dml_stmts: Iterable[irast.MutatingLikeStmt] = (), - path_id: irast.PathId, - ctx: context.CompilerContextLevel) -> None: + ptrref: irast.PointerRef, + op: context.OverlayOp, + rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *, + dml_stmts: Iterable[irast.MutatingLikeStmt] = (), + path_id: irast.PathId, + ctx: context.CompilerContextLevel +) -> None: typeref = ptrref.out_source.real_material_type objs = [typeref]