Skip to content

Commit

Permalink
Refactor compile_constraint for control flow clarity. (#7601)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark authored Jul 31, 2024
1 parent 7450ecb commit 801d401
Showing 1 changed file with 137 additions and 135 deletions.
272 changes: 137 additions & 135 deletions edb/pgsql/schemamech.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from edb.schema import schema as s_schema
from edb.schema import sources as s_sources
from edb.schema import expr as s_expr
from edb.schema import objects as s_obj

from edb.common import ast
from edb.common import parsing
Expand Down Expand Up @@ -211,40 +212,40 @@ def _edgeql_ref_to_pg_constr(
)


def compile_constraint(
subject: s_constraints.ConsistencySubject,
@dataclasses.dataclass(frozen=True)
class CompiledConstraintData:
subject: s_types.Type | s_pointers.Pointer
exclusive_expr_refs: Optional[Sequence[irast.Base]]
subject_db_name: Optional[Tuple[str, str]]
except_data: Optional[ExprDataSources]
ir: irast.Statement
subject_table_type: str


def _compile_constraint_data(
constraint: s_constraints.Constraint,
schema: s_schema.Schema,
span: Optional[parsing.Span],
) -> SchemaDomainConstraint | SchemaTableConstraint:
assert constraint.get_subject(schema) is not None
TypeOrPointer = s_types.Type | s_pointers.Pointer
is_optional: bool,
*,
span: Optional[parsing.Span] = None,
type_remaps: Optional[dict[s_obj.Object, s_obj.Object]] = None,
) -> CompiledConstraintData:
sub = constraint.get_subject(schema)
assert isinstance(
subject, (s_types.Type, s_pointers.Pointer, s_scalars.ScalarType)
sub, (s_types.Type, s_pointers.Pointer, s_scalars.ScalarType)
)
subject: s_types.Type | s_pointers.Pointer = sub

constraint_origins = constraint.get_constraint_origins(schema)
first_subject = constraint_origins[0].get_subject(schema)
path_prefix_anchor = '__subject__'
singletons = frozenset({(subject, is_optional)})

is_optional = isinstance(
first_subject, s_pointers.Pointer
) and not first_subject.get_required(schema)
singletons: Collection[Tuple[TypeOrPointer, bool]] = frozenset(
{(subject, is_optional)}
)
options = qlcompiler.CompilerOptions(
anchors={'__subject__': subject},
path_prefix_anchor='__subject__',
path_prefix_anchor=path_prefix_anchor,
apply_query_rewrites=False,
singletons=singletons,
schema_object_context=type(constraint),
# Remap the constraint origin to the subject, so that if
# we have B <: A, and the constraint references A.foo, it
# gets rewritten in the subtype to B.foo. It's OK to only
# look at one constraint origin, because if there were
# multiple different origins, they couldn't get away with
# referring to the type explicitly.
type_remaps={first_subject: subject},
type_remaps=type_remaps if type_remaps is not None else {},
)

final_expr: Optional[s_expr.Expression] = constraint.get_finalexpr(schema)
Expand Down Expand Up @@ -289,10 +290,11 @@ def compile_constraint(
)
elif ref_tables:
subject_db_name, info = next(iter(ref_tables.items()))
table_type = info[0][3].table_type
subject_table_type = info[0][3].table_type
else:
# the expression does don't have any refs: default to the subject table

subject_table: Optional[s_obj.InheritingObject] | s_types.Type
if isinstance(subject, s_pointers.Pointer):
subject_table = subject.get_source(schema)
else:
Expand All @@ -302,141 +304,141 @@ def compile_constraint(
subject_db_name = common.get_backend_name(
schema, subject_table, catenate=False,
)
table_type = 'ObjectType'
subject_table_type = 'ObjectType'

exclusive_expr_refs = _get_exclusive_refs(ir)

pg_constr_data = PGConstrData(
subject_db_name=subject_db_name,
expressions=[],
origin_expressions=[],
table_type=table_type,
except_data=except_data,
return CompiledConstraintData(
subject,
exclusive_expr_refs,
subject_db_name,
except_data,
ir,
subject_table_type,
)

different_origins = [
origin for origin in constraint_origins if origin != constraint
]

per_origin_parts = []
for constraint_origin in different_origins:
sub = constraint_origin.get_subject(schema)
assert isinstance(sub, (s_types.Type, s_pointers.Pointer))
origin_subject: s_types.Type | s_pointers.Pointer = sub

origin_path_prefix_anchor = '__subject__'
singletons = frozenset({(origin_subject, is_optional)})

origin_options = qlcompiler.CompilerOptions(
anchors={'__subject__': origin_subject},
path_prefix_anchor=origin_path_prefix_anchor,
apply_query_rewrites=False,
singletons=singletons,
schema_object_context=type(constraint),
)
def _get_compiled_constraint_expr_data(
primary_subject: s_constraints.ConsistencySubject,
constraint_data: CompiledConstraintData,
) -> list[ExprData]:
exprdatas: list[ExprData] = []

final_expr = constraint_origin.get_finalexpr(schema)
assert final_expr is not None and final_expr.parse() is not None
origin_ir = qlcompiler.compile_ast_to_ir(
final_expr.parse(),
schema,
options=origin_options,
)
constraint_subject = (
constraint_data.subject
if constraint_data.subject != primary_subject else
None
)

assert origin_ir.expr.expr
origin_terminal_refs = ir_utils.get_longest_paths(
origin_ir.expr.expr
)
origin_ref_tables = get_ref_storage_info(
origin_ir.schema, origin_terminal_refs
assert constraint_data.exclusive_expr_refs is not None
for ref in constraint_data.exclusive_expr_refs:
exprdata = _edgeql_ref_to_pg_constr(
primary_subject, constraint_subject, ref
)
exprdata.origin_subject_db_name = constraint_data.subject_db_name
exprdata.origin_except_data = constraint_data.except_data
exprdatas.append(exprdata)

if origin_ref_tables:
origin_subject_db_name, _ = next(iter(origin_ref_tables.items()))
else:
origin_subject_db_name = common.get_backend_name(
schema,
origin_subject,
catenate=False,
)
return exprdatas

origin_except_data = None
if except_expr := constraint_origin.get_except_expr(schema):
assert isinstance(except_expr, s_expr.Expression)
except_ir = qlcompiler.compile_ast_to_ir(
except_expr.parse(),
schema,
options=origin_options,
)
except_sql = compiler.compile_ir_to_sql_tree(
except_ir, singleton_mode=True)
origin_except_data = _edgeql_tree_to_expr_data(except_sql.ast)

origin_exclusive_expr_refs = _get_exclusive_refs(origin_ir)
per_origin_parts.append(
(
origin_subject,
origin_exclusive_expr_refs,
origin_subject_db_name,
origin_except_data,
)
)

if not per_origin_parts:
origin_subject = subject
origin_subject_db_name = subject_db_name
origin_except_data = except_data
per_origin_parts.append(
(
origin_subject,
None,
origin_subject_db_name,
origin_except_data,
)
)
def compile_constraint(
subject: s_constraints.ConsistencySubject,
constraint: s_constraints.Constraint,
schema: s_schema.Schema,
span: Optional[parsing.Span],
) -> SchemaDomainConstraint | SchemaTableConstraint:
assert constraint.get_subject(schema) is not None
assert isinstance(
subject, (s_types.Type, s_pointers.Pointer, s_scalars.ScalarType)
)

if exclusive_expr_refs:
exprdatas: List[ExprData] = []
for ref in exclusive_expr_refs:
exprdata = _edgeql_ref_to_pg_constr(subject, None, ref)
exprdata.origin_subject_db_name = subject_db_name
exprdata.origin_except_data = except_data
exprdatas.append(exprdata)
constraint_origins = constraint.get_constraint_origins(schema)
first_subject = constraint_origins[0].get_subject(schema)

pg_constr_data.expressions.extend(exprdatas)
is_optional = isinstance(
first_subject, s_pointers.Pointer
) and not first_subject.get_required(schema)

else:
assert len(constraint_origins) == 1
exprdata = _edgeql_ref_to_pg_constr(subject, origin_subject, ir)
exprdata.origin_subject_db_name = origin_subject_db_name
exprdata.origin_except_data = origin_except_data
constraint_data = _compile_constraint_data(
constraint,
schema,
is_optional,
span=span,
# Remap the constraint origin to the subject, so that if
# we have B <: A, and the constraint references A.foo, it
# gets rewritten in the subtype to B.foo. It's OK to only
# look at one constraint origin, because if there were
# multiple different origins, they couldn't get away with
# referring to the type explicitly.
type_remaps={first_subject: subject},
)

pg_constr_data.expressions.append(exprdata)
pg_constr_data = PGConstrData(
subject_db_name=constraint_data.subject_db_name,
expressions=[],
origin_expressions=[],
table_type=constraint_data.subject_table_type,
except_data=constraint_data.except_data,
)

for (
origin_subject,
origin_exclusive_expr_refs,
origin_subject_db_name,
origin_except_data,
) in per_origin_parts:
if not exclusive_expr_refs:
continue
if constraint_data.exclusive_expr_refs:
origin_expr_datas: dict[
s_constraints.Constraint, list[ExprData]
] = {}
for origin in constraint_origins:
if origin == constraint:
origin_data = constraint_data

if origin_exclusive_expr_refs:
for ref in origin_exclusive_expr_refs:
exprdata = _edgeql_ref_to_pg_constr(
subject, origin_subject, ref
else:
origin_data = _compile_constraint_data(
origin,
schema,
is_optional,
)
exprdata.origin_subject_db_name = origin_subject_db_name
exprdata.origin_except_data = origin_except_data
pg_constr_data.origin_expressions.append(exprdata)

origin_expr_datas[origin] = _get_compiled_constraint_expr_data(
subject, origin_data
)

expressions: list[ExprData]
if constraint in origin_expr_datas:
expressions = origin_expr_datas[constraint]
else:
pg_constr_data.origin_expressions.extend(exprdatas)
expressions = _get_compiled_constraint_expr_data(
subject, constraint_data
)

pg_constr_data.expressions.extend(expressions)

origin_expressions: list[ExprData] = []
for origin in constraint_origins:
origin_expressions.extend(origin_expr_datas[origin])

pg_constr_data.origin_expressions.extend(origin_expressions)

if exclusive_expr_refs:
pg_constr_data.scope = 'relation'
pg_constr_data.type = 'unique'

else:
assert len(constraint_origins) == 1
origin_data = (
_compile_constraint_data(
constraint_origins[0],
schema,
is_optional,
)
if constraint_origins[0] != constraint else
constraint_data
)
exprdata = _edgeql_ref_to_pg_constr(
subject, origin_data.subject, constraint_data.ir
)
exprdata.origin_subject_db_name = origin_data.subject_db_name
exprdata.origin_except_data = origin_data.except_data

pg_constr_data.expressions.append(exprdata)

pg_constr_data.scope = 'row'
pg_constr_data.type = 'check'

Expand Down

0 comments on commit 801d401

Please sign in to comment.