Skip to content

Commit

Permalink
Fix update rewrites on types that are children of updated type (#7073)
Browse files Browse the repository at this point in the history
Closes #7048
Closes #6801

When type X is updated, but has child Y with an update rewrite, we were using type X as the subject where we should have been using Y.
  • Loading branch information
aljazerzen committed Mar 19, 2024
1 parent 3793cf8 commit 72fbd32
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 61 deletions.
142 changes: 81 additions & 61 deletions edb/edgeql/compiler/viewgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,15 @@ def _raise_on_missing(
)


@dataclasses.dataclass(kw_only=True, repr=False, eq=False)
class RewriteContext:
specified_ptrs: Set[sn.UnqualName]
kind: qltypes.RewriteKind

base_type: s_objtypes.ObjectType
shape_type: s_objtypes.ObjectType


def _compile_rewrites(
specified_ptrs: Set[sn.UnqualName],
kind: qltypes.RewriteKind,
Expand All @@ -910,18 +919,24 @@ def _compile_rewrites(
ctx: context.ContextLevel,
) -> Optional[irast.Rewrites]:
# init
anchors = None
r_ctx = RewriteContext(
specified_ptrs=specified_ptrs,
kind=kind,
base_type=stype,
shape_type=view_scls,
)

# Computing anchors isn't cheap, so we want to only do it once,
# and only do it when it is necessary.
def get_anchors() -> RewriteAnchors:
nonlocal anchors
if anchors is None:
anchors = prepare_rewrite_anchors(specified_ptrs, kind, stype, ctx)
return anchors
anchors: Dict[s_objtypes.ObjectType, RewriteAnchors] = {}

def get_anchors(stype: s_objtypes.ObjectType) -> RewriteAnchors:
if stype not in anchors:
anchors[stype] = prepare_rewrite_anchors(stype, r_ctx, s_ctx, ctx)
return anchors[stype]

rewrites = _compile_rewrites_for_stype(
stype, kind, view_scls, ir_set, get_anchors, s_ctx, ctx=ctx
stype, kind, ir_set, get_anchors, s_ctx, ctx=ctx
)

if kind == qltypes.RewriteKind.Insert:
Expand All @@ -937,7 +952,7 @@ def get_anchors() -> RewriteAnchors:
# statement later.

rewrites_by_type = _compile_rewrites_of_children(
stype, rewrites, kind, view_scls, ir_set, get_anchors, s_ctx, ctx
stype, rewrites, kind, ir_set, get_anchors, s_ctx, ctx
)

else:
Expand Down Expand Up @@ -979,12 +994,13 @@ def get_anchors() -> RewriteAnchors:

by_type[ty][pn] = (ptr_set, ptrref.real_material_ptr)

if not anchors:
anc = next(iter(anchors.values()), None)
if not anc:
return None

return irast.Rewrites(
subject_path_id=anchors[0].path_id,
old_path_id=anchors[2].path_id if anchors[2] else None,
subject_path_id=anc.subject_set.path_id,
old_path_id=anc.old_set.path_id if anc.old_set else None,
by_type=by_type,
)

Expand All @@ -993,9 +1009,8 @@ def _compile_rewrites_of_children(
stype: s_objtypes.ObjectType,
parent_rewrites: Dict[sn.UnqualName, EarlyShapePtr],
kind: qltypes.RewriteKind,
view_scls: s_objtypes.ObjectType,
ir_set: irast.Set,
get_anchors: Callable[[], RewriteAnchors],
get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors],
s_ctx: ShapeContext,
ctx: context.ContextLevel,
) -> Dict[irast.TypeRef, Dict[sn.UnqualName, EarlyShapePtr]]:
Expand All @@ -1015,7 +1030,7 @@ def _compile_rewrites_of_children(
child_rewrites = parent_rewrites.copy()
# override with rewrites defined here
rewrites_defined_here = _compile_rewrites_for_stype(
child, kind, view_scls, ir_set, get_anchors, s_ctx,
child, kind, ir_set, get_anchors, s_ctx,
already_defined_rewrites=child_rewrites,
ctx=ctx
)
Expand All @@ -1027,7 +1042,6 @@ def _compile_rewrites_of_children(
child,
child_rewrites,
kind,
view_scls,
ir_set,
get_anchors,
s_ctx,
Expand All @@ -1041,9 +1055,8 @@ def _compile_rewrites_of_children(
def _compile_rewrites_for_stype(
stype: s_objtypes.ObjectType,
kind: qltypes.RewriteKind,
view_scls: s_objtypes.ObjectType,
ir_set: irast.Set,
get_anchors: Callable[[], RewriteAnchors],
get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors],
s_ctx: ShapeContext,
*,
already_defined_rewrites: Optional[
Expand Down Expand Up @@ -1091,27 +1104,7 @@ def _compile_rewrites_for_stype(
):
continue

subject_set, specified_set, old_set = get_anchors()

rewrite_view = view_scls
if stype != view_scls.get_nearest_non_derived_parent(schema):
# FIXME: Caching?
rewrite_view = downcast(
s_objtypes.ObjectType,
schemactx.derive_view(
stype,
exprtype=s_ctx.exprtype,
ctx=ctx,
)
)
subject_set = setgen.class_set(
rewrite_view, path_id=subject_set.path_id, ctx=ctx)
if old_set:
old_set = setgen.class_set(
rewrite_view, path_id=old_set.path_id, ctx=ctx)

ir_set = setgen.class_set(
rewrite_view, path_id=ir_set.path_id, ctx=ctx)
anchors = get_anchors(stype)

rewrite_expr = rewrite.get_expr(ctx.env.schema)
assert rewrite_expr
Expand All @@ -1120,27 +1113,27 @@ def _compile_rewrites_for_stype(
scopectx.active_rewrites |= {stype}

# prepare context
scopectx.partial_path_prefix = subject_set
scopectx.partial_path_prefix = anchors.subject_set
nanchors = {}
nanchors["__specified__"] = specified_set
nanchors["__subject__"] = subject_set
if old_set:
nanchors["__old__"] = old_set
nanchors["__specified__"] = anchors.specified_set
nanchors["__subject__"] = anchors.subject_set
if anchors.old_set:
nanchors["__old__"] = anchors.old_set

for key, anchor in nanchors.items():
scopectx.path_scope.attach_path(
anchor.path_id, context=None,
optional=(anchor is subject_set),
optional=(anchor is anchors.subject_set),
)
scopectx.iterator_path_ids |= {anchor.path_id}
scopectx.anchors[key] = anchor

# XXX: I am pretty sure this must be wrong, but we get
# a failure without due to volatility issues in
# test_edgeql_rewrites_16
scopectx.env.singletons.append(subject_set.path_id)
scopectx.env.singletons.append(anchors.subject_set.path_id)

ctx.path_scope.factoring_allowlist.add(subject_set.path_id)
ctx.path_scope.factoring_allowlist.add(anchors.subject_set.path_id)

# prepare expression
ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
Expand All @@ -1157,16 +1150,16 @@ def _compile_rewrites_for_stype(
)
shape_ql_desc = _shape_el_ql_to_shape_el_desc(
shape_ql,
source=rewrite_view,
source=anchors.rewrite_type,
s_ctx=s_ctx,
ctx=scopectx,
)

# compile as normal shape element
pointer, ptr_set = _normalize_view_ptr_expr(
subject_set,
anchors.subject_set,
shape_ql_desc,
rewrite_view,
anchors.rewrite_type,
path_id=path_id,
from_default=True,
s_ctx=s_ctx,
Expand All @@ -1178,13 +1171,19 @@ def _compile_rewrites_for_stype(
return res


RewriteAnchors = Tuple[irast.Set, irast.Set, Optional[irast.Set]]
@dataclasses.dataclass(kw_only=True, repr=False, eq=False)
class RewriteAnchors:
subject_set: irast.Set
specified_set: irast.Set
old_set: Optional[irast.Set]

rewrite_type: s_objtypes.ObjectType


def prepare_rewrite_anchors(
specified_ptrs: Set[sn.UnqualName],
rewrite_kind: qltypes.RewriteKind,
stype: s_objtypes.ObjectType,
r_ctx: RewriteContext,
s_ctx: ShapeContext,
ctx: context.ContextLevel,
) -> RewriteAnchors:
schema = ctx.env.schema
Expand All @@ -1193,11 +1192,11 @@ def prepare_rewrite_anchors(
# TODO: Do we really need a separate path id for __subject__?
subject_name = sn.QualName("__derived__", "__subject__")
subject_path_id = irast.PathId.from_type(
schema, stype, typename=subject_name, namespace=ctx.path_id_namespace,
env=ctx.env,
schema, stype, typename=subject_name,
namespace=ctx.path_id_namespace, env=ctx.env,
)
subject_set = setgen.new_set(
stype=stype, path_id=subject_path_id, ctx=ctx
subject_set = setgen.class_set(
stype, path_id=subject_path_id, ctx=ctx
)

# init reference to std::bool
Expand Down Expand Up @@ -1225,7 +1224,7 @@ def prepare_rewrite_anchors(
name=pn.name,
val=setgen.ensure_set(
irast.BooleanConstant(
value=str(pn in specified_ptrs),
value=str(pn in r_ctx.specified_ptrs),
typeref=bool_path.target,
),
ctx=ctx
Expand All @@ -1238,11 +1237,11 @@ def prepare_rewrite_anchors(
)

# init set for __old__
if rewrite_kind == qltypes.RewriteKind.Update:
if r_ctx.kind == qltypes.RewriteKind.Update:
old_name = sn.QualName("__derived__", "__old__")
old_path_id = irast.PathId.from_type(
schema, stype, typename=old_name, namespace=ctx.path_id_namespace,
env=ctx.env,
schema, stype, typename=old_name,
namespace=ctx.path_id_namespace, env=ctx.env,
)
old_set = setgen.new_set(
stype=stype, path_id=old_path_id, ctx=ctx
Expand All @@ -1251,7 +1250,28 @@ def prepare_rewrite_anchors(
else:
old_set = None

return (subject_set, specified_set, old_set)
rewrite_type = r_ctx.shape_type
if stype != r_ctx.shape_type.get_nearest_non_derived_parent(schema):
rewrite_type = downcast(
s_objtypes.ObjectType,
schemactx.derive_view(
stype,
exprtype=s_ctx.exprtype,
ctx=ctx,
)
)
subject_set = setgen.class_set(
rewrite_type, path_id=subject_set.path_id, ctx=ctx)
if old_set:
old_set = setgen.class_set(
rewrite_type, path_id=old_set.path_id, ctx=ctx)

return RewriteAnchors(
subject_set=subject_set,
specified_set=specified_set,
old_set=old_set,
rewrite_type=rewrite_type,
)


def _maybe_fixup_lprop(
Expand Down
29 changes: 29 additions & 0 deletions tests/test_edgeql_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ class TestRewrites(tb.QueryTestCase):
};
create type Project extending Resource;
create type Document extending Resource {
create property text: str;
create required property textUpdatedAt: std::datetime {
set default := (std::datetime_of_statement());
create rewrite update using ((
IF __specified__.text
THEN std::datetime_of_statement()
ELSE __old__.textUpdatedAt
));
};
};
"""
]

Expand Down Expand Up @@ -1074,3 +1086,20 @@ async def test_edgeql_rewrites_28(self):
}
]
)

async def test_edgeql_rewrites_29(self):
# see https://github.com/edgedb/edgedb/issues/7048

# these tests check that subject of an update rewrite is the child
# object and not parent that is being updated
await self.con.execute(
'''
update std::Object set { };
'''
)

await self.con.execute(
'''
update Project set { name := '## redacted ##' }
'''
)

0 comments on commit 72fbd32

Please sign in to comment.