Skip to content

Commit

Permalink
Add AnyAspect.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Jul 5, 2024
1 parent 3156730 commit da9b46b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 55 deletions.
83 changes: 39 additions & 44 deletions edb/pgsql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class TupleAspect(str):
pass


AnyAspect = (
RelAspect
| ScalarAspect
| OperatorAspect
| OperatorAspect
| CastAspect
| ConstraintAspect
| IndexAspect
| TupleAspect
)


def quote_e_literal(string: str) -> str:
def escape_sq(s):
split = re.split(r"(\n|\\\\|\\')", s)
Expand Down Expand Up @@ -605,10 +617,7 @@ def get_index_table_backend_name(
subject = index.get_subject(schema)
assert isinstance(subject, s_types.Type)
return get_backend_name(
schema,
subject,
aspect=(str(aspect) if aspect is not None else None),
catenate=False,
schema, subject, aspect=aspect, catenate=False,
)


Expand All @@ -630,7 +639,7 @@ def get_backend_name(
catenate: Literal[True]=True,
*,
versioned: bool=True,
aspect: Optional[str]=None
aspect: Optional[AnyAspect | Literal["index"]]=None
) -> str:
...

Expand All @@ -642,7 +651,7 @@ def get_backend_name(
catenate: Literal[False],
*,
versioned: bool=True,
aspect: Optional[str]=None
aspect: Optional[AnyAspect | Literal["index"]]=None
) -> tuple[str, str]:
...

Expand All @@ -652,66 +661,57 @@ def get_backend_name(
obj: so.Object,
catenate: bool=True,
*,
aspect: Optional[str]=None,
aspect: Optional[AnyAspect | Literal["index"]]=None,
versioned: bool=True,
) -> Union[str, tuple[str, str]]:
name: Union[s_name.QualName, s_name.Name]
if isinstance(obj, s_objtypes.ObjectType):
name = obj.get_name(schema)
assert aspect is None or isinstance(aspect, RelAspect)
return get_objtype_backend_name(
obj.id,
name.module,
catenate=catenate,
aspect=(
RelAspect.from_str(aspect) if aspect is not None else None
),
aspect=aspect,
versioned=versioned,
)

elif isinstance(obj, s_pointers.Pointer):
name = obj.get_name(schema)
if aspect == "index":
aspect = RelAspect.INDEX
assert aspect is None or isinstance(aspect, RelAspect)
return get_pointer_backend_name(
obj.id,
name.module,
catenate=catenate,
versioned=versioned,
aspect=(
RelAspect.from_str(aspect) if aspect is not None else None
),
aspect=aspect,
)

elif isinstance(obj, s_scalars.ScalarType):
name = obj.get_name(schema)
assert aspect is None or isinstance(aspect, ScalarAspect)
return get_scalar_backend_name(
obj.id,
name.module,
catenate=catenate,
versioned=versioned,
aspect=(
ScalarAspect.from_str(aspect) if aspect is not None else None
),
aspect=aspect,
)

elif isinstance(obj, s_opers.Operator):
name = obj.get_shortname(schema)
assert aspect is None or isinstance(aspect, OperatorAspect)
return get_operator_backend_name(
name,
catenate,
versioned=versioned,
aspect=(
OperatorAspect.from_str(aspect) if aspect is not None else None
)
name, catenate, versioned=versioned, aspect=aspect,
)

elif isinstance(obj, s_casts.Cast):
name = obj.get_name(schema)
assert aspect is None or isinstance(aspect, CastAspect)
return get_cast_backend_name(
name,
catenate,
versioned=versioned,
aspect=(
CastAspect.from_str(aspect) if aspect is not None else None
)
name, catenate, versioned=versioned, aspect=aspect,
)

elif isinstance(obj, s_func.Function):
Expand All @@ -722,32 +722,27 @@ def get_backend_name(

elif isinstance(obj, s_constr.Constraint):
name = obj.get_name(schema)
if aspect == "index":
aspect = ConstraintAspect.INDEX
assert aspect is None or isinstance(aspect, ConstraintAspect)
return get_constraint_backend_name(
obj.id, name.module, catenate, aspect=(
ConstraintAspect.from_str(aspect)
if aspect is not None else
None
))
obj.id, name.module, catenate, aspect=aspect
)

elif isinstance(obj, s_indexes.Index):
name = obj.get_name(schema)
if aspect == "index":
aspect = IndexAspect.INDEX
assert aspect is None or isinstance(aspect, IndexAspect)
return get_index_backend_name(
obj.id,
name.module,
catenate,
aspect=(
IndexAspect(aspect) if aspect is not None else None
)
obj.id, name.module, catenate, aspect=aspect
)

elif isinstance(obj, s_types.Tuple):
# XXX: TRAMPOLINE: VERSIONED?
assert aspect is None or isinstance(aspect, TupleAspect)
return get_tuple_backend_name(
obj.id,
catenate,
aspect=(
TupleAspect(aspect) if aspect is not None else None
)
obj.id, catenate, aspect=aspect
)

else:
Expand Down
14 changes: 7 additions & 7 deletions edb/pgsql/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def make_operator_function(self, oper: s_opers.Operator, schema):
oper,
catenate=False,
versioned=False,
aspect=str(common.OperatorAspect.FUNCTION),
aspect=common.OperatorAspect.FUNCTION,
)
return self.get_function_type(name)(
name=name,
Expand Down Expand Up @@ -1820,7 +1820,7 @@ def make_cast_function(self, cast: s_casts.Cast, schema):
cast,
catenate=False,
versioned=False,
aspect=str(common.CastAspect.FUNCTION),
aspect=common.CastAspect.FUNCTION,
)

args: Sequence[dbops.FunctionArg] = [
Expand Down Expand Up @@ -3111,7 +3111,7 @@ def attach_alter_table(self, context):
@staticmethod
def _get_table_name(obj, schema) -> tuple[str, str]:
is_internal_view = types.is_cfg_view(obj, schema)
aspect = str(common.RelAspect.DUMMY) if is_internal_view else None
aspect = common.RelAspect.DUMMY if is_internal_view else None
return common.get_backend_name(
schema, obj, catenate=False, aspect=aspect)

Expand Down Expand Up @@ -3284,7 +3284,7 @@ def _get_select_from(
schema,
obj,
catenate=False,
aspect=str(common.RelAspect.TABLE),
aspect=common.RelAspect.TABLE,
)

talias = qi(tabname[1])
Expand Down Expand Up @@ -3315,7 +3315,7 @@ def get_inhview(
schema,
obj,
catenate=False,
aspect=str(common.RelAspect.INHVIEW),
aspect=common.RelAspect.INHVIEW,
)

ptrs: Dict[sn.UnqualName, Tuple[str, Tuple[str, ...]]] = {}
Expand Down Expand Up @@ -6224,7 +6224,7 @@ def __init__(self, **kwargs):

def _get_link_table_union(self, schema, links, include_children) -> str:
selects = []
aspect = str(common.RelAspect.INHVIEW) if include_children else None
aspect = common.RelAspect.INHVIEW if include_children else None
for link in links:
selects.append(textwrap.dedent('''\
(SELECT
Expand All @@ -6249,7 +6249,7 @@ def _get_inline_link_table_union(
self, schema, links, include_children
) -> str:
selects = []
aspect = str(common.RelAspect.INHVIEW) if include_children else None
aspect = common.RelAspect.INHVIEW if include_children else None
for link in links:
link_psi = types.get_pointer_storage_info(link, schema=schema)
link_col = link_psi.column_name
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4869,7 +4869,7 @@ def tabname(
return common.get_backend_name(
schema,
obj,
aspect=str(common.RelAspect.TABLE),
aspect=common.RelAspect.TABLE,
catenate=False,
versioned=True,
)
Expand All @@ -4881,7 +4881,7 @@ def inhviewname(
return common.get_backend_name(
schema,
obj,
aspect=str(common.RelAspect.INHVIEW),
aspect=common.RelAspect.INHVIEW,
catenate=False,
versioned=True,
)
Expand Down Expand Up @@ -5578,7 +5578,7 @@ def _generate_schema_alias_view(
bn = common.get_backend_name(
schema,
obj,
aspect=str(common.RelAspect.INHVIEW),
aspect=common.RelAspect.INHVIEW,
catenate=False,
versioned=True,
)
Expand Down
2 changes: 1 addition & 1 deletion edb/pgsql/resolver/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def column_order_key(c: context.Column) -> Tuple[int, str]:
schemaname, dbname = pgcommon.get_backend_name(
ctx.schema,
obj,
aspect=str(pgcommon.RelAspect.TABLE),
aspect=pgcommon.RelAspect.TABLE,
catenate=False,
)
relation = pgast.Relation(name=dbname, schemaname=schemaname)
Expand Down

0 comments on commit da9b46b

Please sign in to comment.