From c6a7040b0c950ff47c4ee5086990a5f61fe67f35 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 12 Nov 2024 11:34:06 -0800 Subject: [PATCH] Remove multiplicity from `QueryUnit.sql` Currently, `QueryUnit.sql` is a tuple representing, possibly, multiple SQL statements corresponding to the original EdgeQL statement. This introduces significant complexity to consumers of `QueryUnit`, which are mostly unprepared to handle more than one SQL statement anyway. Originally the SQL tuple was used to represent multiple EdgeQL statements, a task which is now handled by the `QueryUnitGroup` stuff. Another use case are the non-transactional commands (`CREATE BRANCH` and friends). Finally, since this facility was available, more uses of it were added without actually _needing_ be executed as multiple SQL statements with no other recourse: those are mostly maintenance commands and the DDL type_id readback. I think the above uses are no longer a sufficient reason to keep the tuple complexity and so I'm ripping it out here, making `QueryUnit.sql` a `bytes` property with the invariant of _always_ containing exactly one SQL statement and thus not needing any special handling. The users are fixed up as follows: - Non-transactional branch units encode the extra SQL needed to represent them in the new `QueryUnit.db_op_trailer` property which is a tuple of SQL. Branch commands have special handling for them already, so this is not a nuisance. - The newly-added type id mappings produced by DDL commands are now communicated via the new "indirect return" mechanism, whereby a DDL PLBlock can communicate a return value via a specially-formatted `NoticeResponse` message. - All other unwitting users of multi-statement `QueryUnit` are converted to use a single statement instead (primarily by converting them to use a `DO` block). --- edb/buildmeta.py | 2 +- edb/edgeql/ast.py | 7 +- edb/pgsql/dbops/base.py | 29 ++ edb/pgsql/delta.py | 7 +- edb/pgsql/metaschema.py | 49 ++- edb/server/bootstrap.py | 18 +- edb/server/compiler/compiler.py | 553 +++++++++++++++++--------------- edb/server/compiler/dbstate.py | 157 ++++----- edb/server/compiler/ddl.py | 135 ++++---- edb/server/pgcluster.py | 6 +- edb/server/pgcon/pgcon.pxd | 2 + edb/server/pgcon/pgcon.pyi | 2 +- edb/server/pgcon/pgcon.pyx | 187 ++++++----- edb/server/protocol/binary.pyx | 3 +- edb/server/protocol/execute.pyx | 90 +++--- edb/server/server.py | 5 +- 16 files changed, 683 insertions(+), 569 deletions(-) diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 07c602cd2d6..cfa5ff88228 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_11_08_01_00 +EDGEDB_CATALOG_VERSION = 2024_11_12_00_00 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index eb74f85fe5d..ec919e1a6b1 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -674,6 +674,11 @@ class DDLCommand(Command, DDLOperation): __abstract_node__ = True +class NonTransactionalDDLCommand(DDLCommand): + __abstract_node__ = True + __rust_ignore__ = True + + class AlterAddInherit(DDLOperation): position: typing.Optional[Position] = None bases: typing.List[TypeName] @@ -868,7 +873,7 @@ class BranchType(s_enum.StrEnum): TEMPLATE = 'TEMPLATE' -class DatabaseCommand(ExternalObjectCommand): +class DatabaseCommand(ExternalObjectCommand, NonTransactionalDDLCommand): __abstract_node__ = True __rust_ignore__ = True diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index 9dde6621435..cba2c162e4d 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -504,6 +504,35 @@ def __repr__(self) -> str: return f'' +class PLQuery(Command): + def __init__( + self, + text: str, + *, + type: Optional[str | Tuple[str, str]] = None, + trampoline_fixup: bool = True, + ) -> None: + from ..import trampoline + + super().__init__() + if trampoline_fixup: + text = trampoline.fixup_query(text) + self.text = text + self.type = type + + def to_sql_expr(self) -> str: + if self.type: + return f'({self.text})::{qn(*self.type)}' + else: + return self.text + + def code_with_block(self, block: PLBlock) -> str: + return self.text + + def __repr__(self) -> str: + return f'' + + class DefaultMeta(type): def __bool__(cls): return False diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 817054691ab..d5f22d51d4e 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -2953,7 +2953,7 @@ def _alter_finalize( return schema -def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: +def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.PLQuery: if len(pg_type) == 1: types_cte = f''' SELECT @@ -2980,7 +2980,6 @@ def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: )\ ''' drop_func_cache_sql = textwrap.dedent(f''' - DO $$ DECLARE qc RECORD; BEGIN @@ -3014,9 +3013,9 @@ class AS ( LOOP PERFORM edgedb_VER."_evict_query_cache"(qc.key); END LOOP; - END $$; + END; ''') - return dbops.Query(drop_func_cache_sql) + return dbops.PLQuery(drop_func_cache_sql) class DeleteScalarType(ScalarTypeMetaCommand, diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 0146429b28a..9630a3c6868 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -105,13 +105,13 @@ class PGConnection(Protocol): async def sql_execute( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, ) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -1497,6 +1497,50 @@ def __init__(self) -> None: ) +class RaiseNoticeFunction(trampoline.VersionedFunction): + text = ''' + BEGIN + RAISE NOTICE USING + MESSAGE = "msg", + DETAIL = COALESCE("detail", ''), + HINT = COALESCE("hint", ''), + COLUMN = COALESCE("column", ''), + CONSTRAINT = COALESCE("constraint", ''), + DATATYPE = COALESCE("datatype", ''), + TABLE = COALESCE("table", ''), + SCHEMA = COALESCE("schema", ''); + RETURN "rtype"; + END; + ''' + + def __init__(self) -> None: + super().__init__( + name=('edgedb', 'notice'), + args=[ + ('rtype', ('anyelement',)), + ('msg', ('text',), "''"), + ('detail', ('text',), "''"), + ('hint', ('text',), "''"), + ('column', ('text',), "''"), + ('constraint', ('text',), "''"), + ('datatype', ('text',), "''"), + ('table', ('text',), "''"), + ('schema', ('text',), "''"), + ], + returns=('anyelement',), + # NOTE: The main reason why we don't want this function to be + # immutable is that immutable functions can be + # pre-evaluated by the query planner once if they have + # constant arguments. This means that using this function + # as the second argument in a COALESCE will raise a + # notice regardless of whether the first argument is + # NULL or not. + volatility='stable', + language='plpgsql', + text=self.text, + ) + + class RaiseExceptionFunction(trampoline.VersionedFunction): text = ''' BEGIN @@ -4980,6 +5024,7 @@ def get_bootstrap_commands( dbops.CreateFunction(GetSharedObjectMetadata()), dbops.CreateFunction(GetDatabaseMetadataFunction()), dbops.CreateFunction(GetCurrentDatabaseFunction()), + dbops.CreateFunction(RaiseNoticeFunction()), dbops.CreateFunction(RaiseExceptionFunction()), dbops.CreateFunction(RaiseExceptionOnNullFunction()), dbops.CreateFunction(RaiseExceptionOnNotNullFunction()), diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index c829549bd79..153a225251f 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -173,7 +173,7 @@ async def _retry_conn_errors( return result - async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: + async def sql_execute(self, sql: bytes) -> None: async def _task() -> None: assert self._conn is not None await self._conn.sql_execute(sql) @@ -181,7 +181,7 @@ async def _task() -> None: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -634,8 +634,8 @@ def compile_single_query( ) -> str: ql_source = edgeql.Source.from_string(eql) units = edbcompiler.compile(ctx=compilerctx, source=ql_source).units - assert len(units) == 1 and len(units[0].sql) == 1 - return units[0].sql[0].decode() + assert len(units) == 1 + return units[0].sql.decode() def _get_all_subcommands( @@ -687,7 +687,7 @@ def prepare_repair_patch( schema_class_layout: s_refl.SchemaClassLayout, backend_params: params.BackendRuntimeParams, config: Any, -) -> tuple[bytes, ...]: +) -> bytes: compiler = edbcompiler.new_compiler( std_schema=stdschema, reflection_schema=reflschema, @@ -701,7 +701,7 @@ def prepare_repair_patch( ) res = edbcompiler.repair_schema(compilerctx) if not res: - return () + return b"" sql, _, _ = res return sql @@ -2111,10 +2111,10 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_2_0 = units[0].out_type_id + units[0].out_type_data - queries['report_configs'] = units[0].sql[0].decode() + queries['report_configs'] = units[0].sql.decode() units = edbcompiler.compile( ctx=edbcompiler.new_compiler_context( @@ -2128,7 +2128,7 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_1_0 = units[0].out_type_id + units[0].out_type_data return ( diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 5813181485c..59993239f40 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -429,7 +429,7 @@ def _try_compile_rollback( sql = b'ROLLBACK;' unit = dbstate.QueryUnit( status=b'ROLLBACK', - sql=(sql,), + sql=sql, tx_rollback=True, cacheable=False) @@ -437,7 +437,7 @@ def _try_compile_rollback( sql = f'ROLLBACK TO {pg_common.quote_ident(stmt.name)};'.encode() unit = dbstate.QueryUnit( status=b'ROLLBACK TO SAVEPOINT', - sql=(sql,), + sql=sql, tx_savepoint_rollback=True, sp_name=stmt.name, cacheable=False) @@ -1316,12 +1316,11 @@ def _compile_schema_storage_stmt( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) if len(sql_stmts) > 1: raise errors.InternalServerError( @@ -1356,12 +1355,11 @@ def _compile_ql_script( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) return b'\n'.join(sql_stmts).decode() @@ -1485,7 +1483,7 @@ def _compile_ql_explain( span=ql.span, ) - assert len(query.sql) == 1, query.sql + assert query.sql out_type_data, out_type_id = sertypes.describe( schema, @@ -1493,7 +1491,7 @@ def _compile_ql_explain( protocol_version=ctx.protocol_version, ) - sql_bytes = exp_command.encode('utf-8') + query.sql[0] + sql_bytes = exp_command.encode('utf-8') + query.sql sql_hash = _hash_sql( sql_bytes, mode=str(ctx.output_format).encode(), @@ -1503,9 +1501,9 @@ def _compile_ql_explain( return dataclasses.replace( query, is_explain=True, - append_rollback=args['execute'], + run_and_rollback=args['execute'], cacheable=False, - sql=(sql_bytes,), + sql=sql_bytes, sql_hash=sql_hash, cardinality=enums.Cardinality.ONE, out_type_data=out_type_data, @@ -1531,7 +1529,7 @@ def _compile_ql_administer( span=ql.expr.span, ) - return dbstate.MaintenanceQuery(sql=(b'ANALYZE',)) + return dbstate.MaintenanceQuery(sql=b'ANALYZE') elif ql.expr.func == 'schema_repair': return ddl.administer_repair_schema(ctx, ql) elif ql.expr.func == 'reindex': @@ -1692,7 +1690,7 @@ def _compile_ql_query( query_asts = None return dbstate.Query( - sql=(sql_bytes,), + sql=sql_bytes, sql_hash=sql_hash, cache_sql=cache_sql, cache_func_call=cache_func_call, @@ -1885,7 +1883,7 @@ def _compile_ql_transaction( if ql.deferrable is not None: sqls += f' {ql.deferrable.value}' sqls += ';' - sql = (sqls.encode(),) + sql = sqls.encode() action = dbstate.TxAction.START cacheable = False @@ -1901,7 +1899,7 @@ def _compile_ql_transaction( new_state = ctx.state.commit_tx() modaliases = new_state.modaliases - sql = (b'COMMIT',) + sql = b'COMMIT' cacheable = False action = dbstate.TxAction.COMMIT @@ -1909,7 +1907,7 @@ def _compile_ql_transaction( new_state = ctx.state.rollback_tx() modaliases = new_state.modaliases - sql = (b'ROLLBACK',) + sql = b'ROLLBACK' cacheable = False action = dbstate.TxAction.ROLLBACK @@ -1918,7 +1916,7 @@ def _compile_ql_transaction( sp_id = tx.declare_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'SAVEPOINT {pgname}'.encode(),) + sql = f'SAVEPOINT {pgname}'.encode() cacheable = False action = dbstate.TxAction.DECLARE_SAVEPOINT @@ -1928,7 +1926,7 @@ def _compile_ql_transaction( elif isinstance(ql, qlast.ReleaseSavepoint): ctx.state.current_tx().release_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'RELEASE SAVEPOINT {pgname}'.encode(),) + sql = f'RELEASE SAVEPOINT {pgname}'.encode() action = dbstate.TxAction.RELEASE_SAVEPOINT elif isinstance(ql, qlast.RollbackToSavepoint): @@ -1937,7 +1935,7 @@ def _compile_ql_transaction( modaliases = new_state.modaliases pgname = pg_common.quote_ident(ql.name) - sql = (f'ROLLBACK TO SAVEPOINT {pgname};'.encode(),) + sql = f'ROLLBACK TO SAVEPOINT {pgname};'.encode() cacheable = False action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT sp_name = ql.name @@ -1996,9 +1994,7 @@ def _compile_ql_sess_state( ctx.state.current_tx().update_modaliases(aliases) - return dbstate.SessionStateQuery( - sql=(), - ) + return dbstate.SessionStateQuery() def _get_config_spec( @@ -2170,9 +2166,7 @@ def _compile_ql_config_op( if pretty: debug.dump_code(sql_text, lexer='sql') - sql: tuple[bytes, ...] = ( - sql_text.encode(), - ) + sql = sql_text.encode() in_type_args, in_type_data, in_type_id = describe_params( ctx, ir, sql_res.argmap, None @@ -2363,7 +2357,6 @@ def _try_compile( if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) - default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) statements_len = len(statements) @@ -2399,15 +2392,6 @@ def _try_compile( _check_force_database_error(stmt_ctx, stmt) - # Initialize user_schema_version with the version this query is - # going to be compiled upon. This can be overwritten later by DDLs. - try: - schema_version = _get_schema_version( - stmt_ctx.state.current_tx().get_user_schema() - ) - except errors.InvalidReferenceError: - schema_version = None - comp, capabilities = _compile_dispatch_ql( stmt_ctx, stmt, @@ -2416,234 +2400,21 @@ def _try_compile( in_script=is_script, ) - unit = dbstate.QueryUnit( - sql=(), - status=status.get_status(stmt), - cardinality=default_cardinality, + unit, user_schema = _make_query_unit( + ctx=ctx, + stmt_ctx=stmt_ctx, + stmt=stmt, + is_script=is_script, + is_trailing_stmt=is_trailing_stmt, + comp=comp, capabilities=capabilities, - output_format=stmt_ctx.output_format, - cache_key=ctx.cache_key, - user_schema_version=schema_version, - warnings=comp.warnings, ) - if not comp.is_transactional: - if is_script: - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'with other commands in one block', - span=stmt.span, - ) - - if not ctx.state.current_tx().is_implicit(): - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'in a transaction', - span=stmt.span, - ) - - unit.is_transactional = False - - if isinstance(comp, dbstate.Query): - unit.sql = comp.sql - unit.cache_sql = comp.cache_sql - unit.cache_func_call = comp.cache_func_call - unit.globals = comp.globals - unit.in_type_args = comp.in_type_args - - unit.sql_hash = comp.sql_hash - - unit.out_type_data = comp.out_type_data - unit.out_type_id = comp.out_type_id - unit.in_type_data = comp.in_type_data - unit.in_type_id = comp.in_type_id - - unit.cacheable = comp.cacheable - - if comp.is_explain: - unit.is_explain = True - unit.query_asts = comp.query_asts - - if comp.append_rollback: - unit.append_rollback = True - - if is_trailing_stmt: - unit.cardinality = comp.cardinality - - elif isinstance(comp, dbstate.SimpleQuery): - unit.sql = comp.sql - unit.in_type_args = comp.in_type_args - - elif isinstance(comp, dbstate.DDLQuery): - unit.sql = comp.sql - unit.create_db = comp.create_db - unit.drop_db = comp.drop_db - unit.drop_db_reset_connections = comp.drop_db_reset_connections - unit.create_db_template = comp.create_db_template - unit.create_db_mode = comp.create_db_mode - unit.ddl_stmt_id = comp.ddl_stmt_id - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - unit.config_ops.extend(comp.config_ops) - - elif isinstance(comp, dbstate.TxControlQuery): - if is_script: - raise errors.QueryError( - "Explicit transaction control commands cannot be executed " - "in an implicit transaction block" - ) - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: - unit.tx_savepoint_rollback = True - unit.sp_name = comp.sp_name - elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: - unit.tx_savepoint_declare = True - unit.sp_name = comp.sp_name - unit.sp_id = comp.sp_id - - elif isinstance(comp, dbstate.MigrationControlQuery): - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - unit.ddl_stmt_id = comp.ddl_stmt_id - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.tx_action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.tx_action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.tx_action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action == dbstate.MigrationAction.ABORT: - unit.tx_abort_migration = True - - elif isinstance(comp, dbstate.SessionStateQuery): - unit.sql = comp.sql - unit.globals = comp.globals - - if comp.config_scope is qltypes.ConfigScope.INSTANCE: - if (not ctx.state.current_tx().is_implicit() or - statements_len > 1): - raise errors.QueryError( - 'CONFIGURE INSTANCE cannot be executed in a ' - 'transaction block') - - unit.system_config = True - elif comp.config_scope is qltypes.ConfigScope.GLOBAL: - unit.needs_readback = True - - elif comp.config_scope is qltypes.ConfigScope.DATABASE: - unit.database_config = True - unit.needs_readback = True - - if comp.is_backend_setting: - unit.backend_config = True - if comp.requires_restart: - unit.config_requires_restart = True - if comp.is_system_config: - unit.is_system_config = True - - unit.modaliases = ctx.state.current_tx().get_modaliases() - - if comp.config_op is not None: - unit.config_ops.append(comp.config_op) - - if comp.in_type_args: - unit.in_type_args = comp.in_type_args - if comp.in_type_data: - unit.in_type_data = comp.in_type_data - if comp.in_type_id: - unit.in_type_id = comp.in_type_id - - unit.has_set = True - - elif isinstance(comp, dbstate.MaintenanceQuery): - unit.sql = comp.sql - - elif isinstance(comp, dbstate.NullQuery): - pass - - else: # pragma: no cover - raise errors.InternalServerError('unknown compile state') - - if unit.in_type_args: - unit.in_type_args_real_count = sum( - len(p.sub_params[0]) if p.sub_params else 1 - for p in unit.in_type_args - ) - - if unit.warnings: - for warning in unit.warnings: - warning.__traceback__ = None - rv.append(unit) + if user_schema is not None: + final_user_schema = user_schema + if script_info: if ctx.state.current_tx().is_implicit(): if ctx.state.current_tx().get_migration_state(): @@ -2690,7 +2461,6 @@ def _try_compile( f'QueryUnit {unit!r} is cacheable but has config/aliases') if not na_cardinality and ( - len(unit.sql) > 1 or unit.tx_commit or unit.tx_rollback or unit.tx_savepoint_rollback or @@ -2715,6 +2485,261 @@ def _try_compile( return rv +def _make_query_unit( + *, + ctx: CompileContext, + stmt_ctx: CompileContext, + stmt: qlast.Base, + is_script: bool, + is_trailing_stmt: bool, + comp: dbstate.BaseQuery, + capabilities: enums.Capability, +) -> tuple[dbstate.QueryUnit, Optional[s_schema.Schema]]: + + # Initialize user_schema_version with the version this query is + # going to be compiled upon. This can be overwritten later by DDLs. + try: + schema_version = _get_schema_version( + stmt_ctx.state.current_tx().get_user_schema() + ) + except errors.InvalidReferenceError: + schema_version = None + + unit = dbstate.QueryUnit( + sql=b"", + status=status.get_status(stmt), + cardinality=enums.Cardinality.NO_RESULT, + capabilities=capabilities, + output_format=stmt_ctx.output_format, + cache_key=ctx.cache_key, + user_schema_version=schema_version, + warnings=comp.warnings, + ) + + if not comp.is_transactional: + if is_script: + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'with other commands in one block', + span=stmt.span, + ) + + if not ctx.state.current_tx().is_implicit(): + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'in a transaction', + span=stmt.span, + ) + + unit.is_transactional = False + + final_user_schema: Optional[s_schema.Schema] = None + + if isinstance(comp, dbstate.Query): + unit.sql = comp.sql + unit.cache_sql = comp.cache_sql + unit.cache_func_call = comp.cache_func_call + unit.globals = comp.globals + unit.in_type_args = comp.in_type_args + + unit.sql_hash = comp.sql_hash + + unit.out_type_data = comp.out_type_data + unit.out_type_id = comp.out_type_id + unit.in_type_data = comp.in_type_data + unit.in_type_id = comp.in_type_id + + unit.cacheable = comp.cacheable + + if comp.is_explain: + unit.is_explain = True + unit.query_asts = comp.query_asts + + if comp.run_and_rollback: + unit.run_and_rollback = True + + if is_trailing_stmt: + unit.cardinality = comp.cardinality + + elif isinstance(comp, dbstate.SimpleQuery): + unit.sql = comp.sql + unit.in_type_args = comp.in_type_args + + elif isinstance(comp, dbstate.DDLQuery): + unit.sql = comp.sql + unit.db_op_trailer = comp.db_op_trailer + unit.create_db = comp.create_db + unit.drop_db = comp.drop_db + unit.drop_db_reset_connections = comp.drop_db_reset_connections + unit.create_db_template = comp.create_db_template + unit.create_db_mode = comp.create_db_mode + unit.ddl_stmt_id = comp.ddl_stmt_id + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + unit.config_ops.extend(comp.config_ops) + + elif isinstance(comp, dbstate.TxControlQuery): + if is_script: + raise errors.QueryError( + "Explicit transaction control commands cannot be executed " + "in an implicit transaction block" + ) + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.action == dbstate.TxAction.START: + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + elif comp.action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: + unit.tx_savepoint_rollback = True + unit.sp_name = comp.sp_name + elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: + unit.tx_savepoint_declare = True + unit.sp_name = comp.sp_name + unit.sp_id = comp.sp_id + + elif isinstance(comp, dbstate.MigrationControlQuery): + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + unit.ddl_stmt_id = comp.ddl_stmt_id + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.tx_action == dbstate.TxAction.START: + # units[0:0] = _make_tx_units(ctx, qlast.StartTransaction()) + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.tx_action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + unit.append_tx_op = True + elif comp.tx_action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + unit.append_tx_op = True + elif comp.action == dbstate.MigrationAction.ABORT: + unit.tx_abort_migration = True + + elif isinstance(comp, dbstate.SessionStateQuery): + unit.sql = comp.sql + unit.globals = comp.globals + + if comp.config_scope is qltypes.ConfigScope.INSTANCE: + if not ctx.state.current_tx().is_implicit() or is_script: + raise errors.QueryError( + 'CONFIGURE INSTANCE cannot be executed in a ' + 'transaction block') + + unit.system_config = True + elif comp.config_scope is qltypes.ConfigScope.GLOBAL: + unit.needs_readback = True + + elif comp.config_scope is qltypes.ConfigScope.DATABASE: + unit.database_config = True + unit.needs_readback = True + + if comp.is_backend_setting: + unit.backend_config = True + if comp.requires_restart: + unit.config_requires_restart = True + if comp.is_system_config: + unit.is_system_config = True + + unit.modaliases = ctx.state.current_tx().get_modaliases() + + if comp.config_op is not None: + unit.config_ops.append(comp.config_op) + + if comp.in_type_args: + unit.in_type_args = comp.in_type_args + if comp.in_type_data: + unit.in_type_data = comp.in_type_data + if comp.in_type_id: + unit.in_type_id = comp.in_type_id + + unit.has_set = True + unit.output_format = enums.OutputFormat.NONE + + elif isinstance(comp, dbstate.MaintenanceQuery): + unit.sql = comp.sql + + elif isinstance(comp, dbstate.NullQuery): + pass + + else: # pragma: no cover + raise errors.InternalServerError('unknown compile state') + + if unit.in_type_args: + unit.in_type_args_real_count = sum( + len(p.sub_params[0]) if p.sub_params else 1 + for p in unit.in_type_args + ) + + if unit.warnings: + for warning in unit.warnings: + warning.__traceback__ = None + + return unit, final_user_schema + + def _extract_params( params: List[irast.Param], *, diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 73a1f753b70..ea44756af78 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -60,7 +60,6 @@ class TxAction(enum.IntEnum): - START = 1 COMMIT = 2 ROLLBACK = 3 @@ -71,7 +70,6 @@ class TxAction(enum.IntEnum): class MigrationAction(enum.IntEnum): - START = 1 POPULATE = 2 DESCRIBE = 3 @@ -80,10 +78,11 @@ class MigrationAction(enum.IntEnum): REJECT_PROPOSED = 6 -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class BaseQuery: - - sql: Tuple[bytes, ...] + sql: bytes + is_transactional: bool = True + has_dml: bool = False cache_sql: Optional[Tuple[bytes, bytes]] = dataclasses.field( kw_only=True, default=None ) # (persist, evict) @@ -94,22 +93,14 @@ class BaseQuery: kw_only=True, default=() ) - @property - def is_transactional(self) -> bool: - return True - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class NullQuery(BaseQuery): - - sql: Tuple[bytes, ...] = tuple() - is_transactional: bool = True - has_dml: bool = False + sql: bytes = b"" -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class Query(BaseQuery): - sql_hash: bytes cardinality: enums.Cardinality @@ -122,27 +113,21 @@ class Query(BaseQuery): globals: Optional[list[tuple[str, bool]]] = None - is_transactional: bool = True - has_dml: bool = False cacheable: bool = True is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SimpleQuery(BaseQuery): - - sql: Tuple[bytes, ...] - is_transactional: bool = True - has_dml: bool = False # XXX: Temporary hack, since SimpleQuery will die in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SessionStateQuery(BaseQuery): - + sql: bytes = b"" config_scope: Optional[qltypes.ConfigScope] = None is_backend_setting: bool = False requires_restart: bool = False @@ -156,9 +141,8 @@ class SessionStateQuery(BaseQuery): in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class DDLQuery(BaseQuery): - user_schema: Optional[s_schema.FlatSchema] feature_used_metrics: Optional[dict[str, float]] global_schema: Optional[s_schema.FlatSchema] = None @@ -169,18 +153,17 @@ class DDLQuery(BaseQuery): drop_db_reset_connections: bool = False create_db_template: Optional[str] = None create_db_mode: Optional[qlast.BranchType] = None + db_op_trailer: tuple[bytes, ...] = () ddl_stmt_id: Optional[str] = None config_ops: List[config.Operation] = dataclasses.field(default_factory=list) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class TxControlQuery(BaseQuery): - action: TxAction cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.Schema] = None global_schema: Optional[s_schema.Schema] = None @@ -191,25 +174,22 @@ class TxControlQuery(BaseQuery): sp_id: Optional[int] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MigrationControlQuery(BaseQuery): - action: MigrationAction tx_action: Optional[TxAction] cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.FlatSchema] = None cached_reflection: Any = None ddl_stmt_id: Optional[str] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MaintenanceQuery(BaseQuery): - - is_transactional: bool = True + pass @dataclasses.dataclass(frozen=True) @@ -224,10 +204,9 @@ class Param: ############################# -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class QueryUnit: - - sql: Tuple[bytes, ...] + sql: bytes # Status-line for the compiled command; returned to front-end # in a CommandComplete protocol message if the command is @@ -244,9 +223,9 @@ class QueryUnit: # Set only for units that contain queries that can be cached # as prepared statements in Postgres. - sql_hash: bytes = b'' + sql_hash: bytes = b"" - # True if all statments in *sql* can be executed inside a transaction. + # True if all statements in *sql* can be executed inside a transaction. # If False, they will be executed separately. is_transactional: bool = True @@ -298,6 +277,10 @@ class QueryUnit: create_db_template: Optional[str] = None create_db_mode: Optional[str] = None + # If a branch command needs extra SQL commands to be performed, + # those would end up here. + db_op_trailer: tuple[bytes, ...] = () + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None @@ -357,7 +340,8 @@ class QueryUnit: is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False + append_tx_op: bool = False @property def has_ddl(self) -> bool: @@ -390,13 +374,12 @@ def deserialize(cls, data: bytes) -> Self: def maybe_use_func_cache(self) -> None: if self.cache_func_call is not None: sql, sql_hash = self.cache_func_call - self.sql = (sql,) + self.sql = sql self.sql_hash = sql_hash @dataclasses.dataclass class QueryUnitGroup: - # All capabilities used by any query units in this group capabilities: enums.Capability = enums.Capability(0) @@ -568,29 +551,29 @@ class SQLQueryUnit: class CommandCompleteTag: - '''Dictates the tag of CommandComplete message that concludes this query.''' + """Dictates the tag of CommandComplete message that concludes this query.""" @dataclasses.dataclass(kw_only=True) class TagPlain(CommandCompleteTag): - '''Set the tag verbatim''' + """Set the tag verbatim""" tag: bytes @dataclasses.dataclass(kw_only=True) class TagCountMessages(CommandCompleteTag): - '''Count DataRow messages in the response and set the tag to - f'{prefix} {count_of_messages}'.''' + """Count DataRow messages in the response and set the tag to + f'{prefix} {count_of_messages}'.""" prefix: str @dataclasses.dataclass(kw_only=True) class TagUnpackRow(CommandCompleteTag): - '''Intercept a single DataRow message with a single column which represents + """Intercept a single DataRow message with a single column which represents the number of modified rows. - Sets the CommandComplete tag to f'{prefix} {modified_rows}'.''' + Sets the CommandComplete tag to f'{prefix} {modified_rows}'.""" prefix: str @@ -634,13 +617,15 @@ class ParsedDatabase: SQLSetting = tuple[str | int | float, ...] SQLSettings = immutables.Map[Optional[str], Optional[SQLSetting]] DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map() -DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map({ - "search_path": ("public",), - "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), - "server_version_num": cast( - SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) - ), -}) +DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map( + { + "search_path": ("public",), + "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), + "server_version_num": cast( + SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) + ), + } +) @dataclasses.dataclass @@ -680,11 +665,16 @@ def apply(self, query_unit: SQLQueryUnit) -> None: self.in_tx_local_settings = None self.savepoints.clear() elif query_unit.tx_action == TxAction.DECLARE_SAVEPOINT: - self.savepoints.append(( - query_unit.sp_name, - self.in_tx_settings, - self.in_tx_local_settings, - )) # type: ignore + assert query_unit.sp_name is not None + assert self.in_tx_settings is not None + assert self.in_tx_local_settings is not None + self.savepoints.append( + ( + query_unit.sp_name, + self.in_tx_settings, + self.in_tx_local_settings, + ) + ) elif query_unit.tx_action == TxAction.ROLLBACK_TO_SAVEPOINT: while self.savepoints: sp_name, settings, local_settings = self.savepoints[-1] @@ -735,7 +725,6 @@ def _set(attr_name: str) -> None: class ProposedMigrationStep(NamedTuple): - statements: Tuple[str, ...] confidence: float prompt: str @@ -748,17 +737,16 @@ class ProposedMigrationStep(NamedTuple): def to_json(self) -> Dict[str, Any]: return { - 'statements': [{'text': stmt} for stmt in self.statements], - 'confidence': self.confidence, - 'prompt': self.prompt, - 'prompt_id': self.prompt_id, - 'data_safe': self.data_safe, - 'required_user_input': list(self.required_user_input), + "statements": [{"text": stmt} for stmt in self.statements], + "confidence": self.confidence, + "prompt": self.prompt, + "prompt_id": self.prompt_id, + "data_safe": self.data_safe, + "required_user_input": list(self.required_user_input), } class MigrationState(NamedTuple): - parent_migration: Optional[s_migrations.Migration] initial_schema: s_schema.Schema initial_savepoint: Optional[str] @@ -769,14 +757,12 @@ class MigrationState(NamedTuple): class MigrationRewriteState(NamedTuple): - initial_savepoint: Optional[str] target_schema: s_schema.Schema accepted_migrations: Tuple[qlast.CreateMigration, ...] class TransactionState(NamedTuple): - id: int name: Optional[str] local_user_schema: s_schema.FlatSchema | None @@ -799,7 +785,6 @@ def user_schema(self) -> s_schema.FlatSchema: class Transaction: - _savepoints: Dict[int, TransactionState] _constate: CompilerConnectionState @@ -816,7 +801,6 @@ def __init__( cached_reflection: immutables.Map[str, Tuple[str, ...]], implicit: bool = True, ) -> None: - assert not isinstance(user_schema, s_schema.ChainedSchema) self._constate = constate @@ -857,12 +841,12 @@ def make_explicit(self) -> None: if self._implicit: self._implicit = False else: - raise errors.TransactionError('already in explicit transaction') + raise errors.TransactionError("already in explicit transaction") def declare_savepoint(self, name: str) -> int: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._declare_savepoint(name) @@ -882,7 +866,7 @@ def _declare_savepoint(self, name: str) -> int: def rollback_to_savepoint(self, name: str) -> TransactionState: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._rollback_to_savepoint(name) @@ -899,7 +883,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: sp_ids_to_erase.append(sp.id) else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -909,7 +893,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: def release_savepoint(self, name: str) -> None: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) self._release_savepoint(name) @@ -925,7 +909,7 @@ def _release_savepoint(self, name: str) -> None: if sp.name == name: break else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -1030,8 +1014,7 @@ def update_migration_rewrite_state( class CompilerConnectionState: - - __slots__ = ('_savepoints_log', '_current_tx', '_tx_count', '_user_schema') + __slots__ = ("_savepoints_log", "_current_tx", "_tx_count", "_user_schema") _savepoints_log: Dict[int, TransactionState] _user_schema: Optional[s_schema.FlatSchema] @@ -1111,7 +1094,7 @@ def sync_to_savepoint(self, spid: int) -> None: """Synchronize the compiler state with the current DB state.""" if not self.can_sync_to_savepoint(spid): - raise RuntimeError(f'failed to lookup savepoint with id={spid}') + raise RuntimeError(f"failed to lookup savepoint with id={spid}") sp = self._savepoints_log[spid] self._current_tx = sp.tx @@ -1137,7 +1120,7 @@ def start_tx(self) -> None: if self._current_tx.is_implicit(): self._current_tx.make_explicit() else: - raise errors.TransactionError('already in transaction') + raise errors.TransactionError("already in transaction") def rollback_tx(self) -> TransactionState: # Note that we might not be in a transaction as we allow @@ -1159,7 +1142,7 @@ def rollback_tx(self) -> TransactionState: def commit_tx(self) -> TransactionState: if self._current_tx.is_implicit(): - raise errors.TransactionError('cannot commit: not in transaction') + raise errors.TransactionError("cannot commit: not in transaction") latest_state = self._current_tx._current @@ -1184,5 +1167,5 @@ def sync_tx(self, txid: int) -> None: return raise errors.InternalServerError( - f'failed to lookup transaction or savepoint with id={txid}' + f"failed to lookup transaction or savepoint with id={txid}" ) # pragma: no cover diff --git a/edb/server/compiler/ddl.py b/edb/server/compiler/ddl.py index 9754e9eeeb8..f9c87a705ba 100644 --- a/edb/server/compiler/ddl.py +++ b/edb/server/compiler/ddl.py @@ -64,17 +64,28 @@ from edb.pgsql import common as pg_common from edb.pgsql import delta as pg_delta from edb.pgsql import dbops as pg_dbops -from edb.pgsql import trampoline from . import dbstate from . import compiler +NIL_QUERY = b"SELECT LIMIT 0" + + def compile_and_apply_ddl_stmt( ctx: compiler.CompileContext, - stmt: qlast.DDLOperation, + stmt: qlast.DDLCommand, source: Optional[edgeql.Source] = None, ) -> dbstate.DDLQuery: + query, _ = _compile_and_apply_ddl_stmt(ctx, stmt, source) + return query + + +def _compile_and_apply_ddl_stmt( + ctx: compiler.CompileContext, + stmt: qlast.DDLCommand, + source: Optional[edgeql.Source] = None, +) -> tuple[dbstate.DDLQuery, Optional[pg_dbops.SQLBlock]]: if isinstance(stmt, qlast.GlobalObjectCommand): ctx._assert_not_in_migration_block(stmt) @@ -127,7 +138,7 @@ def compile_and_apply_ddl_stmt( ) ], ) - return compile_and_apply_ddl_stmt(ctx, cm) + return _compile_and_apply_ddl_stmt(ctx, cm) assert isinstance(stmt, qlast.DDLCommand) new_schema, delta = s_ddl.delta_and_schema_from_ddl( @@ -175,14 +186,16 @@ def compile_and_apply_ddl_stmt( current_tx.update_migration_state(mstate) current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + store_migration_sdl = compiler._get_config_val(ctx, 'store_migration_sdl') if ( isinstance(stmt, qlast.CreateMigration) @@ -207,26 +220,30 @@ def compile_and_apply_ddl_stmt( current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + # Apply and adapt delta, build native delta plan, which # will also update the schema. block, new_types, config_ops = _process_delta(ctx, delta) ddl_stmt_id: Optional[str] = None - is_transactional = block.is_transactional() if not is_transactional: - sql = tuple(stmt.encode('utf-8') for stmt in block.get_statements()) + if not isinstance(stmt, qlast.DatabaseCommand): + raise AssertionError( + f"unexpected non-transaction DDL command type: {stmt}") + sql_stmts = block.get_statements() + sql = sql_stmts[0].encode("utf-8") + db_op_trailer = tuple(stmt.encode("utf-8") for stmt in sql_stmts[1:]) else: - sql = (block.to_string().encode('utf-8'),) - if new_types: # Inject a query returning backend OIDs for the newly # created types. @@ -234,11 +251,11 @@ def compile_and_apply_ddl_stmt( new_type_ids = [ f'{pg_common.quote_literal(tid)}::uuid' for tid in new_types ] - sql = sql + ( - trampoline.fixup_query(textwrap.dedent( - f'''\ - SELECT - json_build_object( + new_types_sql = textwrap.dedent(f"""\ + PERFORM edgedb.notice( + NULL::text, + msg => 'edb:notice:indirect_return', + detail => json_build_object( 'ddl_stmt_id', {pg_common.quote_literal(ddl_stmt_id)}, 'new_types', @@ -254,11 +271,15 @@ def compile_and_apply_ddl_stmt( {', '.join(new_type_ids)} ]) ) - )::text; - ''' - )).encode('utf-8'), + )::text + )""" ) + block.add_command(pg_dbops.Query(text=new_types_sql).code()) + + sql = block.to_string().encode('utf-8') + db_op_trailer = () + create_db = None drop_db = None drop_db_reset_connections = False @@ -286,10 +307,10 @@ def compile_and_apply_ddl_stmt( debug.dump_code(code, lexer='sql') if debug.flags.delta_execute: debug.header('Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql + b"\n".join(db_op_trailer), lexer='sql') new_user_schema = current_tx.get_user_schema_if_updated() - return dbstate.DDLQuery( + query = dbstate.DDLQuery( sql=sql, is_transactional=is_transactional, create_db=create_db, @@ -297,6 +318,7 @@ def compile_and_apply_ddl_stmt( drop_db_reset_connections=drop_db_reset_connections, create_db_template=create_db_template, create_db_mode=create_db_mode, + db_op_trailer=db_op_trailer, ddl_stmt_id=ddl_stmt_id, user_schema=new_user_schema, cached_reflection=current_tx.get_cached_reflection_if_updated(), @@ -309,6 +331,8 @@ def compile_and_apply_ddl_stmt( ), ) + return query, block + def _new_delta_context( ctx: compiler.CompileContext, args: Any = None @@ -464,7 +488,7 @@ def _start_migration( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -573,7 +597,7 @@ def _populate_migration( current_tx.update_schema(schema) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.POPULATE, cacheable=False, @@ -801,7 +825,7 @@ def _alter_current_migration_reject_proposed( current_tx.update_migration_state(mstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.REJECT_PROPOSED, cacheable=False, @@ -876,7 +900,7 @@ def _commit_migration( current_tx.update_migration_rewrite_state(mrstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.COMMIT, tx_action=None, cacheable=False, @@ -893,16 +917,12 @@ def _commit_migration( if mstate.initial_savepoint: current_tx.commit_migration(mstate.initial_savepoint) - sql = ddl_query.sql tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sql = ddl_query.sql + tx_query.sql - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=sql, + sql=ddl_query.sql, ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, @@ -923,7 +943,7 @@ def _abort_migration( if mstate.initial_savepoint: current_tx.abort_migration(mstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -967,7 +987,7 @@ def _start_migration_rewrite( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -1052,25 +1072,24 @@ def _commit_migration_rewrite( for cm in cmds: cm.dump_edgeql() - sqls: List[bytes] = [] + block = pg_dbops.PLTopBlock() for cmd in cmds: - ddl_query = compile_and_apply_ddl_stmt(ctx, cmd) + _, ddl_block = _compile_and_apply_ddl_stmt(ctx, cmd) + assert isinstance(ddl_block, pg_dbops.PLBlock) # We know nothing serious can be in that query # except for the SQL, so it's fine to just discard # it all. - sqls.extend(ddl_query.sql) + for stmt in ddl_block.get_statements(): + block.add_command(stmt) if mrstate.initial_savepoint: current_tx.commit_migration(mrstate.initial_savepoint) tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sqls.extend(tx_query.sql) - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=block.to_string().encode("utf-8"), action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, cacheable=False, @@ -1090,7 +1109,7 @@ def _abort_migration_rewrite( if mrstate.initial_savepoint: current_tx.abort_migration(mrstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -1146,8 +1165,6 @@ def _reset_schema( current_schema=empty_schema, ) - sqls: List[bytes] = [] - # diff and create migration that drops all objects diff = s_ddl.delta_schemas(schema, empty_schema) new_ddl: Tuple[qlast.DDLCommand, ...] = tuple( @@ -1156,8 +1173,8 @@ def _reset_schema( create_mig = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock(commands=tuple(new_ddl)), # type: ignore ) - ddl_query = compile_and_apply_ddl_stmt(ctx, create_mig) - sqls.extend(ddl_query.sql) + ddl_query, ddl_block = _compile_and_apply_ddl_stmt(ctx, create_mig) + assert ddl_block is not None # delete all migrations schema = current_tx.get_schema(ctx.compiler_state.std_schema) @@ -1170,11 +1187,13 @@ def _reset_schema( drop_mig = qlast.DropMigration( # type: ignore name=qlast.ObjectRef(name=mig.get_name(schema).name), ) - ddl_query = compile_and_apply_ddl_stmt(ctx, drop_mig) - sqls.extend(ddl_query.sql) + _, mig_block = _compile_and_apply_ddl_stmt(ctx, drop_mig) + assert isinstance(mig_block, pg_dbops.PLBlock) + for stmt in mig_block.get_statements(): + ddl_block.add_command(stmt) return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=ddl_block.to_string().encode("utf-8"), ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=None, @@ -1278,7 +1297,7 @@ def _track(key: str) -> None: def repair_schema( ctx: compiler.CompileContext, -) -> Optional[tuple[tuple[bytes, ...], s_schema.Schema, Any]]: +) -> Optional[tuple[bytes, s_schema.Schema, Any]]: """Repair inconsistencies in the schema caused by bug fixes Works by comparing the actual current schema to the schema we get @@ -1340,11 +1359,11 @@ def repair_schema( is_transactional = block.is_transactional() assert not new_types assert is_transactional - sql = (block.to_string().encode('utf-8'),) + sql = block.to_string().encode('utf-8') if debug.flags.delta_execute: debug.header('Repair Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql, lexer='sql') return sql, reloaded_schema, config_ops @@ -1363,7 +1382,7 @@ def administer_repair_schema( res = repair_schema(ctx) if not res: - return dbstate.MaintenanceQuery(sql=(b'',)) + return dbstate.MaintenanceQuery(sql=b"") sql, new_schema, config_ops = res current_tx.update_schema(new_schema) @@ -1511,9 +1530,11 @@ def administer_reindex( for pindex in pindexes ] - return dbstate.MaintenanceQuery( - sql=tuple(q.encode('utf-8') for q in commands) - ) + block = pg_dbops.PLTopBlock() + for command in commands: + block.add_command(command) + + return dbstate.MaintenanceQuery(sql=block.to_string().encode("utf-8")) def administer_vacuum( @@ -1663,7 +1684,7 @@ def administer_vacuum( command = f'VACUUM {options} ' + ', '.join(tables_and_columns) return dbstate.MaintenanceQuery( - sql=(command.encode('utf-8'),), + sql=command.encode('utf-8'), is_transactional=False, ) diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index 03272dfe382..ca3ab04c9ae 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -79,6 +79,10 @@ EDGEDB_SERVER_SETTINGS = { 'client_encoding': 'utf-8', + # DO NOT raise client_min_messages above NOTICE level + # because server indirect block return machinery relies + # on NoticeResponse as the data channel. + 'client_min_messages': 'NOTICE', 'search_path': 'edgedb', 'timezone': 'UTC', 'intervalstyle': 'iso_8601', @@ -568,7 +572,7 @@ async def start( else: log_level_map = { 'd': 'INFO', - 'i': 'NOTICE', + 'i': 'WARNING', # NOTICE in Postgres is quite noisy 'w': 'WARNING', 'e': 'ERROR', 's': 'PANIC', diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 348b4010b40..4b6a0ed0608 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -139,6 +139,8 @@ cdef class PGConnection: object last_state + str last_indirect_return + cdef before_command(self) cdef write(self, buf) diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index ee35a9fad88..e6c00c287b9 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -53,7 +53,7 @@ class PGConnection(asyncio.Protocol): async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 2d77944cb0d..0c66aa3e1c3 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -100,6 +100,7 @@ cdef dict POSTGRES_SHUTDOWN_ERR_CODES = { } cdef object EMPTY_SQL_STATE = b"{}" +cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() cdef object logger = logging.getLogger('edb.server') @@ -239,6 +240,8 @@ cdef class PGConnection: self.last_parse_prep_stmts = [] self.debug = debug.flags.server_proto + self.last_indirect_return = None + self.log_listeners = [] self.server = None @@ -590,6 +593,8 @@ cdef class PGConnection: WriteBuffer bind_data bytes stmt_name ssize_t idx = start + bytes sql + tuple sqls out = WriteBuffer.new() parsed = set() @@ -607,7 +612,6 @@ cdef class PGConnection: ) stmt_name = query_unit.sql_hash if stmt_name: - assert len(query_unit.sql) == 1 # The same EdgeQL query may show up twice in the same script. # We just need to know and skip if we've already parsed the # same query within current send batch, because self.prep_stmts @@ -624,15 +628,16 @@ cdef class PGConnection: for query_unit, bind_data in zip( query_unit_group.units[start:end], bind_datas): stmt_name = query_unit.sql_hash + sql = query_unit.sql if stmt_name: if parse_array[idx]: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) - buf.write_bytestring(query_unit.sql[0]) + buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( - len(query_unit.sql[0]), + len(sql), self.get_tenant_label(), 'compiled', ) @@ -649,26 +654,25 @@ cdef class PGConnection: out.write_buffer(buf.end_message()) else: - for sql in query_unit.sql: - buf = WriteBuffer.new_message(b'P') - buf.write_bytestring(b'') # statement name - buf.write_bytestring(sql) - buf.write_int16(0) - out.write_buffer(buf.end_message()) - metrics.query_size.observe( - len(sql), self.get_tenant_label(), 'compiled' - ) + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') # statement name + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + metrics.query_size.observe( + len(sql), self.get_tenant_label(), 'compiled' + ) - buf = WriteBuffer.new_message(b'B') - buf.write_bytestring(b'') # portal name - buf.write_bytestring(b'') # statement name - buf.write_buffer(bind_data) - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_buffer(bind_data) + out.write_buffer(buf.end_message()) - buf = WriteBuffer.new_message(b'E') - buf.write_bytestring(b'') # portal name - buf.write_int32(0) # limit: 0 - return all rows - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) idx += 1 @@ -686,7 +690,7 @@ cdef class PGConnection: out = WriteBuffer.new() buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(b'???') + buf.write_bytestring(b'') buf.write_int16(0) # Then do a sync to get everything executed and lined back up @@ -796,7 +800,6 @@ cdef class PGConnection: elif mtype == b'I': ## result # EmptyQueryResponse self.buffer.discard_message() - return result else: self.fallthrough() @@ -819,6 +822,10 @@ cdef class PGConnection: WriteBuffer out WriteBuffer buf bytes stmt_name + bytes sql + tuple sqls + bytes prologue_sql + bytes epilogue_sql int32_t dat_len @@ -833,14 +840,6 @@ cdef class PGConnection: uint64_t msgs_executed = 0 uint64_t i - if use_pending_func_cache and query.cache_func_call: - sql, stmt_name = query.cache_func_call - sqls = (sql,) - else: - sqls = query.sql - stmt_name = query.sql_hash - msgs_num = (len(sqls)) - out = WriteBuffer.new() if state is not None: @@ -848,7 +847,7 @@ cdef class PGConnection: if ( query.tx_id or not query.is_transactional - or query.append_rollback + or query.run_and_rollback or tx_isolation is not None ): # This query has START TRANSACTION or non-transactional command @@ -860,22 +859,22 @@ cdef class PGConnection: state_sync = 1 self.write_sync(out) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: if self.in_tx(): sp_name = f'_edb_{time.monotonic_ns()}' - sql = f'SAVEPOINT {sp_name}'.encode('utf-8') + prologue_sql = f'SAVEPOINT {sp_name}'.encode('utf-8') else: sp_name = None - sql = b'START TRANSACTION' + prologue_sql = b'START TRANSACTION' if tx_isolation is not None: - sql += ( + prologue_sql += ( f' ISOLATION LEVEL {tx_isolation._value_}' .encode('utf-8') ) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(sql) + buf.write_bytestring(prologue_sql) buf.write_int16(0) out.write_buffer(buf.end_message()) @@ -895,9 +894,17 @@ cdef class PGConnection: # Insert a SYNC as a boundary of the parsing logic later self.write_sync(out) + if use_pending_func_cache and query.cache_func_call: + sql, stmt_name = query.cache_func_call + sqls = (sql,) + else: + sqls = (query.sql,) + query.db_op_trailer + stmt_name = query.sql_hash + + msgs_num = (len(sqls)) + if use_prep_stmt: - parse = self.before_prepare( - stmt_name, dbver, out) + parse = self.before_prepare(stmt_name, dbver, out) else: stmt_name = b'' @@ -962,8 +969,8 @@ cdef class PGConnection: buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) - if query.append_rollback or tx_isolation is not None: - if query.append_rollback: + if query.run_and_rollback or tx_isolation is not None: + if query.run_and_rollback: if sp_name: sql = f'ROLLBACK TO SAVEPOINT {sp_name}'.encode('utf-8') else: @@ -985,6 +992,35 @@ cdef class PGConnection: buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) + elif query.append_tx_op: + if query.tx_commit: + sql = b'COMMIT' + elif query.tx_rollback: + sql = b'ROLLBACK' + else: + raise errors.InternalServerError( + "QueryUnit.append_tx_op is set but none of the " + "Query.tx_ properties are" + ) + + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_int16(0) # number of format codes + buf.write_int16(0) # number of parameters + buf.write_int16(0) # number of result columns + out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows @@ -999,7 +1035,7 @@ cdef class PGConnection: if state is not None: await self.wait_for_state_resp(state, state_sync) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: await self.wait_for_sync() buf = None @@ -1098,7 +1134,7 @@ cdef class PGConnection: self, *, query, - WriteBuffer bind_data, + WriteBuffer bind_data = NO_ARGS, frontend.AbstractFrontendConnection fe_conn = None, bint use_prep_stmt = False, bytes state = None, @@ -1127,29 +1163,21 @@ cdef class PGConnection: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, ) -> list[tuple[bytes, ...]]: - cdef tuple sql_tuple - - if not isinstance(sql, tuple): - sql_tuple = (sql,) - else: - sql_tuple = sql - if use_prep_stmt: sql_digest = hashlib.sha1() - for stmt in sql_tuple: - sql_digest.update(stmt) + sql_digest.update(sql) sql_hash = sql_digest.hexdigest().encode('latin1') else: sql_hash = None query = compiler.QueryUnit( - sql=sql_tuple, + sql=sql, sql_hash=sql_hash, status=b"", ) @@ -2067,42 +2095,24 @@ cdef class PGConnection: msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) - async def run_ddl( - self, - object query_unit, - bytes state=None - ): - data = await self.sql_fetch(query_unit.sql, state=state) - if query_unit.ddl_stmt_id is None: - return - else: - return self.load_ddl_return(query_unit, data) - - def load_ddl_return(self, object query_unit, data): + def load_last_ddl_return(self, object query_unit): if query_unit.ddl_stmt_id: + data = self.last_indirect_return if data: - ret = json.loads(data[0][0]) + ret = json.loads(data) if ret['ddl_stmt_id'] != query_unit.ddl_stmt_id: raise RuntimeError( - 'unrecognized data packet after a DDL command: ' - 'data_stmt_id do not match' + 'unrecognized data notice after a DDL command: ' + 'data_stmt_id do not match: expected ' + f'{query_unit.ddl_stmt_id!r}, got ' + f'{ret["ddl_stmt_id"]!r}' ) return ret else: raise RuntimeError( - 'missing the required data packet after a DDL command' + 'missing the required data notice after a DDL command' ) - async def handle_ddl_in_script( - self, object query_unit, bint parse, int dbver - ): - data = None - for sql in query_unit.sql: - data = await self.wait_for_command( - query_unit, parse, dbver, ignore_data=bool(data) - ) or data - return self.load_ddl_return(query_unit, data) - async def _dump(self, block, output_queue, fragment_suggested_size): cdef: WriteBuffer buf @@ -2527,6 +2537,7 @@ cdef class PGConnection: 'previous one') self.idle = False + self.last_indirect_return = None async def after_command(self): if self.idle: @@ -2675,14 +2686,18 @@ cdef class PGConnection: elif mtype == b'N': # NoticeResponse - if self.log_listeners: - _, fields = self.parse_error_message() - severity = fields.get('V') - message = fields.get('M') + _, fields = self.parse_error_message() + severity = fields.get('V') + message = fields.get('M') + detail = fields.get('D') + if ( + severity == "NOTICE" + and message.startswith("edb:notice:indirect_return") + ): + self.last_indirect_return = detail + elif self.log_listeners: for listener in self.log_listeners: self.loop.call_soon(listener, severity, message) - else: - self.buffer.discard_message() return True return False diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 72a8f265d9b..d35d5664dcf 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1582,7 +1582,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): if query_unit.sql: if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) + await pgcon.parse_execute(query=query_unit) + ddl_ret = pgcon.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 28730acc3fa..bb63859a204 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -94,19 +94,14 @@ cdef class ExecutionGroup: if state is not None: await be_conn.wait_for_state_resp(state, state_sync=0) for i, unit in enumerate(self.group): - if unit.output_format == FMT_NONE and unit.ddl_stmt_id is None: - for sql in unit.sql: - await be_conn.wait_for_command( - unit, parse_array[i], dbver, ignore_data=True - ) - rv = None - else: - for sql in unit.sql: - rv = await be_conn.wait_for_command( - unit, parse_array[i], dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + ignore_data = unit.output_format == FMT_NONE + rv = await be_conn.wait_for_command( + unit, + parse_array[i], + dbver, + ignore_data=ignore_data, + fe_conn=None if ignore_data else fe_conn, + ) return rv @@ -135,13 +130,11 @@ cpdef ExecutionGroup build_cache_persistence_units( assert serialized_result is not None if evict: - group.append(compiler.QueryUnit(sql=(evict,), status=b'')) + group.append(compiler.QueryUnit(sql=evict, status=b'')) if persist: - group.append(compiler.QueryUnit(sql=(persist,), status=b'')) + group.append(compiler.QueryUnit(sql=persist, status=b'')) group.append( - compiler.QueryUnit( - sql=(insert_sql,), sql_hash=sql_hash, status=b'', - ), + compiler.QueryUnit(sql=insert_sql, sql_hash=sql_hash, status=b''), args_ser.combine_raw_args(( query_unit.cache_key.bytes, query_unit.user_schema_version.bytes, @@ -276,9 +269,11 @@ async def execute( if query_unit.sql: if query_unit.user_schema: - ddl_ret = await be_conn.run_ddl(query_unit, state) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] + await be_conn.parse_execute(query=query_unit, state=state) + if query_unit.ddl_stmt_id is not None: + ddl_ret = be_conn.load_last_ddl_return(query_unit) + if ddl_ret and ddl_ret['new_types']: + new_types = ddl_ret['new_types'] else: bound_args_buf = args_ser.recode_bind_args( dbv, compiled, bind_args) @@ -519,35 +514,29 @@ async def execute_script( if query_unit.sql: parse = parse_array[idx] + fe_output = query_unit.output_format != FMT_NONE + ignore_data = ( + not fe_output + and not query_unit.needs_readback + ) + data = await conn.wait_for_command( + query_unit, + parse, + dbver, + ignore_data=ignore_data, + fe_conn=fe_conn if fe_output else None, + ) + if query_unit.ddl_stmt_id: - ddl_ret = await conn.handle_ddl_in_script( - query_unit, parse, dbver - ) + ddl_ret = conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] - elif query_unit.needs_readback: - config_data = [] - for sql in query_unit.sql: - config_data = await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=False - ) - if config_data: - config_ops = [ - config.Operation.from_json(r[0][1:]) - for r in config_data - ] - elif query_unit.output_format == FMT_NONE: - for sql in query_unit.sql: - await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=True - ) - else: - for sql in query_unit.sql: - data = await conn.wait_for_command( - query_unit, parse, dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + + if query_unit.needs_readback and data: + config_ops = [ + config.Operation.from_json(r[0][1:]) + for r in data + ] if config_ops: await dbv.apply_config_ops(conn, config_ops) @@ -642,12 +631,7 @@ async def execute_system_config( await conn.sql_fetch(b'select 1', state=state) if query_unit.sql: - if len(query_unit.sql) > 1: - raise errors.InternalServerError( - "unexpected multiple SQL statements in CONFIGURE INSTANCE " - "compilation product" - ) - data = await conn.sql_fetch_col(query_unit.sql[0]) + data = await conn.sql_fetch_col(query_unit.sql) else: data = None diff --git a/edb/server/server.py b/edb/server/server.py index 557fa39eaf0..59caa6311f9 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -1433,7 +1433,7 @@ async def _maybe_apply_patches( db_config = self._parse_db_config(config_json, user_schema) try: logger.info("repairing database '%s'", dbname) - sql += bootstrap.prepare_repair_patch( + rep_sql = bootstrap.prepare_repair_patch( self._std_schema, self._refl_schema, user_schema, @@ -1442,6 +1442,7 @@ async def _maybe_apply_patches( self._tenant.get_backend_runtime_params(), db_config, ) + sql += (rep_sql,) except errors.EdgeDBError as e: if isinstance(e, errors.InternalServerError): raise @@ -1454,7 +1455,7 @@ async def _maybe_apply_patches( ) from e if sql: - await conn.sql_fetch(sql) + await conn.sql_execute(sql) logger.info( "finished applying patch %d to database '%s'", num, dbname)