From 06705afd7fcb6f1e03289f354c7d5293c521a1e1 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 19 Jun 2024 16:58:22 -0700 Subject: [PATCH] Propagate type creation/deletion to functions that depend on ancestors --- edb/pgsql/delta.py | 13 ++++++++ edb/schema/delta.py | 7 +++- edb/schema/objtypes.py | 46 +++++++++++++++++++++++++ tests/test_edgeql_ddl.py | 72 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 1 deletion(-) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 8259ddd7056d..4aff35071ccd 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -1168,6 +1168,19 @@ def _compile_edgeql_function( ) -> s_expr.CompiledExpression: if isinstance(body, s_expr.CompiledExpression): return body + + # HACK: When an object type selected by a function (via + # inheritance) is dropped, the function gets + # recompiled. Unfortunately, 'caused' subcommands run *before* + # the object is actually deleted, and so we would ordinarily + # still try to select from the deleted object. To avoid + # needing to add *another* type of subcommand, we work around + # this by temporarily stripping all objects that are about to + # be deleted from the schema. + for ctx in context.stack: + if isinstance(ctx.op, s_objtypes.DeleteObjectType): + schema = schema.delete(ctx.op.scls) + return s_funcs.compile_function( schema, context, diff --git a/edb/schema/delta.py b/edb/schema/delta.py index aa345080334c..69f5e711ce77 100644 --- a/edb/schema/delta.py +++ b/edb/schema/delta.py @@ -2088,6 +2088,7 @@ def _propagate_if_expr_refs( context: CommandContext, *, action: str, + include_self: bool=True, include_ancestors: bool=False, extra_refs: Optional[Dict[so.Object, List[str]]]=None, filter: Type[so.Object] | Tuple[Type[so.Object], ...] | None = None, @@ -2104,7 +2105,9 @@ def _propagate_if_expr_refs( fixer = None scls = self.scls - expr_refs = s_expr.get_expr_referrers(schema, scls) + expr_refs: dict[so.Object, list[str]] = {} + if include_self: + expr_refs.update(s_expr.get_expr_referrers(schema, scls)) if include_ancestors and isinstance(scls, so.InheritingObject): for anc in scls.get_ancestors(schema).objects(schema): expr_refs.update(s_expr.get_expr_referrers(schema, anc)) @@ -3203,6 +3206,8 @@ def _create_finalize( context: CommandContext, ) -> s_schema.Schema: if not context.canonical: + # This is rarely triggered. + schema = self._finalize_affected_refs(schema, context) self.validate_object(schema, context) return schema diff --git a/edb/schema/objtypes.py b/edb/schema/objtypes.py index 0e7ea829ee3f..0d567e3589e8 100644 --- a/edb/schema/objtypes.py +++ b/edb/schema/objtypes.py @@ -32,6 +32,7 @@ from . import annos as s_anno from . import constraints from . import delta as sd +from . import functions as s_func from . import inheriting from . import links from . import properties @@ -560,6 +561,28 @@ def _get_ast_node( else: return super()._get_ast_node(schema, context) + def _create_finalize( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + if ( + not context.canonical + and self.scls.is_material_object_type(schema) + ): + # Propagate changes to any functions that depend on + # ancestor types in order to recompute the inheritance + # situation. + schema = self._propagate_if_expr_refs( + schema, + context, + action='creating an object type', + include_ancestors=True, + filter=s_func.Function, + ) + + return super()._create_finalize(schema, context) + class RenameObjectType( ObjectTypeCommand, @@ -701,3 +724,26 @@ def _get_ast( return None else: return super()._get_ast(schema, context, parent_node=parent_node) + + def _delete_finalize( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + if ( + not context.canonical + and self.scls.is_material_object_type(schema) + ): + # Propagate changes to any functions that depend on + # ancestor types in order to recompute the inheritance + # situation. + schema = self._propagate_if_expr_refs( + schema, + context, + action='deleting an object type', + include_self=False, + include_ancestors=True, + filter=s_func.Function, + ) + + return super()._delete_finalize(schema, context) diff --git a/tests/test_edgeql_ddl.py b/tests/test_edgeql_ddl.py index 868c89c81b53..d122fc303dfb 100644 --- a/tests/test_edgeql_ddl.py +++ b/tests/test_edgeql_ddl.py @@ -5139,6 +5139,78 @@ async def test_edgeql_ddl_function_38(self): ); ''') + async def test_edgeql_ddl_function_inh_01(self): + await self.con.execute(""" + create abstract type T; + create function countall() -> int64 USING (count(T)); + """) + + await self.assert_query_result( + """SELECT countall()""", + [0], + ) + await self.con.execute(""" + create type S1 extending T; + insert S1; + """) + await self.assert_query_result( + """SELECT countall()""", + [1], + ) + await self.con.execute(""" + create type S2 extending T; + insert S2; + insert S2; + """) + await self.assert_query_result( + """SELECT countall()""", + [3], + ) + await self.con.execute(""" + drop type S2; + """) + + await self.assert_query_result( + """SELECT countall()""", + [1], + ) + + async def test_edgeql_ddl_function_inh_02(self): + await self.con.execute(""" + create abstract type T { create multi property n -> int64 }; + create function countall() -> int64 USING (sum(T.n)); + """) + + await self.assert_query_result( + """SELECT countall()""", + [0], + ) + await self.con.execute(""" + create type S1 extending T; + insert S1 { n := {3, 4} }; + """) + await self.assert_query_result( + """SELECT countall()""", + [7], + ) + await self.con.execute(""" + create type S2 extending T; + insert S2 { n := 1 }; + insert S2 { n := {2, 2, 2} }; + """) + await self.assert_query_result( + """SELECT countall()""", + [14], + ) + await self.con.execute(""" + drop type S2; + """) + + await self.assert_query_result( + """SELECT countall()""", + [7], + ) + async def test_edgeql_ddl_function_rename_01(self): await self.con.execute(""" CREATE FUNCTION foo(s: str) -> str {