Skip to content

Commit

Permalink
Add OverlayOp enum. (#7501)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark authored Jun 27, 2024
1 parent cfd88cd commit 2328483
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 59 deletions.
2 changes: 1 addition & 1 deletion edb/pgsql/compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def _rewrite_config_insert(

relctx.add_type_rel_overlay(
ir_set.typeref,
'replace',
context.OverlayOp.REPLACE,
overwrite_query,
path_id=ir_set.path_id,
ctx=ctx,
Expand Down
12 changes: 10 additions & 2 deletions edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import immutables as immu

from edb.common import compiler
from edb.common import enum as s_enum

from edb.pgsql import ast as pgast
from edb.pgsql import params as pgparams
Expand Down Expand Up @@ -87,8 +88,15 @@ class OutputFormat(enum.Enum):
NO_STMT = pgast.SelectStmt()


class OverlayOp(s_enum.StrEnum):
UNION = 'union'
REPLACE = 'replace'
FILTER = 'filter'
EXCEPT = 'except'


OverlayEntry = tuple[
str,
OverlayOp,
Union[pgast.BaseRelation, pgast.CommonTableExpr],
'irast.PathId',
]
Expand Down Expand Up @@ -186,7 +194,7 @@ class RelOverlays:
Tuple[uuid.UUID, str],
Tuple[
Tuple[
str,
OverlayOp,
Union[pgast.BaseRelation, pgast.CommonTableExpr],
irast.PathId,
], ...
Expand Down
70 changes: 49 additions & 21 deletions edb/pgsql/compiler/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def fini_dml_stmt(
assert len(parts.dml_ctes) == 1
cte = next(iter(parts.dml_ctes.values()))[0]
relctx.add_type_rel_overlay(
ir_stmt.subject.typeref, 'union', cte,
ir_stmt.subject.typeref, context.OverlayOp.UNION, cte,
dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx)
elif isinstance(ir_stmt, irast.UpdateStmt):
base_typeref = ir_stmt.subject.typeref.real_material_type
Expand All @@ -461,11 +461,11 @@ def fini_dml_stmt(
# First, filter out objects that have been updated, then union them
# back in. (If we just did union, we'd see the old values also.)
relctx.add_type_rel_overlay(
typeref, 'filter', cte,
typeref, context.OverlayOp.FILTER, cte,
stop_ref=stop_ref,
dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx)
relctx.add_type_rel_overlay(
typeref, 'union', cte,
typeref, context.OverlayOp.UNION, cte,
stop_ref=stop_ref,
dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx)

Expand All @@ -482,7 +482,7 @@ def fini_dml_stmt(
stop_ref = base_typeref

relctx.add_type_rel_overlay(
typeref, 'except', cte,
typeref, context.OverlayOp.EXCEPT, cte,
stop_ref=stop_ref,
dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx)

Expand Down Expand Up @@ -1549,10 +1549,17 @@ def compile_insert_else_body_failure_check(
overlays_map = ctx.rel_overlays.type.get(None, immu.Map())
for k, overlays in overlays_map.items():
# Strip out filters, which we don't care about in this context
overlays = tuple([(k, r, p) for k, r, p in overlays if k != 'filter'])
overlays = tuple([
(k, r, p)
for k, r, p in overlays
if k != context.OverlayOp.FILTER
])
# Drop the initial set
if overlays and overlays[0][0] == 'union':
overlays = (('replace', *overlays[0][1:]), *overlays[1:])
if overlays and overlays[0][0] == context.OverlayOp.UNION:
overlays = (
(context.OverlayOp.REPLACE, *overlays[0][1:]),
*overlays[1:]
)
overlays_map = overlays_map.set(k, overlays)

ctx.rel_overlays.type = ctx.rel_overlays.type.set(None, overlays_map)
Expand Down Expand Up @@ -2606,8 +2613,13 @@ def process_link_update(
# context to ensure that references to the link in the result
# of this DML statement yield the expected results.
relctx.add_ptr_rel_overlay(
mptrref, 'except', delcte, path_id=path_id.ptr_path(),
dml_stmts=ctx.dml_stmt_stack, ctx=ctx)
mptrref,
context.OverlayOp.EXCEPT,
delcte,
path_id=path_id.ptr_path(),
dml_stmts=ctx.dml_stmt_stack,
ctx=ctx
)
toplevel.append_cte(delcte)
else:
delqry = None
Expand Down Expand Up @@ -2645,8 +2657,12 @@ def process_link_update(
# to work without it
subctx.rel_overlays = subctx.rel_overlays.copy()
relctx.add_ptr_rel_overlay(
ptrref, 'except', delcte, path_id=path_id.ptr_path(),
ctx=subctx)
ptrref,
context.OverlayOp.EXCEPT,
delcte,
path_id=path_id.ptr_path(),
ctx=subctx
)

check_cte, _ = process_link_values(
ir_stmt=ir_stmt,
Expand Down Expand Up @@ -2790,14 +2806,22 @@ def register_overlays(
# based filter to filter out links that were already present
# and have been re-added.
relctx.add_ptr_rel_overlay(
mptrref, 'filter', overlay_cte, dml_stmts=ctx.dml_stmt_stack,
mptrref,
context.OverlayOp.FILTER,
overlay_cte,
dml_stmts=ctx.dml_stmt_stack,
path_id=path_id.ptr_path(),
ctx=octx)
ctx=octx
)

relctx.add_ptr_rel_overlay(
mptrref, 'union', overlay_cte, dml_stmts=ctx.dml_stmt_stack,
mptrref,
context.OverlayOp.UNION,
overlay_cte,
dml_stmts=ctx.dml_stmt_stack,
path_id=path_id.ptr_path(),
ctx=octx)
ctx=octx
)

if policy_ctx:
policy_ctx.rel_overlays = policy_ctx.rel_overlays.copy()
Expand Down Expand Up @@ -3138,16 +3162,20 @@ def compile_trigger(
overlays.extend(ov)

# Handle deletions by turning except into union
# Drop 'filter', which is included by update but doesn't help us here
# Drop FILTER, which is included by update but doesn't help us here
overlays = [
('union', *x[1:]) if x[0] == 'except' else x
(
(context.OverlayOp.UNION, *x[1:])
if x[0] == context.OverlayOp.EXCEPT
else x
)
for x in overlays
if x[0] != 'filter'
if x[0] != context.OverlayOp.FILTER
]
# Replace an initial union with 'replace', since we *don't* want whatever
# Replace an initial union with REPLACE, since we *don't* want whatever
# already existed
assert overlays and overlays[0][0] == 'union'
overlays[0] = ('replace', *overlays[0][1:])
assert overlays and overlays[0][0] == context.OverlayOp.UNION
overlays[0] = (context.OverlayOp.REPLACE, *overlays[0][1:])

# Produce a CTE containing all of the affected objects for this trigger
with ctx.newrel() as ictx:
Expand Down
74 changes: 39 additions & 35 deletions edb/pgsql/compiler/relctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ def range_for_material_objtype(
pathctx.put_path_value_rvar(qry, sub_path_id, rvar)
pathctx.put_path_source_rvar(qry, sub_path_id, rvar)

ops.append(('union', qry))
ops.append((context.OverlayOp.UNION, qry))

rvar = range_from_queryset(
ops,
Expand Down Expand Up @@ -1623,7 +1623,7 @@ def range_for_material_objtype(
pathctx.put_path_source_rvar(qry, path_id, rvar)
pathctx.put_path_bond(qry, path_id)

set_ops.append(('union', qry))
set_ops.append((context.OverlayOp.UNION, qry))

for op, cte, cte_path_id in overlays:
rvar = rvar_for_rel(cte, typeref=typeref, ctx=ctx)
Expand Down Expand Up @@ -1653,8 +1653,8 @@ def range_for_material_objtype(
pathctx.put_path_source_rvar(qry2, path_id, qry_rvar)
pathctx.put_path_bond(qry2, path_id)

if op == 'replace':
op = 'union'
if op == context.OverlayOp.REPLACE:
op = context.OverlayOp.UNION
set_ops = []
set_ops.append((op, qry2))

Expand Down Expand Up @@ -1720,7 +1720,7 @@ def range_for_typeref(

pathctx.put_path_bond(qry, path_id)

set_ops.append(('union', qry))
set_ops.append((context.OverlayOp.UNION, qry))

rvar = range_from_queryset(
set_ops,
Expand Down Expand Up @@ -1796,7 +1796,7 @@ def anti_join(


def range_from_queryset(
set_ops: Sequence[Tuple[str, pgast.SelectStmt]],
set_ops: Sequence[Tuple[context.OverlayOp, pgast.SelectStmt]],
objname: sn.Name,
*,
prep_filter: Callable[
Expand All @@ -1815,7 +1815,7 @@ def range_from_queryset(
qry = set_ops[0][1]

for op, rarg in set_ops[1:]:
if op == 'filter':
if op == context.OverlayOp.FILTER:
qry = wrap_set_op_query(qry, ctx=ctx)
prep_filter(qry, rarg)
anti_join(qry, rarg, path_id, ctx=ctx)
Expand Down Expand Up @@ -2049,7 +2049,7 @@ def prep_filter(larg: pgast.SelectStmt, rarg: pgast.SelectStmt) -> None:
qry.target_list.append(
pgast.ResTarget(val=selexpr, name=output_colname))

sub_set_ops.append(('union', qry))
sub_set_ops.append((context.OverlayOp.UNION, qry))

# We need the identity var for semi_join to work and
# the source rvar so that linkprops can be found here.
Expand Down Expand Up @@ -2080,7 +2080,7 @@ def prep_filter(larg: pgast.SelectStmt, rarg: pgast.SelectStmt) -> None:
pathctx.put_path_identity_var(sub_qry, path_id, var=target_ref)
pathctx.put_path_source_rvar(sub_qry, path_id, sub_rvar)

set_ops.append(('union', sub_qry))
set_ops.append((context.OverlayOp.UNION, sub_qry))

# Only fire off the overlays at the end of each expanded inhview.
# This only matters when we are doing expand_inhviews, and prevents
Expand Down Expand Up @@ -2182,12 +2182,13 @@ def rvar_for_rel(


def _add_type_rel_overlay(
typeid: uuid.UUID,
op: str,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel) -> None:
typeid: uuid.UUID,
op: context.OverlayOp,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel
) -> None:
entry = (op, rel, path_id)
dml_stmts2 = dml_stmts if dml_stmts else (None,)
# If there is a "global" overlay, and there is none for the
Expand All @@ -2204,13 +2205,14 @@ def _add_type_rel_overlay(


def add_type_rel_overlay(
typeref: irast.TypeRef,
op: str,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
stop_ref: Optional[irast.TypeRef]=None,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel) -> None:
typeref: irast.TypeRef,
op: context.OverlayOp,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
stop_ref: Optional[irast.TypeRef]=None,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel
) -> None:
typeref = typeref.real_material_type
objs = [typeref]
if typeref.ancestors:
Expand Down Expand Up @@ -2273,13 +2275,14 @@ def reuse_type_rel_overlays(


def _add_ptr_rel_overlay(
typeid: uuid.UUID,
ptrref_name: str,
op: str,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel) -> None:
typeid: uuid.UUID,
ptrref_name: str,
op: context.OverlayOp,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel
) -> None:

entry = (op, rel, path_id)
dml_stmts2 = dml_stmts if dml_stmts else (None,)
Expand All @@ -2298,12 +2301,13 @@ def _add_ptr_rel_overlay(


def add_ptr_rel_overlay(
ptrref: irast.PointerRef,
op: str,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel) -> None:
ptrref: irast.PointerRef,
op: context.OverlayOp,
rel: Union[pgast.BaseRelation, pgast.CommonTableExpr], *,
dml_stmts: Iterable[irast.MutatingLikeStmt] = (),
path_id: irast.PathId,
ctx: context.CompilerContextLevel
) -> None:

typeref = ptrref.out_source.real_material_type
objs = [typeref]
Expand Down

0 comments on commit 2328483

Please sign in to comment.