diff --git a/edb/pgsql/compiler/__init__.py b/edb/pgsql/compiler/__init__.py index 75375274d0f4..b4e633877dd3 100644 --- a/edb/pgsql/compiler/__init__.py +++ b/edb/pgsql/compiler/__init__.py @@ -39,6 +39,7 @@ from . import dispatch from . import dml from . import pathctx +from . import aliases from .context import OutputFormat as OutputFormat # NOQA @@ -75,6 +76,7 @@ def compile_ir_to_sql_tree( ] = None, backend_runtime_params: Optional[pgparams.BackendRuntimeParams]=None, detach_params: bool = False, + alias_generator: Optional[aliases.AliasGenerator] = None ) -> CompileResult: try: # Transform to sql tree @@ -110,6 +112,7 @@ def compile_ir_to_sql_tree( backend_runtime_params = pgparams.get_default_runtime_params() env = context.Environment( + alias_generator=alias_generator, output_format=output_format, expected_cardinality_one=expected_cardinality_one, named_param_prefix=named_param_prefix, diff --git a/edb/pgsql/compiler/context.py b/edb/pgsql/compiler/context.py index 7b101fcfd0b6..112b0a609753 100644 --- a/edb/pgsql/compiler/context.py +++ b/edb/pgsql/compiler/context.py @@ -517,6 +517,7 @@ class Environment: def __init__( self, *, + alias_generator: Optional[aliases.AliasGenerator] = None, output_format: Optional[OutputFormat], named_param_prefix: Optional[tuple[str, ...]], expected_cardinality_one: bool, @@ -534,7 +535,7 @@ def __init__( # XXX: TRAMPOLINE: THIS IS WRONG versioned_stdlib: bool = True, ) -> None: - self.aliases = aliases.AliasGenerator() + self.aliases = alias_generator or aliases.AliasGenerator() self.output_format = output_format self.named_param_prefix = named_param_prefix self.ptrref_source_visibility = {} diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index 9910e138a3d1..daa005f74bf3 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -42,6 +42,7 @@ from . import context from . import expr as pg_res_expr from . import relation as pg_res_rel +from . import range_var as pg_res_range_var Context = context.ResolverContextLevel @@ -56,7 +57,7 @@ def resolve_CopyStmt(stmt: pgast.CopyStmt, *, ctx: Context) -> pgast.CopyStmt: elif stmt.relation: relation, table = dispatch.resolve_relation(stmt.relation, ctx=ctx) - table.reference_as = ctx.names.get('rel') + table.reference_as = ctx.alias_generator.get('rel') selected_columns = _pull_columns_from_table( table, @@ -162,8 +163,10 @@ def resolve_InsertStmt( ) val_rel, val_table = compile_insert_value( - stmt.select_stmt, expected_columns, ctx + stmt.select_stmt, stmt.ctes, expected_columns, ctx ) + value_ctes = val_rel.ctes if val_rel.ctes else [] + val_rel.ctes = None # if we are sure that we are inserting a single row, # we can skip for loops and the iterator, so we generate better SQL @@ -234,7 +237,7 @@ def resolve_InsertStmt( # for link ids. assert isinstance(val_rel, pgast.Query) source_cte = pgast.CommonTableExpr( - name=ctx.names.get('ins_source'), + name=ctx.alias_generator.get('ins_source'), query=pgast.SelectStmt( from_clause=[pgast.RangeSubselect(subquery=val_rel)], target_list=pre_projection, @@ -243,7 +246,7 @@ def resolve_InsertStmt( ) # source needs an identity column, so we need to invent one - source_identity = ctx.names.get('identity') + source_identity = ctx.alias_generator.get('identity') source_cte.query.target_list.append( pgast.ResTarget( name=source_identity, @@ -317,6 +320,7 @@ def resolve_InsertStmt( ir_stmt, external_rels={source_id: (source_cte, ('source', 'identity'))}, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, + alias_generator=ctx.alias_generator, ) except errors.QueryError as e: raise errors.QueryError( @@ -329,7 +333,7 @@ def resolve_InsertStmt( assert isinstance(sql_result.ast, pgast.Query) assert sql_result.ast.ctes - ctes = [source_cte] + sql_result.ast.ctes + ctes = value_ctes + [source_cte] + sql_result.ast.ctes sql_result.ast.ctes.clear() if ctx.subquery_depth == 0: @@ -363,6 +367,7 @@ def resolve_InsertStmt( def compile_insert_value( value_query: Optional[pgast.Query], + value_ctes: Optional[List[pgast.CommonTableExpr]], expected_columns: List[context.Column], ctx: context.ResolverContextLevel, ) -> Tuple[pgast.BaseRelation, context.Table]: @@ -402,25 +407,25 @@ def is_default(e: pgast.BaseExpr) -> bool: value_query.values[r_index] = row.replace(args=cols) # INSERT INTO x DEFAULT VALUES - val_rel: pgast.BaseRelation - if value_query: - val_rel = value_query - else: - val_rel = pgast.SelectStmt(values=[]) - + value_query: pgast.BaseRelation + if not value_query: + value_query = pgast.SelectStmt(values=[]) # edgeql compiler will provide default values # (and complain about missing ones) expected_columns = [] + # compile these CTEs as they were defined on value relation + value_query.ctes = value_ctes + # compile value that is to be inserted - val_rel, val_table = dispatch.resolve_relation(val_rel, ctx=ctx) + val_rel, val_table = dispatch.resolve_relation(value_query, ctx=ctx) if len(expected_columns) != len(val_table.columns): col_names = ', '.join(c.name for c in expected_columns) raise errors.QueryError( f'INSERT expected {len(expected_columns)} columns, ' f'but got {len(val_table.columns)} (expecting {col_names})', - span=val_rel.span, + span=value_query.span, ) return val_rel, val_table diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index 90181e55825f..f22bd2134a9c 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -149,7 +149,7 @@ class ContextSwitchMode(enum.Enum): class ResolverContextLevel(compiler.ContextLevel): schema: s_schema.Schema - names: compiler.AliasGenerator + alias_generator: compiler.AliasGenerator # Visible names in scope scope: Scope @@ -186,14 +186,14 @@ def __init__( self.schema = schema self.options = options self.scope = Scope() - self.names = compiler.AliasGenerator() + self.alias_generator = compiler.AliasGenerator() self.subquery_depth = 0 self.ctes_buffer = [] else: self.schema = prevlevel.schema self.options = prevlevel.options - self.names = prevlevel.names + self.alias_generator = prevlevel.alias_generator self.subquery_depth = prevlevel.subquery_depth + 1 self.ctes_buffer = prevlevel.ctes_buffer diff --git a/edb/pgsql/resolver/expr.py b/edb/pgsql/resolver/expr.py index 369b099a1bca..9f85c0c7fd35 100644 --- a/edb/pgsql/resolver/expr.py +++ b/edb/pgsql/resolver/expr.py @@ -101,7 +101,7 @@ def resolve_ResTarget( ): alias = static.name_in_pg_catalog(res_target.val.name) - name: str = alias or ctx.names.get('col') + name: str = alias or ctx.alias_generator.get('col') col = context.Column( name=name, kind=context.ColumnByName(reference_as=name) ) diff --git a/edb/pgsql/resolver/range_var.py b/edb/pgsql/resolver/range_var.py index 74f0fdc3a236..9db657189482 100644 --- a/edb/pgsql/resolver/range_var.py +++ b/edb/pgsql/resolver/range_var.py @@ -45,7 +45,7 @@ def resolve_BaseRangeVar( return _resolve_JoinExpr(range_var, ctx=ctx) # generate internal alias - internal_alias = ctx.names.get('rel') + internal_alias = ctx.alias_generator.get('rel') alias = pgast.Alias( aliasname=internal_alias, colnames=range_var.alias.colnames ) @@ -232,7 +232,7 @@ def resolve_CommonTableExpr( aliascolnames = res if cte.recursive and aliascolnames: - reference_as = [subctx.names.get('col') for _ in aliascolnames] + reference_as = [subctx.alias_generator.get('col') for _ in aliascolnames] columns = [ context.Column( name=col, kind=context.ColumnByName(reference_as=ref_as) @@ -322,7 +322,7 @@ def _resolve_RangeFunction( context.Column( name=al or col.name, kind=context.ColumnByName( - reference_as=al or ctx.names.get('col') + reference_as=al or ctx.alias_generator.get('col') ), ) for col, al in _zip_column_alias( diff --git a/tests/test_sql_dml.py b/tests/test_sql_dml.py index 487b867a4057..596e48b05e05 100644 --- a/tests/test_sql_dml.py +++ b/tests/test_sql_dml.py @@ -189,6 +189,35 @@ async def test_sql_dml_insert_08(self): ''' ) + async def test_sql_dml_insert_07(self): + # insert with a CTE + await self.scon.execute( + ''' + WITH a AS ( + SELECT 'Report' as t UNION ALL SELECT 'Briefing' + ) + INSERT INTO "Document" (title) SELECT * FROM a + ''' + ) + res = await self.squery_values('SELECT title FROM "Document"') + self.assertEqual(res, tb.bag([['Report (new)'], ['Briefing (new)']])) + + async def test_sql_dml_insert_07(self): + # two inserts + await self.scon.execute( + ''' + WITH a AS ( + INSERT INTO "Document" (title) VALUES ('Report') + RETURNING title as t + ) + INSERT INTO "Document" (title) SELECT t || ' - copy' FROM a + ''' + ) + res = await self.squery_values('SELECT title FROM "Document"') + self.assertEqual( + res, tb.bag([['Report (new)'], ['Report (new) - copy (new)']]) + ) + async def test_sql_dml_insert_09(self): # returning await self.scon.execute('INSERT INTO "User" DEFAULT VALUES;')