From 72fbd32a47b3c0eefbbd795f0aad490a182ce676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Mon, 18 Mar 2024 23:13:37 +0100 Subject: [PATCH] Fix update rewrites on types that are children of updated type (#7073) Closes #7048 Closes #6801 When type X is updated, but has child Y with an update rewrite, we were using type X as the subject where we should have been using Y. --- edb/edgeql/compiler/viewgen.py | 142 +++++++++++++++++++-------------- tests/test_edgeql_rewrites.py | 29 +++++++ 2 files changed, 110 insertions(+), 61 deletions(-) diff --git a/edb/edgeql/compiler/viewgen.py b/edb/edgeql/compiler/viewgen.py index bbefb09dd11..f47f346a98b 100644 --- a/edb/edgeql/compiler/viewgen.py +++ b/edb/edgeql/compiler/viewgen.py @@ -900,6 +900,15 @@ def _raise_on_missing( ) +@dataclasses.dataclass(kw_only=True, repr=False, eq=False) +class RewriteContext: + specified_ptrs: Set[sn.UnqualName] + kind: qltypes.RewriteKind + + base_type: s_objtypes.ObjectType + shape_type: s_objtypes.ObjectType + + def _compile_rewrites( specified_ptrs: Set[sn.UnqualName], kind: qltypes.RewriteKind, @@ -910,18 +919,24 @@ def _compile_rewrites( ctx: context.ContextLevel, ) -> Optional[irast.Rewrites]: # init - anchors = None + r_ctx = RewriteContext( + specified_ptrs=specified_ptrs, + kind=kind, + base_type=stype, + shape_type=view_scls, + ) # Computing anchors isn't cheap, so we want to only do it once, # and only do it when it is necessary. - def get_anchors() -> RewriteAnchors: - nonlocal anchors - if anchors is None: - anchors = prepare_rewrite_anchors(specified_ptrs, kind, stype, ctx) - return anchors + anchors: Dict[s_objtypes.ObjectType, RewriteAnchors] = {} + + def get_anchors(stype: s_objtypes.ObjectType) -> RewriteAnchors: + if stype not in anchors: + anchors[stype] = prepare_rewrite_anchors(stype, r_ctx, s_ctx, ctx) + return anchors[stype] rewrites = _compile_rewrites_for_stype( - stype, kind, view_scls, ir_set, get_anchors, s_ctx, ctx=ctx + stype, kind, ir_set, get_anchors, s_ctx, ctx=ctx ) if kind == qltypes.RewriteKind.Insert: @@ -937,7 +952,7 @@ def get_anchors() -> RewriteAnchors: # statement later. rewrites_by_type = _compile_rewrites_of_children( - stype, rewrites, kind, view_scls, ir_set, get_anchors, s_ctx, ctx + stype, rewrites, kind, ir_set, get_anchors, s_ctx, ctx ) else: @@ -979,12 +994,13 @@ def get_anchors() -> RewriteAnchors: by_type[ty][pn] = (ptr_set, ptrref.real_material_ptr) - if not anchors: + anc = next(iter(anchors.values()), None) + if not anc: return None return irast.Rewrites( - subject_path_id=anchors[0].path_id, - old_path_id=anchors[2].path_id if anchors[2] else None, + subject_path_id=anc.subject_set.path_id, + old_path_id=anc.old_set.path_id if anc.old_set else None, by_type=by_type, ) @@ -993,9 +1009,8 @@ def _compile_rewrites_of_children( stype: s_objtypes.ObjectType, parent_rewrites: Dict[sn.UnqualName, EarlyShapePtr], kind: qltypes.RewriteKind, - view_scls: s_objtypes.ObjectType, ir_set: irast.Set, - get_anchors: Callable[[], RewriteAnchors], + get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors], s_ctx: ShapeContext, ctx: context.ContextLevel, ) -> Dict[irast.TypeRef, Dict[sn.UnqualName, EarlyShapePtr]]: @@ -1015,7 +1030,7 @@ def _compile_rewrites_of_children( child_rewrites = parent_rewrites.copy() # override with rewrites defined here rewrites_defined_here = _compile_rewrites_for_stype( - child, kind, view_scls, ir_set, get_anchors, s_ctx, + child, kind, ir_set, get_anchors, s_ctx, already_defined_rewrites=child_rewrites, ctx=ctx ) @@ -1027,7 +1042,6 @@ def _compile_rewrites_of_children( child, child_rewrites, kind, - view_scls, ir_set, get_anchors, s_ctx, @@ -1041,9 +1055,8 @@ def _compile_rewrites_of_children( def _compile_rewrites_for_stype( stype: s_objtypes.ObjectType, kind: qltypes.RewriteKind, - view_scls: s_objtypes.ObjectType, ir_set: irast.Set, - get_anchors: Callable[[], RewriteAnchors], + get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors], s_ctx: ShapeContext, *, already_defined_rewrites: Optional[ @@ -1091,27 +1104,7 @@ def _compile_rewrites_for_stype( ): continue - subject_set, specified_set, old_set = get_anchors() - - rewrite_view = view_scls - if stype != view_scls.get_nearest_non_derived_parent(schema): - # FIXME: Caching? - rewrite_view = downcast( - s_objtypes.ObjectType, - schemactx.derive_view( - stype, - exprtype=s_ctx.exprtype, - ctx=ctx, - ) - ) - subject_set = setgen.class_set( - rewrite_view, path_id=subject_set.path_id, ctx=ctx) - if old_set: - old_set = setgen.class_set( - rewrite_view, path_id=old_set.path_id, ctx=ctx) - - ir_set = setgen.class_set( - rewrite_view, path_id=ir_set.path_id, ctx=ctx) + anchors = get_anchors(stype) rewrite_expr = rewrite.get_expr(ctx.env.schema) assert rewrite_expr @@ -1120,17 +1113,17 @@ def _compile_rewrites_for_stype( scopectx.active_rewrites |= {stype} # prepare context - scopectx.partial_path_prefix = subject_set + scopectx.partial_path_prefix = anchors.subject_set nanchors = {} - nanchors["__specified__"] = specified_set - nanchors["__subject__"] = subject_set - if old_set: - nanchors["__old__"] = old_set + nanchors["__specified__"] = anchors.specified_set + nanchors["__subject__"] = anchors.subject_set + if anchors.old_set: + nanchors["__old__"] = anchors.old_set for key, anchor in nanchors.items(): scopectx.path_scope.attach_path( anchor.path_id, context=None, - optional=(anchor is subject_set), + optional=(anchor is anchors.subject_set), ) scopectx.iterator_path_ids |= {anchor.path_id} scopectx.anchors[key] = anchor @@ -1138,9 +1131,9 @@ def _compile_rewrites_for_stype( # XXX: I am pretty sure this must be wrong, but we get # a failure without due to volatility issues in # test_edgeql_rewrites_16 - scopectx.env.singletons.append(subject_set.path_id) + scopectx.env.singletons.append(anchors.subject_set.path_id) - ctx.path_scope.factoring_allowlist.add(subject_set.path_id) + ctx.path_scope.factoring_allowlist.add(anchors.subject_set.path_id) # prepare expression ptrcls_sn = ptrcls.get_shortname(ctx.env.schema) @@ -1157,16 +1150,16 @@ def _compile_rewrites_for_stype( ) shape_ql_desc = _shape_el_ql_to_shape_el_desc( shape_ql, - source=rewrite_view, + source=anchors.rewrite_type, s_ctx=s_ctx, ctx=scopectx, ) # compile as normal shape element pointer, ptr_set = _normalize_view_ptr_expr( - subject_set, + anchors.subject_set, shape_ql_desc, - rewrite_view, + anchors.rewrite_type, path_id=path_id, from_default=True, s_ctx=s_ctx, @@ -1178,13 +1171,19 @@ def _compile_rewrites_for_stype( return res -RewriteAnchors = Tuple[irast.Set, irast.Set, Optional[irast.Set]] +@dataclasses.dataclass(kw_only=True, repr=False, eq=False) +class RewriteAnchors: + subject_set: irast.Set + specified_set: irast.Set + old_set: Optional[irast.Set] + + rewrite_type: s_objtypes.ObjectType def prepare_rewrite_anchors( - specified_ptrs: Set[sn.UnqualName], - rewrite_kind: qltypes.RewriteKind, stype: s_objtypes.ObjectType, + r_ctx: RewriteContext, + s_ctx: ShapeContext, ctx: context.ContextLevel, ) -> RewriteAnchors: schema = ctx.env.schema @@ -1193,11 +1192,11 @@ def prepare_rewrite_anchors( # TODO: Do we really need a separate path id for __subject__? subject_name = sn.QualName("__derived__", "__subject__") subject_path_id = irast.PathId.from_type( - schema, stype, typename=subject_name, namespace=ctx.path_id_namespace, - env=ctx.env, + schema, stype, typename=subject_name, + namespace=ctx.path_id_namespace, env=ctx.env, ) - subject_set = setgen.new_set( - stype=stype, path_id=subject_path_id, ctx=ctx + subject_set = setgen.class_set( + stype, path_id=subject_path_id, ctx=ctx ) # init reference to std::bool @@ -1225,7 +1224,7 @@ def prepare_rewrite_anchors( name=pn.name, val=setgen.ensure_set( irast.BooleanConstant( - value=str(pn in specified_ptrs), + value=str(pn in r_ctx.specified_ptrs), typeref=bool_path.target, ), ctx=ctx @@ -1238,11 +1237,11 @@ def prepare_rewrite_anchors( ) # init set for __old__ - if rewrite_kind == qltypes.RewriteKind.Update: + if r_ctx.kind == qltypes.RewriteKind.Update: old_name = sn.QualName("__derived__", "__old__") old_path_id = irast.PathId.from_type( - schema, stype, typename=old_name, namespace=ctx.path_id_namespace, - env=ctx.env, + schema, stype, typename=old_name, + namespace=ctx.path_id_namespace, env=ctx.env, ) old_set = setgen.new_set( stype=stype, path_id=old_path_id, ctx=ctx @@ -1251,7 +1250,28 @@ def prepare_rewrite_anchors( else: old_set = None - return (subject_set, specified_set, old_set) + rewrite_type = r_ctx.shape_type + if stype != r_ctx.shape_type.get_nearest_non_derived_parent(schema): + rewrite_type = downcast( + s_objtypes.ObjectType, + schemactx.derive_view( + stype, + exprtype=s_ctx.exprtype, + ctx=ctx, + ) + ) + subject_set = setgen.class_set( + rewrite_type, path_id=subject_set.path_id, ctx=ctx) + if old_set: + old_set = setgen.class_set( + rewrite_type, path_id=old_set.path_id, ctx=ctx) + + return RewriteAnchors( + subject_set=subject_set, + specified_set=specified_set, + old_set=old_set, + rewrite_type=rewrite_type, + ) def _maybe_fixup_lprop( diff --git a/tests/test_edgeql_rewrites.py b/tests/test_edgeql_rewrites.py index 9c2fe954d8c..d65a69fee16 100644 --- a/tests/test_edgeql_rewrites.py +++ b/tests/test_edgeql_rewrites.py @@ -49,6 +49,18 @@ class TestRewrites(tb.QueryTestCase): }; create type Project extending Resource; + + create type Document extending Resource { + create property text: str; + create required property textUpdatedAt: std::datetime { + set default := (std::datetime_of_statement()); + create rewrite update using (( + IF __specified__.text + THEN std::datetime_of_statement() + ELSE __old__.textUpdatedAt + )); + }; + }; """ ] @@ -1074,3 +1086,20 @@ async def test_edgeql_rewrites_28(self): } ] ) + + async def test_edgeql_rewrites_29(self): + # see https://github.com/edgedb/edgedb/issues/7048 + + # these tests check that subject of an update rewrite is the child + # object and not parent that is being updated + await self.con.execute( + ''' + update std::Object set { }; + ''' + ) + + await self.con.execute( + ''' + update Project set { name := '## redacted ##' } + ''' + )