Skip to content

Commit

Permalink
Minor refactor, change pg_ast to pgast (#6672)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Jan 5, 2024
1 parent 23bb395 commit 1930838
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
54 changes: 28 additions & 26 deletions edb/pgsql/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from edb.server import config
from edb.server.config import ops as config_ops

from . import ast as pg_ast
from . import ast as pgast
from .common import qname as q
from .common import quote_literal as ql
from .common import quote_ident as qi
Expand Down Expand Up @@ -3640,7 +3640,7 @@ def create_index(
# casts, strip them as they mess with the requirement that
# index expressions are IMMUTABLE (also indexes expect the
# usage of literals and will do their own implicit casts).
if isinstance(kw_sql_tree, pg_ast.TypeCast):
if isinstance(kw_sql_tree, pgast.TypeCast):
kw_sql_tree = kw_sql_tree.arg
sql = codegen.generate_source(kw_sql_tree)
sql_kwarg_exprs[name] = sql
Expand Down Expand Up @@ -4972,7 +4972,7 @@ def _compile_conversion_expr(
backend_runtime_params=context.backend_runtime_params,
)
sql_tree = sql_res.ast
assert isinstance(sql_tree, pg_ast.SelectStmt)
assert isinstance(sql_tree, pgast.SelectStmt)

if produce_ctes:
# ensure the result contains the object id in the second column
Expand All @@ -4989,47 +4989,47 @@ def _compile_conversion_expr(
if check_non_null:
# wrap into raise_on_null
pointer_name = 'link' if is_link else 'property'
msg = pg_ast.StringConstant(
msg = pgast.StringConstant(
val=f"missing value for required {pointer_name}"
)
# Concat to string which is a JSON. Great. Equivalent to SQL:
# '{"object_id": "' || {obj_id_ref} || '"}'
detail = pg_ast.Expr(
detail = pgast.Expr(
name='||',
lexpr=pg_ast.StringConstant(val='{"object_id": "'),
rexpr=pg_ast.Expr(
lexpr=pgast.StringConstant(val='{"object_id": "'),
rexpr=pgast.Expr(
name='||',
lexpr=pg_ast.ColumnRef(name=('id', )),
rexpr=pg_ast.StringConstant(val='"}'),
lexpr=pgast.ColumnRef(name=('id',)),
rexpr=pgast.StringConstant(val='"}'),
)
)
column = pg_ast.StringConstant(val=str(pointer.id))
column = pgast.StringConstant(val=str(pointer.id))

null_check = pg_ast.FuncCall(
null_check = pgast.FuncCall(
name=("edgedb", "raise_on_null"),
args=[
pg_ast.ColumnRef(name=("val", )),
pg_ast.StringConstant(val="not_null_violation"),
pg_ast.NamedFuncArg(name="msg", val=msg),
pg_ast.NamedFuncArg(name="detail", val=detail),
pg_ast.NamedFuncArg(name="column", val=column),
pgast.ColumnRef(name=("val",)),
pgast.StringConstant(val="not_null_violation"),
pgast.NamedFuncArg(name="msg", val=msg),
pgast.NamedFuncArg(name="detail", val=detail),
pgast.NamedFuncArg(name="column", val=column),
],
)

inner_colnames = ["val"]
target_list = [pg_ast.ResTarget(val=null_check)]
target_list = [pgast.ResTarget(val=null_check)]
if produce_ctes:
inner_colnames.append("id")
target_list.append(
pg_ast.ResTarget(val=pg_ast.ColumnRef(name=("id", )))
pgast.ResTarget(val=pgast.ColumnRef(name=("id",)))
)

sql_tree = pg_ast.SelectStmt(
sql_tree = pgast.SelectStmt(
target_list=target_list,
from_clause=[
pg_ast.RangeSubselect(
pgast.RangeSubselect(
subquery=sql_tree,
alias=pg_ast.Alias(
alias=pgast.Alias(
aliasname="_inner", colnames=inner_colnames
)
)
Expand All @@ -5040,11 +5040,13 @@ def _compile_conversion_expr(

if produce_ctes:
# convert root query into last CTE
ctes.append(pg_ast.CommonTableExpr(
name="_conv_rel",
aliascolnames=["val", "id"],
query=sql_tree
))
ctes.append(
pgast.CommonTableExpr(
name="_conv_rel",
aliascolnames=["val", "id"],
query=sql_tree,
)
)
# compile to SQL
ctes_sql = codegen.generate_ctes_source(ctes)

Expand Down
39 changes: 20 additions & 19 deletions edb/pgsql/schemamech.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from edb.common import ast
from edb.common import parsing

from . import ast as pg_ast
from . import ast as pgast
from . import dbops
from . import deltadbops
from . import common
Expand Down Expand Up @@ -106,34 +106,37 @@ class ExprDataSources:
plain_chunks: Sequence[str]


def _to_source(sql_expr: pg_ast.Base) -> str:
def _to_source(sql_expr: pgast.Base) -> str:
src = codegen.generate_source(sql_expr)
# ColumnRefs are the most common thing, and they should be safe to
# skip parenthesizing, for deuglification purposes. anything else
# we put parens around, to be sure.
if not isinstance(sql_expr, pg_ast.ColumnRef):
if not isinstance(sql_expr, pgast.ColumnRef):
src = f'({src})'
return src


def _edgeql_tree_to_expr_data(
sql_expr: pg_ast.Base, refs: Optional[Set[pg_ast.ColumnRef]] = None
sql_expr: pgast.Base, refs: Optional[Set[pgast.ColumnRef]] = None
) -> ExprDataSources:
if refs is None:
refs = set(ast.find_children(
sql_expr, pg_ast.ColumnRef, lambda n: len(n.name) == 1))
refs = set(
ast.find_children(
sql_expr, pgast.ColumnRef, lambda n: len(n.name) == 1
)
)

plain_expr = _to_source(sql_expr)

if isinstance(sql_expr, (pg_ast.RowExpr, pg_ast.ImplicitRowExpr)):
if isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr)):
chunks = []

for elem in sql_expr.args:
chunks.append(_to_source(elem))
else:
chunks = [plain_expr]

if isinstance(sql_expr, pg_ast.ColumnRef):
if isinstance(sql_expr, pgast.ColumnRef):
refs.add(sql_expr)

for ref in refs:
Expand All @@ -158,12 +161,12 @@ def _edgeql_ref_to_pg_constr(
) -> ExprData:
sql_res = compiler.compile_ir_to_sql_tree(tree, singleton_mode=True)

sql_expr: pg_ast.Base
if isinstance(sql_res.ast, pg_ast.SelectStmt):
sql_expr: pgast.Base
if isinstance(sql_res.ast, pgast.SelectStmt):
# XXX: use ast pattern matcher for this
from_clause = sql_res.ast.from_clause[0]
assert isinstance(from_clause, pg_ast.RelRangeVar)
assert isinstance(from_clause.relation, pg_ast.CommonTableExpr)
assert isinstance(from_clause, pgast.RelRangeVar)
assert isinstance(from_clause.relation, pgast.CommonTableExpr)
sql_expr = from_clause.relation.query.target_list[0].val
else:
sql_expr = sql_res.ast
Expand All @@ -174,22 +177,20 @@ def _edgeql_ref_to_pg_constr(
if isinstance(tree, irast.Set) and isinstance(tree.expr, irast.SelectStmt):
tree = tree.expr.result

is_multicol = isinstance(sql_expr, (pg_ast.RowExpr, pg_ast.ImplicitRowExpr))
is_multicol = isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr))

# Determine if the sequence of references are all simple refs, not
# expressions. This influences the type of Postgres constraint used.
#
is_trivial = isinstance(sql_expr, pg_ast.ColumnRef) or (
isinstance(sql_expr, (pg_ast.RowExpr, pg_ast.ImplicitRowExpr))
and all(isinstance(el, pg_ast.ColumnRef) for el in sql_expr.args)
is_trivial = isinstance(sql_expr, pgast.ColumnRef) or (
isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr))
and all(isinstance(el, pgast.ColumnRef) for el in sql_expr.args)
)

# Find all field references
#
refs = set(
ast.find_children(
sql_expr, pg_ast.ColumnRef, lambda n: len(n.name) == 1
)
ast.find_children(sql_expr, pgast.ColumnRef, lambda n: len(n.name) == 1)
)

if isinstance(subject, s_scalars.ScalarType):
Expand Down

0 comments on commit 1930838

Please sign in to comment.