Skip to content

Commit

Permalink
share alias generator between resolver and the main compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Jun 24, 2024
1 parent cf7eb2a commit 5365dbd
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 21 deletions.
3 changes: 3 additions & 0 deletions edb/pgsql/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from . import dispatch
from . import dml
from . import pathctx
from . import aliases

from .context import OutputFormat as OutputFormat # NOQA

Expand Down Expand Up @@ -75,6 +76,7 @@ def compile_ir_to_sql_tree(
] = None,
backend_runtime_params: Optional[pgparams.BackendRuntimeParams]=None,
detach_params: bool = False,
alias_generator: Optional[aliases.AliasGenerator] = None
) -> CompileResult:
try:
# Transform to sql tree
Expand Down Expand Up @@ -110,6 +112,7 @@ def compile_ir_to_sql_tree(
backend_runtime_params = pgparams.get_default_runtime_params()

env = context.Environment(
alias_generator=alias_generator,
output_format=output_format,
expected_cardinality_one=expected_cardinality_one,
named_param_prefix=named_param_prefix,
Expand Down
3 changes: 2 additions & 1 deletion edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ class Environment:
def __init__(
self,
*,
alias_generator: Optional[aliases.AliasGenerator] = None,
output_format: Optional[OutputFormat],
named_param_prefix: Optional[tuple[str, ...]],
expected_cardinality_one: bool,
Expand All @@ -534,7 +535,7 @@ def __init__(
# XXX: TRAMPOLINE: THIS IS WRONG
versioned_stdlib: bool = True,
) -> None:
self.aliases = aliases.AliasGenerator()
self.aliases = alias_generator or aliases.AliasGenerator()
self.output_format = output_format
self.named_param_prefix = named_param_prefix
self.ptrref_source_visibility = {}
Expand Down
31 changes: 18 additions & 13 deletions edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from . import context
from . import expr as pg_res_expr
from . import relation as pg_res_rel
from . import range_var as pg_res_range_var

Context = context.ResolverContextLevel

Expand All @@ -56,7 +57,7 @@ def resolve_CopyStmt(stmt: pgast.CopyStmt, *, ctx: Context) -> pgast.CopyStmt:

elif stmt.relation:
relation, table = dispatch.resolve_relation(stmt.relation, ctx=ctx)
table.reference_as = ctx.names.get('rel')
table.reference_as = ctx.alias_generator.get('rel')

selected_columns = _pull_columns_from_table(
table,
Expand Down Expand Up @@ -162,8 +163,10 @@ def resolve_InsertStmt(
)

val_rel, val_table = compile_insert_value(
stmt.select_stmt, expected_columns, ctx
stmt.select_stmt, stmt.ctes, expected_columns, ctx
)
value_ctes = val_rel.ctes if val_rel.ctes else []
val_rel.ctes = None

# if we are sure that we are inserting a single row,
# we can skip for loops and the iterator, so we generate better SQL
Expand Down Expand Up @@ -234,7 +237,7 @@ def resolve_InsertStmt(
# for link ids.
assert isinstance(val_rel, pgast.Query)
source_cte = pgast.CommonTableExpr(
name=ctx.names.get('ins_source'),
name=ctx.alias_generator.get('ins_source'),
query=pgast.SelectStmt(
from_clause=[pgast.RangeSubselect(subquery=val_rel)],
target_list=pre_projection,
Expand All @@ -243,7 +246,7 @@ def resolve_InsertStmt(
)

# source needs an identity column, so we need to invent one
source_identity = ctx.names.get('identity')
source_identity = ctx.alias_generator.get('identity')
source_cte.query.target_list.append(
pgast.ResTarget(
name=source_identity,
Expand Down Expand Up @@ -317,6 +320,7 @@ def resolve_InsertStmt(
ir_stmt,
external_rels={source_id: (source_cte, ('source', 'identity'))},
output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL,
alias_generator=ctx.alias_generator,
)
except errors.QueryError as e:
raise errors.QueryError(
Expand All @@ -329,7 +333,7 @@ def resolve_InsertStmt(

assert isinstance(sql_result.ast, pgast.Query)
assert sql_result.ast.ctes
ctes = [source_cte] + sql_result.ast.ctes
ctes = value_ctes + [source_cte] + sql_result.ast.ctes
sql_result.ast.ctes.clear()

if ctx.subquery_depth == 0:
Expand Down Expand Up @@ -363,6 +367,7 @@ def resolve_InsertStmt(

def compile_insert_value(
value_query: Optional[pgast.Query],
value_ctes: Optional[List[pgast.CommonTableExpr]],
expected_columns: List[context.Column],
ctx: context.ResolverContextLevel,
) -> Tuple[pgast.BaseRelation, context.Table]:
Expand Down Expand Up @@ -402,25 +407,25 @@ def is_default(e: pgast.BaseExpr) -> bool:
value_query.values[r_index] = row.replace(args=cols)

# INSERT INTO x DEFAULT VALUES
val_rel: pgast.BaseRelation
if value_query:
val_rel = value_query
else:
val_rel = pgast.SelectStmt(values=[])

value_query: pgast.BaseRelation
if not value_query:
value_query = pgast.SelectStmt(values=[])
# edgeql compiler will provide default values
# (and complain about missing ones)
expected_columns = []

# compile these CTEs as they were defined on value relation
value_query.ctes = value_ctes

# compile value that is to be inserted
val_rel, val_table = dispatch.resolve_relation(val_rel, ctx=ctx)
val_rel, val_table = dispatch.resolve_relation(value_query, ctx=ctx)

if len(expected_columns) != len(val_table.columns):
col_names = ', '.join(c.name for c in expected_columns)
raise errors.QueryError(
f'INSERT expected {len(expected_columns)} columns, '
f'but got {len(val_table.columns)} (expecting {col_names})',
span=val_rel.span,
span=value_query.span,
)

return val_rel, val_table
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ContextSwitchMode(enum.Enum):

class ResolverContextLevel(compiler.ContextLevel):
schema: s_schema.Schema
names: compiler.AliasGenerator
alias_generator: compiler.AliasGenerator

# Visible names in scope
scope: Scope
Expand Down Expand Up @@ -186,14 +186,14 @@ def __init__(
self.schema = schema
self.options = options
self.scope = Scope()
self.names = compiler.AliasGenerator()
self.alias_generator = compiler.AliasGenerator()
self.subquery_depth = 0
self.ctes_buffer = []

else:
self.schema = prevlevel.schema
self.options = prevlevel.options
self.names = prevlevel.names
self.alias_generator = prevlevel.alias_generator

self.subquery_depth = prevlevel.subquery_depth + 1
self.ctes_buffer = prevlevel.ctes_buffer
Expand Down
2 changes: 1 addition & 1 deletion edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def resolve_ResTarget(
):
alias = static.name_in_pg_catalog(res_target.val.name)

name: str = alias or ctx.names.get('col')
name: str = alias or ctx.alias_generator.get('col')
col = context.Column(
name=name, kind=context.ColumnByName(reference_as=name)
)
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/resolver/range_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def resolve_BaseRangeVar(
return _resolve_JoinExpr(range_var, ctx=ctx)

# generate internal alias
internal_alias = ctx.names.get('rel')
internal_alias = ctx.alias_generator.get('rel')
alias = pgast.Alias(
aliasname=internal_alias, colnames=range_var.alias.colnames
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def resolve_CommonTableExpr(
aliascolnames = res

if cte.recursive and aliascolnames:
reference_as = [subctx.names.get('col') for _ in aliascolnames]
reference_as = [subctx.alias_generator.get('col') for _ in aliascolnames]
columns = [
context.Column(
name=col, kind=context.ColumnByName(reference_as=ref_as)
Expand Down Expand Up @@ -322,7 +322,7 @@ def _resolve_RangeFunction(
context.Column(
name=al or col.name,
kind=context.ColumnByName(
reference_as=al or ctx.names.get('col')
reference_as=al or ctx.alias_generator.get('col')
),
)
for col, al in _zip_column_alias(
Expand Down
29 changes: 29 additions & 0 deletions tests/test_sql_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,35 @@ async def test_sql_dml_insert_08(self):
'''
)

async def test_sql_dml_insert_07(self):
# insert with a CTE
await self.scon.execute(
'''
WITH a AS (
SELECT 'Report' as t UNION ALL SELECT 'Briefing'
)
INSERT INTO "Document" (title) SELECT * FROM a
'''
)
res = await self.squery_values('SELECT title FROM "Document"')
self.assertEqual(res, tb.bag([['Report (new)'], ['Briefing (new)']]))

async def test_sql_dml_insert_07(self):
# two inserts
await self.scon.execute(
'''
WITH a AS (
INSERT INTO "Document" (title) VALUES ('Report')
RETURNING title as t
)
INSERT INTO "Document" (title) SELECT t || ' - copy' FROM a
'''
)
res = await self.squery_values('SELECT title FROM "Document"')
self.assertEqual(
res, tb.bag([['Report (new)'], ['Report (new) - copy (new)']])
)

async def test_sql_dml_insert_09(self):
# returning
await self.scon.execute('INSERT INTO "User" DEFAULT VALUES;')
Expand Down

0 comments on commit 5365dbd

Please sign in to comment.