Skip to content

Commit

Permalink
Refactor QL AST (#7157)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Apr 8, 2024
1 parent 1d2d4a1 commit e9d0286
Show file tree
Hide file tree
Showing 35 changed files with 352 additions and 313 deletions.
111 changes: 51 additions & 60 deletions edb/edgeql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# AST classes that name-clash with classes from the typing module.

import typing
import enum

from edb.common import enum as s_enum
from edb.common import ast, span
Expand Down Expand Up @@ -173,11 +174,15 @@ class ObjectRef(BaseObjectRef):


class PseudoObjectRef(BaseObjectRef):
# anytype, anytuple or anyobject
'''anytype, anytuple or anyobject'''
name: str


class Anchor(Expr):
'''Identifier that resolves to some pre-compiled expression.
For example in shapes, the anchor __subject__ refers to object that the
shape is defined on.
'''
__abstract_node__ = True
name: str

Expand All @@ -190,14 +195,6 @@ class SpecialAnchor(Anchor):
pass


class Source(SpecialAnchor): # __source__
name: str = '__source__'


class Subject(SpecialAnchor): # __subject__
name: str = '__subject__'


class DetachedExpr(Expr): # DETACHED Expr
expr: Expr
preserve_path_prefix: bool = False
Expand Down Expand Up @@ -225,11 +222,9 @@ class BinOp(Expr):
left: Expr
op: str
right: Expr
rebalanced: bool = False


class SetConstructorOp(BinOp):
op: str = 'UNION'
rebalanced: bool = False
set_constructor: bool = False


class WindowSpec(Base):
Expand All @@ -245,54 +240,43 @@ class FunctionCall(Expr):


class BaseConstant(Expr):
"""Constant (a literal value)."""
__abstract_node__ = True
value: str

@classmethod
def from_python(cls, val: typing.Any) -> BaseConstant:
raise NotImplementedError

class Constant(BaseConstant):
"""Constant whose value we can store in a string."""
kind: ConstantKind
value: str

class StringConstant(BaseConstant):
@classmethod
def from_python(cls, s: str) -> StringConstant:
return cls(value=s)


class BaseRealConstant(BaseConstant):
__abstract_node__ = True
is_negative: bool = False


class IntegerConstant(BaseRealConstant):
pass


class FloatConstant(BaseRealConstant):
pass


class BigintConstant(BaseRealConstant):
pass
def string(cls, value: str) -> Constant:
return Constant(kind=ConstantKind.STRING, value=value)

@classmethod
def boolean(cls, b: bool) -> Constant:
return Constant(kind=ConstantKind.BOOLEAN, value=str(b).lower())

class DecimalConstant(BaseRealConstant):
pass
@classmethod
def integer(cls, i: int) -> Constant:
return Constant(kind=ConstantKind.INTEGER, value=str(i))


class BooleanConstant(BaseConstant):
@classmethod
def from_python(cls, b: bool) -> BooleanConstant:
return cls(value=str(b).lower())
class ConstantKind(enum.IntEnum):
STRING = 0
BOOLEAN = 1
INTEGER = 2
FLOAT = 3
BIGINT = 4
DECIMAL = 5


class BytesConstant(BaseConstant):
# This should really just be str to match, though
value: bytes # type: ignore[assignment]
value: bytes

@classmethod
def from_python(cls, s: bytes) -> BytesConstant:
return cls(value=s)
return BytesConstant(value=s)


class Parameter(Expr):
Expand All @@ -305,6 +289,8 @@ class UnaryOp(Expr):


class TypeExpr(Base):
__abstract_node__ = True

name: typing.Optional[str] = None # name is used for types in named tuples


Expand All @@ -314,7 +300,7 @@ class TypeOf(TypeExpr):

class TypeExprLiteral(TypeExpr):
# Literal type exprs are used in enum declarations.
val: BaseConstant
val: Constant


class TypeName(TypeExpr):
Expand Down Expand Up @@ -515,9 +501,9 @@ class Query(Expr):
Statement = Query | Command


class PipelinedQuery(Query):
__abstract_node__ = True
implicit: bool = False
class SelectQuery(Query):
result_alias: typing.Optional[str] = None
result: Expr

where: typing.Optional[Expr] = None

Expand All @@ -531,11 +517,7 @@ class PipelinedQuery(Query):
# not interfere with linkprops.
rptr_passthrough: bool = False


class SelectQuery(PipelinedQuery):

result_alias: typing.Optional[str] = None
result: Expr
implicit: bool = False


class GroupingIdentList(Base):
Expand Down Expand Up @@ -599,9 +581,16 @@ class UpdateQuery(Query):
where: typing.Optional[Expr] = None


class DeleteQuery(PipelinedQuery):
class DeleteQuery(Query):
subject: Expr

where: typing.Optional[Expr] = None

orderby: typing.Optional[typing.List[SortExpr]] = None

offset: typing.Optional[Expr] = None
limit: typing.Optional[Expr] = None


class ForQuery(Query):
from_desugaring: bool = False
Expand Down Expand Up @@ -899,7 +888,7 @@ class ExtensionPackageCommand(GlobalObjectCommand):
object_class: qltypes.SchemaObjectClass = (
qltypes.SchemaObjectClass.EXTENSION_PACKAGE
)
version: StringConstant
version: Constant


class CreateExtensionPackage(CreateObject, ExtensionPackageCommand):
Expand All @@ -918,7 +907,7 @@ class ExtensionCommand(UnqualifiedObjectCommand):
object_class: qltypes.SchemaObjectClass = (
qltypes.SchemaObjectClass.EXTENSION
)
version: typing.Optional[StringConstant] = None
version: typing.Optional[Constant] = None


class CreateExtension(CreateObject, ExtensionCommand):
Expand Down Expand Up @@ -1623,13 +1612,15 @@ def has_ddl_subcommand(
ReturningQuery = SelectQuery | ForQuery | InternalGroupQuery


FilteringQuery = PipelinedQuery | ShapeElement | UpdateQuery | ConfigReset
FilteringQuery = (
SelectQuery | DeleteQuery | ShapeElement | UpdateQuery | ConfigReset
)


SubjectQuery = DeleteQuery | UpdateQuery | GroupQuery


OffsetLimitQuery = PipelinedQuery | ShapeElement
OffsetLimitQuery = SelectQuery | DeleteQuery | ShapeElement


BasedOn = (
Expand Down
73 changes: 23 additions & 50 deletions edb/edgeql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,41 +687,21 @@ def visit_Placeholder(self, node: qlast.Placeholder) -> None:
self.write(node.name)
self.write(')')

def visit_StringConstant(self, node: qlast.StringConstant) -> None:
if not _NON_PRINTABLE_RE.search(node.value):
for d in ("'", '"', '$$'):
if d not in node.value:
if '\\' in node.value and d != '$$':
self.write('r', d, node.value, d)
else:
self.write(d, node.value, d)
return
self.write(edgeql_quote.dollar_quote_literal(node.value))
return
self.write(repr(node.value))

def visit_IntegerConstant(self, node: qlast.IntegerConstant) -> None:
if node.is_negative:
self.write('-')
self.write(node.value)

def visit_FloatConstant(self, node: qlast.FloatConstant) -> None:
if node.is_negative:
self.write('-')
self.write(node.value)

def visit_DecimalConstant(self, node: qlast.DecimalConstant) -> None:
if node.is_negative:
self.write('-')
self.write(node.value)

def visit_BigintConstant(self, node: qlast.BigintConstant) -> None:
if node.is_negative:
self.write('-')
self.write(node.value)

def visit_BooleanConstant(self, node: qlast.BooleanConstant) -> None:
self.write(node.value)
def visit_Constant(self, node: qlast.Constant) -> None:
if node.kind == qlast.ConstantKind.STRING:
if not _NON_PRINTABLE_RE.search(node.value):
for d in ("'", '"', '$$'):
if d not in node.value:
if '\\' in node.value and d != '$$':
self.write('r', d, node.value, d)
else:
self.write(d, node.value, d)
return
self.write(edgeql_quote.dollar_quote_literal(node.value))
return
self.write(repr(node.value))
else:
self.write(node.value)

def visit_BytesConstant(self, node: qlast.BytesConstant) -> None:
val = _BYTES_ESCAPE_RE.sub(_bytes_escape, node.value)
Expand Down Expand Up @@ -811,19 +791,7 @@ def visit_ObjectRef(self, node: qlast.ObjectRef) -> None:
self.write('::')
self.write(ident_to_str(node.name))

def visit_Anchor(self, node: qlast.Anchor) -> None:
self.write(node.name)

def visit_IRAnchor(self, node: qlast.IRAnchor) -> None:
self.write(node.name)

def visit_SpecialAnchor(self, node: qlast.SpecialAnchor) -> None:
self.write(node.name)

def visit_Subject(self, node: qlast.Subject) -> None:
self.write(node.name)

def visit_Source(self, node: qlast.Source) -> None:
def visit_SpecialAnchor(self, node: qlast.Anchor) -> None:
self.write(node.name)

def visit_TypeExprLiteral(self, node: qlast.TypeExprLiteral) -> None:
Expand Down Expand Up @@ -1361,7 +1329,9 @@ def _eval_bool_expr(
self,
expr: Union[qlast.Expr, qlast.TypeExpr],
) -> bool:
if not isinstance(expr, qlast.BooleanConstant):
if (not isinstance(expr, qlast.Constant)
or expr.kind != qlast.ConstantKind.BOOLEAN
):
raise AssertionError(f'expected BooleanConstant, got {expr!r}')
return expr.value == 'true'

Expand All @@ -1370,7 +1340,10 @@ def _eval_enum_expr(
expr: Union[qlast.Expr, qlast.TypeExpr],
enum_type: Type[Enum_T],
) -> Enum_T:
if not isinstance(expr, qlast.StringConstant):
if (
not isinstance(expr, qlast.Constant)
or expr.kind != qlast.ConstantKind.STRING
):
raise AssertionError(f'expected StringConstant, got {expr!r}')
return enum_type(expr.value)

Expand Down
13 changes: 0 additions & 13 deletions edb/edgeql/compiler/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,6 @@ def is_ql_empty_array(expr: qlast.Expr) -> bool:
return isinstance(expr, qlast.Array) and len(expr.elements) == 0


def is_ql_path(qlexpr: qlast.Expr) -> bool:
if isinstance(qlexpr, qlast.Shape):
if qlexpr.expr:
qlexpr = qlexpr.expr

if not isinstance(qlexpr, qlast.Path):
return False

start = qlexpr.steps[0]

return isinstance(start, (qlast.Source, qlast.ObjectRef, qlast.Ptr))


def is_nontrivial_shape_element(shape_el: qlast.ShapeElement) -> bool:
return bool(
shape_el.where
Expand Down
Loading

0 comments on commit e9d0286

Please sign in to comment.