Skip to content

Commit

Permalink
Propagate type creation/deletion to functions that depend on ancestors
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored and dnwpark committed Jul 2, 2024
1 parent 9bcc6e4 commit 06705af
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
13 changes: 13 additions & 0 deletions edb/pgsql/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,19 @@ def _compile_edgeql_function(
) -> s_expr.CompiledExpression:
if isinstance(body, s_expr.CompiledExpression):
return body

# HACK: When an object type selected by a function (via
# inheritance) is dropped, the function gets
# recompiled. Unfortunately, 'caused' subcommands run *before*
# the object is actually deleted, and so we would ordinarily
# still try to select from the deleted object. To avoid
# needing to add *another* type of subcommand, we work around
# this by temporarily stripping all objects that are about to
# be deleted from the schema.
for ctx in context.stack:
if isinstance(ctx.op, s_objtypes.DeleteObjectType):
schema = schema.delete(ctx.op.scls)

return s_funcs.compile_function(
schema,
context,
Expand Down
7 changes: 6 additions & 1 deletion edb/schema/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,7 @@ def _propagate_if_expr_refs(
context: CommandContext,
*,
action: str,
include_self: bool=True,
include_ancestors: bool=False,
extra_refs: Optional[Dict[so.Object, List[str]]]=None,
filter: Type[so.Object] | Tuple[Type[so.Object], ...] | None = None,
Expand All @@ -2104,7 +2105,9 @@ def _propagate_if_expr_refs(
fixer = None

scls = self.scls
expr_refs = s_expr.get_expr_referrers(schema, scls)
expr_refs: dict[so.Object, list[str]] = {}
if include_self:
expr_refs.update(s_expr.get_expr_referrers(schema, scls))
if include_ancestors and isinstance(scls, so.InheritingObject):
for anc in scls.get_ancestors(schema).objects(schema):
expr_refs.update(s_expr.get_expr_referrers(schema, anc))
Expand Down Expand Up @@ -3203,6 +3206,8 @@ def _create_finalize(
context: CommandContext,
) -> s_schema.Schema:
if not context.canonical:
# This is rarely triggered.
schema = self._finalize_affected_refs(schema, context)
self.validate_object(schema, context)
return schema

Expand Down
46 changes: 46 additions & 0 deletions edb/schema/objtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from . import annos as s_anno
from . import constraints
from . import delta as sd
from . import functions as s_func
from . import inheriting
from . import links
from . import properties
Expand Down Expand Up @@ -560,6 +561,28 @@ def _get_ast_node(
else:
return super()._get_ast_node(schema, context)

def _create_finalize(
self,
schema: s_schema.Schema,
context: sd.CommandContext,
) -> s_schema.Schema:
if (
not context.canonical
and self.scls.is_material_object_type(schema)
):
# Propagate changes to any functions that depend on
# ancestor types in order to recompute the inheritance
# situation.
schema = self._propagate_if_expr_refs(
schema,
context,
action='creating an object type',
include_ancestors=True,
filter=s_func.Function,
)

return super()._create_finalize(schema, context)


class RenameObjectType(
ObjectTypeCommand,
Expand Down Expand Up @@ -701,3 +724,26 @@ def _get_ast(
return None
else:
return super()._get_ast(schema, context, parent_node=parent_node)

def _delete_finalize(
self,
schema: s_schema.Schema,
context: sd.CommandContext,
) -> s_schema.Schema:
if (
not context.canonical
and self.scls.is_material_object_type(schema)
):
# Propagate changes to any functions that depend on
# ancestor types in order to recompute the inheritance
# situation.
schema = self._propagate_if_expr_refs(
schema,
context,
action='deleting an object type',
include_self=False,
include_ancestors=True,
filter=s_func.Function,
)

return super()._delete_finalize(schema, context)
72 changes: 72 additions & 0 deletions tests/test_edgeql_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5139,6 +5139,78 @@ async def test_edgeql_ddl_function_38(self):
);
''')

async def test_edgeql_ddl_function_inh_01(self):
await self.con.execute("""
create abstract type T;
create function countall() -> int64 USING (count(T));
""")

await self.assert_query_result(
"""SELECT countall()""",
[0],
)
await self.con.execute("""
create type S1 extending T;
insert S1;
""")
await self.assert_query_result(
"""SELECT countall()""",
[1],
)
await self.con.execute("""
create type S2 extending T;
insert S2;
insert S2;
""")
await self.assert_query_result(
"""SELECT countall()""",
[3],
)
await self.con.execute("""
drop type S2;
""")

await self.assert_query_result(
"""SELECT countall()""",
[1],
)

async def test_edgeql_ddl_function_inh_02(self):
await self.con.execute("""
create abstract type T { create multi property n -> int64 };
create function countall() -> int64 USING (sum(T.n));
""")

await self.assert_query_result(
"""SELECT countall()""",
[0],
)
await self.con.execute("""
create type S1 extending T;
insert S1 { n := {3, 4} };
""")
await self.assert_query_result(
"""SELECT countall()""",
[7],
)
await self.con.execute("""
create type S2 extending T;
insert S2 { n := 1 };
insert S2 { n := {2, 2, 2} };
""")
await self.assert_query_result(
"""SELECT countall()""",
[14],
)
await self.con.execute("""
drop type S2;
""")

await self.assert_query_result(
"""SELECT countall()""",
[7],
)

async def test_edgeql_ddl_function_rename_01(self):
await self.con.execute("""
CREATE FUNCTION foo(s: str) -> str {
Expand Down

0 comments on commit 06705af

Please sign in to comment.