From 1c16dbd7a47120262624e4f20ded5a7f7e943024 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 12 Nov 2024 13:19:00 -0800 Subject: [PATCH 01/16] Remove multiplicity from `QueryUnit.sql` (#7985) Currently, `QueryUnit.sql` is a tuple representing, possibly, multiple SQL statements corresponding to the original EdgeQL statement. This introduces significant complexity to consumers of `QueryUnit`, which are mostly unprepared to handle more than one SQL statement anyway. Originally the SQL tuple was used to represent multiple EdgeQL statements, a task which is now handled by the `QueryUnitGroup` stuff. Another use case are the non-transactional commands (`CREATE BRANCH` and friends). Finally, since this facility was available, more uses of it were added without actually _needing_ be executed as multiple SQL statements with no other recourse: those are mostly maintenance commands and the DDL type_id readback. I think the above uses are no longer a sufficient reason to keep the tuple complexity and so I'm ripping it out here, making `QueryUnit.sql` a `bytes` property with the invariant of _always_ containing exactly one SQL statement and thus not needing any special handling. The users are fixed up as follows: - Non-transactional branch units encode the extra SQL needed to represent them in the new `QueryUnit.db_op_trailer` property which is a tuple of SQL. Branch commands have special handling for them already, so this is not a nuisance. - The newly-added type id mappings produced by DDL commands are now communicated via the new "indirect return" mechanism, whereby a DDL PLBlock can communicate a return value via a specially-formatted `NoticeResponse` message. - All other unwitting users of multi-statement `QueryUnit` are converted to use a single statement instead (primarily by converting them to use a `DO` block). --- edb/buildmeta.py | 2 +- edb/edgeql/ast.py | 7 +- edb/pgsql/dbops/base.py | 4 + edb/pgsql/delta.py | 7 +- edb/pgsql/metaschema.py | 83 ++++- edb/server/bootstrap.py | 18 +- edb/server/compiler/compiler.py | 552 +++++++++++++++++--------------- edb/server/compiler/dbstate.py | 157 ++++----- edb/server/compiler/ddl.py | 133 ++++---- edb/server/pgcluster.py | 6 +- edb/server/pgcon/pgcon.pxd | 2 + edb/server/pgcon/pgcon.pyi | 2 +- edb/server/pgcon/pgcon.pyx | 187 ++++++----- edb/server/protocol/binary.pyx | 3 +- edb/server/protocol/execute.pyx | 90 +++--- edb/server/server.py | 5 +- 16 files changed, 690 insertions(+), 568 deletions(-) diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 07c602cd2d6..cfa5ff88228 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_11_08_01_00 +EDGEDB_CATALOG_VERSION = 2024_11_12_00_00 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index eb74f85fe5d..ec919e1a6b1 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -674,6 +674,11 @@ class DDLCommand(Command, DDLOperation): __abstract_node__ = True +class NonTransactionalDDLCommand(DDLCommand): + __abstract_node__ = True + __rust_ignore__ = True + + class AlterAddInherit(DDLOperation): position: typing.Optional[Position] = None bases: typing.List[TypeName] @@ -868,7 +873,7 @@ class BranchType(s_enum.StrEnum): TEMPLATE = 'TEMPLATE' -class DatabaseCommand(ExternalObjectCommand): +class DatabaseCommand(ExternalObjectCommand, NonTransactionalDDLCommand): __abstract_node__ = True __rust_ignore__ = True diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index 9dde6621435..04b9051bce8 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -504,6 +504,10 @@ def __repr__(self) -> str: return f'' +class PLQuery(Query): + pass + + class DefaultMeta(type): def __bool__(cls): return False diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 817054691ab..d5f22d51d4e 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -2953,7 +2953,7 @@ def _alter_finalize( return schema -def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: +def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.PLQuery: if len(pg_type) == 1: types_cte = f''' SELECT @@ -2980,7 +2980,6 @@ def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: )\ ''' drop_func_cache_sql = textwrap.dedent(f''' - DO $$ DECLARE qc RECORD; BEGIN @@ -3014,9 +3013,9 @@ class AS ( LOOP PERFORM edgedb_VER."_evict_query_cache"(qc.key); END LOOP; - END $$; + END; ''') - return dbops.Query(drop_func_cache_sql) + return dbops.PLQuery(drop_func_cache_sql) class DeleteScalarType(ScalarTypeMetaCommand, diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 0146429b28a..de712a72552 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -105,13 +105,13 @@ class PGConnection(Protocol): async def sql_execute( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, ) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -1497,6 +1497,83 @@ def __init__(self) -> None: ) +class RaiseNoticeFunction(trampoline.VersionedFunction): + text = ''' + BEGIN + RAISE NOTICE USING + MESSAGE = "msg", + DETAIL = COALESCE("detail", ''), + HINT = COALESCE("hint", ''), + COLUMN = COALESCE("column", ''), + CONSTRAINT = COALESCE("constraint", ''), + DATATYPE = COALESCE("datatype", ''), + TABLE = COALESCE("table", ''), + SCHEMA = COALESCE("schema", ''); + RETURN "rtype"; + END; + ''' + + def __init__(self) -> None: + super().__init__( + name=('edgedb', 'notice'), + args=[ + ('rtype', ('anyelement',)), + ('msg', ('text',), "''"), + ('detail', ('text',), "''"), + ('hint', ('text',), "''"), + ('column', ('text',), "''"), + ('constraint', ('text',), "''"), + ('datatype', ('text',), "''"), + ('table', ('text',), "''"), + ('schema', ('text',), "''"), + ], + returns=('anyelement',), + # NOTE: The main reason why we don't want this function to be + # immutable is that immutable functions can be + # pre-evaluated by the query planner once if they have + # constant arguments. This means that using this function + # as the second argument in a COALESCE will raise a + # notice regardless of whether the first argument is + # NULL or not. + volatility='stable', + language='plpgsql', + text=self.text, + ) + + +# edgedb.indirect_return() to be used to return values from +# anonymous code blocks or other contexts that have no return +# data channel. +class IndirectReturnFunction(trampoline.VersionedFunction): + text = """ + SELECT + edgedb_VER.notice( + NULL::text, + msg => 'edb:notice:indirect_return', + detail => "value" + ) + """ + + def __init__(self) -> None: + super().__init__( + name=('edgedb', 'indirect_return'), + args=[ + ('value', ('text',)), + ], + returns=('text',), + # NOTE: The main reason why we don't want this function to be + # immutable is that immutable functions can be + # pre-evaluated by the query planner once if they have + # constant arguments. This means that using this function + # as the second argument in a COALESCE will raise a + # notice regardless of whether the first argument is + # NULL or not. + volatility='stable', + language='sql', + text=self.text, + ) + + class RaiseExceptionFunction(trampoline.VersionedFunction): text = ''' BEGIN @@ -4980,6 +5057,8 @@ def get_bootstrap_commands( dbops.CreateFunction(GetSharedObjectMetadata()), dbops.CreateFunction(GetDatabaseMetadataFunction()), dbops.CreateFunction(GetCurrentDatabaseFunction()), + dbops.CreateFunction(RaiseNoticeFunction()), + dbops.CreateFunction(IndirectReturnFunction()), dbops.CreateFunction(RaiseExceptionFunction()), dbops.CreateFunction(RaiseExceptionOnNullFunction()), dbops.CreateFunction(RaiseExceptionOnNotNullFunction()), diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index c829549bd79..153a225251f 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -173,7 +173,7 @@ async def _retry_conn_errors( return result - async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: + async def sql_execute(self, sql: bytes) -> None: async def _task() -> None: assert self._conn is not None await self._conn.sql_execute(sql) @@ -181,7 +181,7 @@ async def _task() -> None: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -634,8 +634,8 @@ def compile_single_query( ) -> str: ql_source = edgeql.Source.from_string(eql) units = edbcompiler.compile(ctx=compilerctx, source=ql_source).units - assert len(units) == 1 and len(units[0].sql) == 1 - return units[0].sql[0].decode() + assert len(units) == 1 + return units[0].sql.decode() def _get_all_subcommands( @@ -687,7 +687,7 @@ def prepare_repair_patch( schema_class_layout: s_refl.SchemaClassLayout, backend_params: params.BackendRuntimeParams, config: Any, -) -> tuple[bytes, ...]: +) -> bytes: compiler = edbcompiler.new_compiler( std_schema=stdschema, reflection_schema=reflschema, @@ -701,7 +701,7 @@ def prepare_repair_patch( ) res = edbcompiler.repair_schema(compilerctx) if not res: - return () + return b"" sql, _, _ = res return sql @@ -2111,10 +2111,10 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_2_0 = units[0].out_type_id + units[0].out_type_data - queries['report_configs'] = units[0].sql[0].decode() + queries['report_configs'] = units[0].sql.decode() units = edbcompiler.compile( ctx=edbcompiler.new_compiler_context( @@ -2128,7 +2128,7 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_1_0 = units[0].out_type_id + units[0].out_type_data return ( diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 5813181485c..eb33544119e 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -429,7 +429,7 @@ def _try_compile_rollback( sql = b'ROLLBACK;' unit = dbstate.QueryUnit( status=b'ROLLBACK', - sql=(sql,), + sql=sql, tx_rollback=True, cacheable=False) @@ -437,7 +437,7 @@ def _try_compile_rollback( sql = f'ROLLBACK TO {pg_common.quote_ident(stmt.name)};'.encode() unit = dbstate.QueryUnit( status=b'ROLLBACK TO SAVEPOINT', - sql=(sql,), + sql=sql, tx_savepoint_rollback=True, sp_name=stmt.name, cacheable=False) @@ -1316,12 +1316,11 @@ def _compile_schema_storage_stmt( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) if len(sql_stmts) > 1: raise errors.InternalServerError( @@ -1356,12 +1355,11 @@ def _compile_ql_script( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) return b'\n'.join(sql_stmts).decode() @@ -1485,7 +1483,7 @@ def _compile_ql_explain( span=ql.span, ) - assert len(query.sql) == 1, query.sql + assert query.sql out_type_data, out_type_id = sertypes.describe( schema, @@ -1493,7 +1491,7 @@ def _compile_ql_explain( protocol_version=ctx.protocol_version, ) - sql_bytes = exp_command.encode('utf-8') + query.sql[0] + sql_bytes = exp_command.encode('utf-8') + query.sql sql_hash = _hash_sql( sql_bytes, mode=str(ctx.output_format).encode(), @@ -1503,9 +1501,9 @@ def _compile_ql_explain( return dataclasses.replace( query, is_explain=True, - append_rollback=args['execute'], + run_and_rollback=args['execute'], cacheable=False, - sql=(sql_bytes,), + sql=sql_bytes, sql_hash=sql_hash, cardinality=enums.Cardinality.ONE, out_type_data=out_type_data, @@ -1531,7 +1529,7 @@ def _compile_ql_administer( span=ql.expr.span, ) - return dbstate.MaintenanceQuery(sql=(b'ANALYZE',)) + return dbstate.MaintenanceQuery(sql=b'ANALYZE') elif ql.expr.func == 'schema_repair': return ddl.administer_repair_schema(ctx, ql) elif ql.expr.func == 'reindex': @@ -1692,7 +1690,7 @@ def _compile_ql_query( query_asts = None return dbstate.Query( - sql=(sql_bytes,), + sql=sql_bytes, sql_hash=sql_hash, cache_sql=cache_sql, cache_func_call=cache_func_call, @@ -1885,7 +1883,7 @@ def _compile_ql_transaction( if ql.deferrable is not None: sqls += f' {ql.deferrable.value}' sqls += ';' - sql = (sqls.encode(),) + sql = sqls.encode() action = dbstate.TxAction.START cacheable = False @@ -1901,7 +1899,7 @@ def _compile_ql_transaction( new_state = ctx.state.commit_tx() modaliases = new_state.modaliases - sql = (b'COMMIT',) + sql = b'COMMIT' cacheable = False action = dbstate.TxAction.COMMIT @@ -1909,7 +1907,7 @@ def _compile_ql_transaction( new_state = ctx.state.rollback_tx() modaliases = new_state.modaliases - sql = (b'ROLLBACK',) + sql = b'ROLLBACK' cacheable = False action = dbstate.TxAction.ROLLBACK @@ -1918,7 +1916,7 @@ def _compile_ql_transaction( sp_id = tx.declare_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'SAVEPOINT {pgname}'.encode(),) + sql = f'SAVEPOINT {pgname}'.encode() cacheable = False action = dbstate.TxAction.DECLARE_SAVEPOINT @@ -1928,7 +1926,7 @@ def _compile_ql_transaction( elif isinstance(ql, qlast.ReleaseSavepoint): ctx.state.current_tx().release_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'RELEASE SAVEPOINT {pgname}'.encode(),) + sql = f'RELEASE SAVEPOINT {pgname}'.encode() action = dbstate.TxAction.RELEASE_SAVEPOINT elif isinstance(ql, qlast.RollbackToSavepoint): @@ -1937,7 +1935,7 @@ def _compile_ql_transaction( modaliases = new_state.modaliases pgname = pg_common.quote_ident(ql.name) - sql = (f'ROLLBACK TO SAVEPOINT {pgname};'.encode(),) + sql = f'ROLLBACK TO SAVEPOINT {pgname};'.encode() cacheable = False action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT sp_name = ql.name @@ -1996,9 +1994,7 @@ def _compile_ql_sess_state( ctx.state.current_tx().update_modaliases(aliases) - return dbstate.SessionStateQuery( - sql=(), - ) + return dbstate.SessionStateQuery() def _get_config_spec( @@ -2170,9 +2166,7 @@ def _compile_ql_config_op( if pretty: debug.dump_code(sql_text, lexer='sql') - sql: tuple[bytes, ...] = ( - sql_text.encode(), - ) + sql = sql_text.encode() in_type_args, in_type_data, in_type_id = describe_params( ctx, ir, sql_res.argmap, None @@ -2363,7 +2357,6 @@ def _try_compile( if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) - default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) statements_len = len(statements) @@ -2399,15 +2392,6 @@ def _try_compile( _check_force_database_error(stmt_ctx, stmt) - # Initialize user_schema_version with the version this query is - # going to be compiled upon. This can be overwritten later by DDLs. - try: - schema_version = _get_schema_version( - stmt_ctx.state.current_tx().get_user_schema() - ) - except errors.InvalidReferenceError: - schema_version = None - comp, capabilities = _compile_dispatch_ql( stmt_ctx, stmt, @@ -2416,234 +2400,21 @@ def _try_compile( in_script=is_script, ) - unit = dbstate.QueryUnit( - sql=(), - status=status.get_status(stmt), - cardinality=default_cardinality, + unit, user_schema = _make_query_unit( + ctx=ctx, + stmt_ctx=stmt_ctx, + stmt=stmt, + is_script=is_script, + is_trailing_stmt=is_trailing_stmt, + comp=comp, capabilities=capabilities, - output_format=stmt_ctx.output_format, - cache_key=ctx.cache_key, - user_schema_version=schema_version, - warnings=comp.warnings, ) - if not comp.is_transactional: - if is_script: - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'with other commands in one block', - span=stmt.span, - ) - - if not ctx.state.current_tx().is_implicit(): - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'in a transaction', - span=stmt.span, - ) - - unit.is_transactional = False - - if isinstance(comp, dbstate.Query): - unit.sql = comp.sql - unit.cache_sql = comp.cache_sql - unit.cache_func_call = comp.cache_func_call - unit.globals = comp.globals - unit.in_type_args = comp.in_type_args - - unit.sql_hash = comp.sql_hash - - unit.out_type_data = comp.out_type_data - unit.out_type_id = comp.out_type_id - unit.in_type_data = comp.in_type_data - unit.in_type_id = comp.in_type_id - - unit.cacheable = comp.cacheable - - if comp.is_explain: - unit.is_explain = True - unit.query_asts = comp.query_asts - - if comp.append_rollback: - unit.append_rollback = True - - if is_trailing_stmt: - unit.cardinality = comp.cardinality - - elif isinstance(comp, dbstate.SimpleQuery): - unit.sql = comp.sql - unit.in_type_args = comp.in_type_args - - elif isinstance(comp, dbstate.DDLQuery): - unit.sql = comp.sql - unit.create_db = comp.create_db - unit.drop_db = comp.drop_db - unit.drop_db_reset_connections = comp.drop_db_reset_connections - unit.create_db_template = comp.create_db_template - unit.create_db_mode = comp.create_db_mode - unit.ddl_stmt_id = comp.ddl_stmt_id - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - unit.config_ops.extend(comp.config_ops) - - elif isinstance(comp, dbstate.TxControlQuery): - if is_script: - raise errors.QueryError( - "Explicit transaction control commands cannot be executed " - "in an implicit transaction block" - ) - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: - unit.tx_savepoint_rollback = True - unit.sp_name = comp.sp_name - elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: - unit.tx_savepoint_declare = True - unit.sp_name = comp.sp_name - unit.sp_id = comp.sp_id - - elif isinstance(comp, dbstate.MigrationControlQuery): - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - unit.ddl_stmt_id = comp.ddl_stmt_id - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.tx_action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.tx_action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.tx_action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action == dbstate.MigrationAction.ABORT: - unit.tx_abort_migration = True - - elif isinstance(comp, dbstate.SessionStateQuery): - unit.sql = comp.sql - unit.globals = comp.globals - - if comp.config_scope is qltypes.ConfigScope.INSTANCE: - if (not ctx.state.current_tx().is_implicit() or - statements_len > 1): - raise errors.QueryError( - 'CONFIGURE INSTANCE cannot be executed in a ' - 'transaction block') - - unit.system_config = True - elif comp.config_scope is qltypes.ConfigScope.GLOBAL: - unit.needs_readback = True - - elif comp.config_scope is qltypes.ConfigScope.DATABASE: - unit.database_config = True - unit.needs_readback = True - - if comp.is_backend_setting: - unit.backend_config = True - if comp.requires_restart: - unit.config_requires_restart = True - if comp.is_system_config: - unit.is_system_config = True - - unit.modaliases = ctx.state.current_tx().get_modaliases() - - if comp.config_op is not None: - unit.config_ops.append(comp.config_op) - - if comp.in_type_args: - unit.in_type_args = comp.in_type_args - if comp.in_type_data: - unit.in_type_data = comp.in_type_data - if comp.in_type_id: - unit.in_type_id = comp.in_type_id - - unit.has_set = True - - elif isinstance(comp, dbstate.MaintenanceQuery): - unit.sql = comp.sql - - elif isinstance(comp, dbstate.NullQuery): - pass - - else: # pragma: no cover - raise errors.InternalServerError('unknown compile state') - - if unit.in_type_args: - unit.in_type_args_real_count = sum( - len(p.sub_params[0]) if p.sub_params else 1 - for p in unit.in_type_args - ) - - if unit.warnings: - for warning in unit.warnings: - warning.__traceback__ = None - rv.append(unit) + if user_schema is not None: + final_user_schema = user_schema + if script_info: if ctx.state.current_tx().is_implicit(): if ctx.state.current_tx().get_migration_state(): @@ -2690,7 +2461,6 @@ def _try_compile( f'QueryUnit {unit!r} is cacheable but has config/aliases') if not na_cardinality and ( - len(unit.sql) > 1 or unit.tx_commit or unit.tx_rollback or unit.tx_savepoint_rollback or @@ -2715,6 +2485,260 @@ def _try_compile( return rv +def _make_query_unit( + *, + ctx: CompileContext, + stmt_ctx: CompileContext, + stmt: qlast.Base, + is_script: bool, + is_trailing_stmt: bool, + comp: dbstate.BaseQuery, + capabilities: enums.Capability, +) -> tuple[dbstate.QueryUnit, Optional[s_schema.Schema]]: + + # Initialize user_schema_version with the version this query is + # going to be compiled upon. This can be overwritten later by DDLs. + try: + schema_version = _get_schema_version( + stmt_ctx.state.current_tx().get_user_schema() + ) + except errors.InvalidReferenceError: + schema_version = None + + unit = dbstate.QueryUnit( + sql=b"", + status=status.get_status(stmt), + cardinality=enums.Cardinality.NO_RESULT, + capabilities=capabilities, + output_format=stmt_ctx.output_format, + cache_key=ctx.cache_key, + user_schema_version=schema_version, + warnings=comp.warnings, + ) + + if not comp.is_transactional: + if is_script: + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'with other commands in one block', + span=stmt.span, + ) + + if not ctx.state.current_tx().is_implicit(): + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'in a transaction', + span=stmt.span, + ) + + unit.is_transactional = False + + final_user_schema: Optional[s_schema.Schema] = None + + if isinstance(comp, dbstate.Query): + unit.sql = comp.sql + unit.cache_sql = comp.cache_sql + unit.cache_func_call = comp.cache_func_call + unit.globals = comp.globals + unit.in_type_args = comp.in_type_args + + unit.sql_hash = comp.sql_hash + + unit.out_type_data = comp.out_type_data + unit.out_type_id = comp.out_type_id + unit.in_type_data = comp.in_type_data + unit.in_type_id = comp.in_type_id + + unit.cacheable = comp.cacheable + + if comp.is_explain: + unit.is_explain = True + unit.query_asts = comp.query_asts + + if comp.run_and_rollback: + unit.run_and_rollback = True + + if is_trailing_stmt: + unit.cardinality = comp.cardinality + + elif isinstance(comp, dbstate.SimpleQuery): + unit.sql = comp.sql + unit.in_type_args = comp.in_type_args + + elif isinstance(comp, dbstate.DDLQuery): + unit.sql = comp.sql + unit.db_op_trailer = comp.db_op_trailer + unit.create_db = comp.create_db + unit.drop_db = comp.drop_db + unit.drop_db_reset_connections = comp.drop_db_reset_connections + unit.create_db_template = comp.create_db_template + unit.create_db_mode = comp.create_db_mode + unit.ddl_stmt_id = comp.ddl_stmt_id + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + unit.config_ops.extend(comp.config_ops) + + elif isinstance(comp, dbstate.TxControlQuery): + if is_script: + raise errors.QueryError( + "Explicit transaction control commands cannot be executed " + "in an implicit transaction block" + ) + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.action == dbstate.TxAction.START: + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + elif comp.action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: + unit.tx_savepoint_rollback = True + unit.sp_name = comp.sp_name + elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: + unit.tx_savepoint_declare = True + unit.sp_name = comp.sp_name + unit.sp_id = comp.sp_id + + elif isinstance(comp, dbstate.MigrationControlQuery): + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + unit.ddl_stmt_id = comp.ddl_stmt_id + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.tx_action == dbstate.TxAction.START: + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.tx_action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + unit.append_tx_op = True + elif comp.tx_action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + unit.append_tx_op = True + elif comp.action == dbstate.MigrationAction.ABORT: + unit.tx_abort_migration = True + + elif isinstance(comp, dbstate.SessionStateQuery): + unit.sql = comp.sql + unit.globals = comp.globals + + if comp.config_scope is qltypes.ConfigScope.INSTANCE: + if not ctx.state.current_tx().is_implicit() or is_script: + raise errors.QueryError( + 'CONFIGURE INSTANCE cannot be executed in a ' + 'transaction block') + + unit.system_config = True + elif comp.config_scope is qltypes.ConfigScope.GLOBAL: + unit.needs_readback = True + + elif comp.config_scope is qltypes.ConfigScope.DATABASE: + unit.database_config = True + unit.needs_readback = True + + if comp.is_backend_setting: + unit.backend_config = True + if comp.requires_restart: + unit.config_requires_restart = True + if comp.is_system_config: + unit.is_system_config = True + + unit.modaliases = ctx.state.current_tx().get_modaliases() + + if comp.config_op is not None: + unit.config_ops.append(comp.config_op) + + if comp.in_type_args: + unit.in_type_args = comp.in_type_args + if comp.in_type_data: + unit.in_type_data = comp.in_type_data + if comp.in_type_id: + unit.in_type_id = comp.in_type_id + + unit.has_set = True + unit.output_format = enums.OutputFormat.NONE + + elif isinstance(comp, dbstate.MaintenanceQuery): + unit.sql = comp.sql + + elif isinstance(comp, dbstate.NullQuery): + pass + + else: # pragma: no cover + raise errors.InternalServerError('unknown compile state') + + if unit.in_type_args: + unit.in_type_args_real_count = sum( + len(p.sub_params[0]) if p.sub_params else 1 + for p in unit.in_type_args + ) + + if unit.warnings: + for warning in unit.warnings: + warning.__traceback__ = None + + return unit, final_user_schema + + def _extract_params( params: List[irast.Param], *, diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 73a1f753b70..ea44756af78 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -60,7 +60,6 @@ class TxAction(enum.IntEnum): - START = 1 COMMIT = 2 ROLLBACK = 3 @@ -71,7 +70,6 @@ class TxAction(enum.IntEnum): class MigrationAction(enum.IntEnum): - START = 1 POPULATE = 2 DESCRIBE = 3 @@ -80,10 +78,11 @@ class MigrationAction(enum.IntEnum): REJECT_PROPOSED = 6 -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class BaseQuery: - - sql: Tuple[bytes, ...] + sql: bytes + is_transactional: bool = True + has_dml: bool = False cache_sql: Optional[Tuple[bytes, bytes]] = dataclasses.field( kw_only=True, default=None ) # (persist, evict) @@ -94,22 +93,14 @@ class BaseQuery: kw_only=True, default=() ) - @property - def is_transactional(self) -> bool: - return True - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class NullQuery(BaseQuery): - - sql: Tuple[bytes, ...] = tuple() - is_transactional: bool = True - has_dml: bool = False + sql: bytes = b"" -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class Query(BaseQuery): - sql_hash: bytes cardinality: enums.Cardinality @@ -122,27 +113,21 @@ class Query(BaseQuery): globals: Optional[list[tuple[str, bool]]] = None - is_transactional: bool = True - has_dml: bool = False cacheable: bool = True is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SimpleQuery(BaseQuery): - - sql: Tuple[bytes, ...] - is_transactional: bool = True - has_dml: bool = False # XXX: Temporary hack, since SimpleQuery will die in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SessionStateQuery(BaseQuery): - + sql: bytes = b"" config_scope: Optional[qltypes.ConfigScope] = None is_backend_setting: bool = False requires_restart: bool = False @@ -156,9 +141,8 @@ class SessionStateQuery(BaseQuery): in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class DDLQuery(BaseQuery): - user_schema: Optional[s_schema.FlatSchema] feature_used_metrics: Optional[dict[str, float]] global_schema: Optional[s_schema.FlatSchema] = None @@ -169,18 +153,17 @@ class DDLQuery(BaseQuery): drop_db_reset_connections: bool = False create_db_template: Optional[str] = None create_db_mode: Optional[qlast.BranchType] = None + db_op_trailer: tuple[bytes, ...] = () ddl_stmt_id: Optional[str] = None config_ops: List[config.Operation] = dataclasses.field(default_factory=list) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class TxControlQuery(BaseQuery): - action: TxAction cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.Schema] = None global_schema: Optional[s_schema.Schema] = None @@ -191,25 +174,22 @@ class TxControlQuery(BaseQuery): sp_id: Optional[int] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MigrationControlQuery(BaseQuery): - action: MigrationAction tx_action: Optional[TxAction] cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.FlatSchema] = None cached_reflection: Any = None ddl_stmt_id: Optional[str] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MaintenanceQuery(BaseQuery): - - is_transactional: bool = True + pass @dataclasses.dataclass(frozen=True) @@ -224,10 +204,9 @@ class Param: ############################# -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class QueryUnit: - - sql: Tuple[bytes, ...] + sql: bytes # Status-line for the compiled command; returned to front-end # in a CommandComplete protocol message if the command is @@ -244,9 +223,9 @@ class QueryUnit: # Set only for units that contain queries that can be cached # as prepared statements in Postgres. - sql_hash: bytes = b'' + sql_hash: bytes = b"" - # True if all statments in *sql* can be executed inside a transaction. + # True if all statements in *sql* can be executed inside a transaction. # If False, they will be executed separately. is_transactional: bool = True @@ -298,6 +277,10 @@ class QueryUnit: create_db_template: Optional[str] = None create_db_mode: Optional[str] = None + # If a branch command needs extra SQL commands to be performed, + # those would end up here. + db_op_trailer: tuple[bytes, ...] = () + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None @@ -357,7 +340,8 @@ class QueryUnit: is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False + append_tx_op: bool = False @property def has_ddl(self) -> bool: @@ -390,13 +374,12 @@ def deserialize(cls, data: bytes) -> Self: def maybe_use_func_cache(self) -> None: if self.cache_func_call is not None: sql, sql_hash = self.cache_func_call - self.sql = (sql,) + self.sql = sql self.sql_hash = sql_hash @dataclasses.dataclass class QueryUnitGroup: - # All capabilities used by any query units in this group capabilities: enums.Capability = enums.Capability(0) @@ -568,29 +551,29 @@ class SQLQueryUnit: class CommandCompleteTag: - '''Dictates the tag of CommandComplete message that concludes this query.''' + """Dictates the tag of CommandComplete message that concludes this query.""" @dataclasses.dataclass(kw_only=True) class TagPlain(CommandCompleteTag): - '''Set the tag verbatim''' + """Set the tag verbatim""" tag: bytes @dataclasses.dataclass(kw_only=True) class TagCountMessages(CommandCompleteTag): - '''Count DataRow messages in the response and set the tag to - f'{prefix} {count_of_messages}'.''' + """Count DataRow messages in the response and set the tag to + f'{prefix} {count_of_messages}'.""" prefix: str @dataclasses.dataclass(kw_only=True) class TagUnpackRow(CommandCompleteTag): - '''Intercept a single DataRow message with a single column which represents + """Intercept a single DataRow message with a single column which represents the number of modified rows. - Sets the CommandComplete tag to f'{prefix} {modified_rows}'.''' + Sets the CommandComplete tag to f'{prefix} {modified_rows}'.""" prefix: str @@ -634,13 +617,15 @@ class ParsedDatabase: SQLSetting = tuple[str | int | float, ...] SQLSettings = immutables.Map[Optional[str], Optional[SQLSetting]] DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map() -DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map({ - "search_path": ("public",), - "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), - "server_version_num": cast( - SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) - ), -}) +DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map( + { + "search_path": ("public",), + "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), + "server_version_num": cast( + SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) + ), + } +) @dataclasses.dataclass @@ -680,11 +665,16 @@ def apply(self, query_unit: SQLQueryUnit) -> None: self.in_tx_local_settings = None self.savepoints.clear() elif query_unit.tx_action == TxAction.DECLARE_SAVEPOINT: - self.savepoints.append(( - query_unit.sp_name, - self.in_tx_settings, - self.in_tx_local_settings, - )) # type: ignore + assert query_unit.sp_name is not None + assert self.in_tx_settings is not None + assert self.in_tx_local_settings is not None + self.savepoints.append( + ( + query_unit.sp_name, + self.in_tx_settings, + self.in_tx_local_settings, + ) + ) elif query_unit.tx_action == TxAction.ROLLBACK_TO_SAVEPOINT: while self.savepoints: sp_name, settings, local_settings = self.savepoints[-1] @@ -735,7 +725,6 @@ def _set(attr_name: str) -> None: class ProposedMigrationStep(NamedTuple): - statements: Tuple[str, ...] confidence: float prompt: str @@ -748,17 +737,16 @@ class ProposedMigrationStep(NamedTuple): def to_json(self) -> Dict[str, Any]: return { - 'statements': [{'text': stmt} for stmt in self.statements], - 'confidence': self.confidence, - 'prompt': self.prompt, - 'prompt_id': self.prompt_id, - 'data_safe': self.data_safe, - 'required_user_input': list(self.required_user_input), + "statements": [{"text": stmt} for stmt in self.statements], + "confidence": self.confidence, + "prompt": self.prompt, + "prompt_id": self.prompt_id, + "data_safe": self.data_safe, + "required_user_input": list(self.required_user_input), } class MigrationState(NamedTuple): - parent_migration: Optional[s_migrations.Migration] initial_schema: s_schema.Schema initial_savepoint: Optional[str] @@ -769,14 +757,12 @@ class MigrationState(NamedTuple): class MigrationRewriteState(NamedTuple): - initial_savepoint: Optional[str] target_schema: s_schema.Schema accepted_migrations: Tuple[qlast.CreateMigration, ...] class TransactionState(NamedTuple): - id: int name: Optional[str] local_user_schema: s_schema.FlatSchema | None @@ -799,7 +785,6 @@ def user_schema(self) -> s_schema.FlatSchema: class Transaction: - _savepoints: Dict[int, TransactionState] _constate: CompilerConnectionState @@ -816,7 +801,6 @@ def __init__( cached_reflection: immutables.Map[str, Tuple[str, ...]], implicit: bool = True, ) -> None: - assert not isinstance(user_schema, s_schema.ChainedSchema) self._constate = constate @@ -857,12 +841,12 @@ def make_explicit(self) -> None: if self._implicit: self._implicit = False else: - raise errors.TransactionError('already in explicit transaction') + raise errors.TransactionError("already in explicit transaction") def declare_savepoint(self, name: str) -> int: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._declare_savepoint(name) @@ -882,7 +866,7 @@ def _declare_savepoint(self, name: str) -> int: def rollback_to_savepoint(self, name: str) -> TransactionState: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._rollback_to_savepoint(name) @@ -899,7 +883,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: sp_ids_to_erase.append(sp.id) else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -909,7 +893,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: def release_savepoint(self, name: str) -> None: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) self._release_savepoint(name) @@ -925,7 +909,7 @@ def _release_savepoint(self, name: str) -> None: if sp.name == name: break else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -1030,8 +1014,7 @@ def update_migration_rewrite_state( class CompilerConnectionState: - - __slots__ = ('_savepoints_log', '_current_tx', '_tx_count', '_user_schema') + __slots__ = ("_savepoints_log", "_current_tx", "_tx_count", "_user_schema") _savepoints_log: Dict[int, TransactionState] _user_schema: Optional[s_schema.FlatSchema] @@ -1111,7 +1094,7 @@ def sync_to_savepoint(self, spid: int) -> None: """Synchronize the compiler state with the current DB state.""" if not self.can_sync_to_savepoint(spid): - raise RuntimeError(f'failed to lookup savepoint with id={spid}') + raise RuntimeError(f"failed to lookup savepoint with id={spid}") sp = self._savepoints_log[spid] self._current_tx = sp.tx @@ -1137,7 +1120,7 @@ def start_tx(self) -> None: if self._current_tx.is_implicit(): self._current_tx.make_explicit() else: - raise errors.TransactionError('already in transaction') + raise errors.TransactionError("already in transaction") def rollback_tx(self) -> TransactionState: # Note that we might not be in a transaction as we allow @@ -1159,7 +1142,7 @@ def rollback_tx(self) -> TransactionState: def commit_tx(self) -> TransactionState: if self._current_tx.is_implicit(): - raise errors.TransactionError('cannot commit: not in transaction') + raise errors.TransactionError("cannot commit: not in transaction") latest_state = self._current_tx._current @@ -1184,5 +1167,5 @@ def sync_tx(self, txid: int) -> None: return raise errors.InternalServerError( - f'failed to lookup transaction or savepoint with id={txid}' + f"failed to lookup transaction or savepoint with id={txid}" ) # pragma: no cover diff --git a/edb/server/compiler/ddl.py b/edb/server/compiler/ddl.py index 9754e9eeeb8..fe1dd279839 100644 --- a/edb/server/compiler/ddl.py +++ b/edb/server/compiler/ddl.py @@ -64,17 +64,28 @@ from edb.pgsql import common as pg_common from edb.pgsql import delta as pg_delta from edb.pgsql import dbops as pg_dbops -from edb.pgsql import trampoline from . import dbstate from . import compiler +NIL_QUERY = b"SELECT LIMIT 0" + + def compile_and_apply_ddl_stmt( ctx: compiler.CompileContext, - stmt: qlast.DDLOperation, + stmt: qlast.DDLCommand, source: Optional[edgeql.Source] = None, ) -> dbstate.DDLQuery: + query, _ = _compile_and_apply_ddl_stmt(ctx, stmt, source) + return query + + +def _compile_and_apply_ddl_stmt( + ctx: compiler.CompileContext, + stmt: qlast.DDLCommand, + source: Optional[edgeql.Source] = None, +) -> tuple[dbstate.DDLQuery, Optional[pg_dbops.SQLBlock]]: if isinstance(stmt, qlast.GlobalObjectCommand): ctx._assert_not_in_migration_block(stmt) @@ -127,7 +138,7 @@ def compile_and_apply_ddl_stmt( ) ], ) - return compile_and_apply_ddl_stmt(ctx, cm) + return _compile_and_apply_ddl_stmt(ctx, cm) assert isinstance(stmt, qlast.DDLCommand) new_schema, delta = s_ddl.delta_and_schema_from_ddl( @@ -175,14 +186,16 @@ def compile_and_apply_ddl_stmt( current_tx.update_migration_state(mstate) current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + store_migration_sdl = compiler._get_config_val(ctx, 'store_migration_sdl') if ( isinstance(stmt, qlast.CreateMigration) @@ -207,26 +220,30 @@ def compile_and_apply_ddl_stmt( current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + # Apply and adapt delta, build native delta plan, which # will also update the schema. block, new_types, config_ops = _process_delta(ctx, delta) ddl_stmt_id: Optional[str] = None - is_transactional = block.is_transactional() if not is_transactional: - sql = tuple(stmt.encode('utf-8') for stmt in block.get_statements()) + if not isinstance(stmt, qlast.DatabaseCommand): + raise AssertionError( + f"unexpected non-transaction DDL command type: {stmt}") + sql_stmts = block.get_statements() + sql = sql_stmts[0].encode("utf-8") + db_op_trailer = tuple(stmt.encode("utf-8") for stmt in sql_stmts[1:]) else: - sql = (block.to_string().encode('utf-8'),) - if new_types: # Inject a query returning backend OIDs for the newly # created types. @@ -234,10 +251,10 @@ def compile_and_apply_ddl_stmt( new_type_ids = [ f'{pg_common.quote_literal(tid)}::uuid' for tid in new_types ] - sql = sql + ( - trampoline.fixup_query(textwrap.dedent( - f'''\ - SELECT + # Return newly-added type id mapping via the indirect + # return channel (see PGConnection.last_indirect_return) + new_types_sql = textwrap.dedent(f"""\ + PERFORM edgedb.indirect_return( json_build_object( 'ddl_stmt_id', {pg_common.quote_literal(ddl_stmt_id)}, @@ -254,11 +271,15 @@ def compile_and_apply_ddl_stmt( {', '.join(new_type_ids)} ]) ) - )::text; - ''' - )).encode('utf-8'), + )::text + )""" ) + block.add_command(pg_dbops.Query(text=new_types_sql).code()) + + sql = block.to_string().encode('utf-8') + db_op_trailer = () + create_db = None drop_db = None drop_db_reset_connections = False @@ -286,10 +307,10 @@ def compile_and_apply_ddl_stmt( debug.dump_code(code, lexer='sql') if debug.flags.delta_execute: debug.header('Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql + b"\n".join(db_op_trailer), lexer='sql') new_user_schema = current_tx.get_user_schema_if_updated() - return dbstate.DDLQuery( + query = dbstate.DDLQuery( sql=sql, is_transactional=is_transactional, create_db=create_db, @@ -297,6 +318,7 @@ def compile_and_apply_ddl_stmt( drop_db_reset_connections=drop_db_reset_connections, create_db_template=create_db_template, create_db_mode=create_db_mode, + db_op_trailer=db_op_trailer, ddl_stmt_id=ddl_stmt_id, user_schema=new_user_schema, cached_reflection=current_tx.get_cached_reflection_if_updated(), @@ -309,6 +331,8 @@ def compile_and_apply_ddl_stmt( ), ) + return query, block + def _new_delta_context( ctx: compiler.CompileContext, args: Any = None @@ -464,7 +488,7 @@ def _start_migration( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -573,7 +597,7 @@ def _populate_migration( current_tx.update_schema(schema) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.POPULATE, cacheable=False, @@ -801,7 +825,7 @@ def _alter_current_migration_reject_proposed( current_tx.update_migration_state(mstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.REJECT_PROPOSED, cacheable=False, @@ -876,7 +900,7 @@ def _commit_migration( current_tx.update_migration_rewrite_state(mrstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.COMMIT, tx_action=None, cacheable=False, @@ -893,16 +917,12 @@ def _commit_migration( if mstate.initial_savepoint: current_tx.commit_migration(mstate.initial_savepoint) - sql = ddl_query.sql tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sql = ddl_query.sql + tx_query.sql - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=sql, + sql=ddl_query.sql, ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, @@ -923,7 +943,7 @@ def _abort_migration( if mstate.initial_savepoint: current_tx.abort_migration(mstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -967,7 +987,7 @@ def _start_migration_rewrite( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -1052,25 +1072,24 @@ def _commit_migration_rewrite( for cm in cmds: cm.dump_edgeql() - sqls: List[bytes] = [] + block = pg_dbops.PLTopBlock() for cmd in cmds: - ddl_query = compile_and_apply_ddl_stmt(ctx, cmd) + _, ddl_block = _compile_and_apply_ddl_stmt(ctx, cmd) + assert isinstance(ddl_block, pg_dbops.PLBlock) # We know nothing serious can be in that query # except for the SQL, so it's fine to just discard # it all. - sqls.extend(ddl_query.sql) + for stmt in ddl_block.get_statements(): + block.add_command(stmt) if mrstate.initial_savepoint: current_tx.commit_migration(mrstate.initial_savepoint) tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sqls.extend(tx_query.sql) - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=block.to_string().encode("utf-8"), action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, cacheable=False, @@ -1090,7 +1109,7 @@ def _abort_migration_rewrite( if mrstate.initial_savepoint: current_tx.abort_migration(mrstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -1146,8 +1165,6 @@ def _reset_schema( current_schema=empty_schema, ) - sqls: List[bytes] = [] - # diff and create migration that drops all objects diff = s_ddl.delta_schemas(schema, empty_schema) new_ddl: Tuple[qlast.DDLCommand, ...] = tuple( @@ -1156,8 +1173,8 @@ def _reset_schema( create_mig = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock(commands=tuple(new_ddl)), # type: ignore ) - ddl_query = compile_and_apply_ddl_stmt(ctx, create_mig) - sqls.extend(ddl_query.sql) + ddl_query, ddl_block = _compile_and_apply_ddl_stmt(ctx, create_mig) + assert ddl_block is not None # delete all migrations schema = current_tx.get_schema(ctx.compiler_state.std_schema) @@ -1170,11 +1187,13 @@ def _reset_schema( drop_mig = qlast.DropMigration( # type: ignore name=qlast.ObjectRef(name=mig.get_name(schema).name), ) - ddl_query = compile_and_apply_ddl_stmt(ctx, drop_mig) - sqls.extend(ddl_query.sql) + _, mig_block = _compile_and_apply_ddl_stmt(ctx, drop_mig) + assert isinstance(mig_block, pg_dbops.PLBlock) + for stmt in mig_block.get_statements(): + ddl_block.add_command(stmt) return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=ddl_block.to_string().encode("utf-8"), ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=None, @@ -1278,7 +1297,7 @@ def _track(key: str) -> None: def repair_schema( ctx: compiler.CompileContext, -) -> Optional[tuple[tuple[bytes, ...], s_schema.Schema, Any]]: +) -> Optional[tuple[bytes, s_schema.Schema, Any]]: """Repair inconsistencies in the schema caused by bug fixes Works by comparing the actual current schema to the schema we get @@ -1340,11 +1359,11 @@ def repair_schema( is_transactional = block.is_transactional() assert not new_types assert is_transactional - sql = (block.to_string().encode('utf-8'),) + sql = block.to_string().encode('utf-8') if debug.flags.delta_execute: debug.header('Repair Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql, lexer='sql') return sql, reloaded_schema, config_ops @@ -1363,7 +1382,7 @@ def administer_repair_schema( res = repair_schema(ctx) if not res: - return dbstate.MaintenanceQuery(sql=(b'',)) + return dbstate.MaintenanceQuery(sql=b"") sql, new_schema, config_ops = res current_tx.update_schema(new_schema) @@ -1511,9 +1530,11 @@ def administer_reindex( for pindex in pindexes ] - return dbstate.MaintenanceQuery( - sql=tuple(q.encode('utf-8') for q in commands) - ) + block = pg_dbops.PLTopBlock() + for command in commands: + block.add_command(command) + + return dbstate.MaintenanceQuery(sql=block.to_string().encode("utf-8")) def administer_vacuum( @@ -1663,7 +1684,7 @@ def administer_vacuum( command = f'VACUUM {options} ' + ', '.join(tables_and_columns) return dbstate.MaintenanceQuery( - sql=(command.encode('utf-8'),), + sql=command.encode('utf-8'), is_transactional=False, ) diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index 03272dfe382..ca3ab04c9ae 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -79,6 +79,10 @@ EDGEDB_SERVER_SETTINGS = { 'client_encoding': 'utf-8', + # DO NOT raise client_min_messages above NOTICE level + # because server indirect block return machinery relies + # on NoticeResponse as the data channel. + 'client_min_messages': 'NOTICE', 'search_path': 'edgedb', 'timezone': 'UTC', 'intervalstyle': 'iso_8601', @@ -568,7 +572,7 @@ async def start( else: log_level_map = { 'd': 'INFO', - 'i': 'NOTICE', + 'i': 'WARNING', # NOTICE in Postgres is quite noisy 'w': 'WARNING', 'e': 'ERROR', 's': 'PANIC', diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 348b4010b40..4b6a0ed0608 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -139,6 +139,8 @@ cdef class PGConnection: object last_state + str last_indirect_return + cdef before_command(self) cdef write(self, buf) diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index ee35a9fad88..e6c00c287b9 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -53,7 +53,7 @@ class PGConnection(asyncio.Protocol): async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 2d77944cb0d..0c66aa3e1c3 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -100,6 +100,7 @@ cdef dict POSTGRES_SHUTDOWN_ERR_CODES = { } cdef object EMPTY_SQL_STATE = b"{}" +cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() cdef object logger = logging.getLogger('edb.server') @@ -239,6 +240,8 @@ cdef class PGConnection: self.last_parse_prep_stmts = [] self.debug = debug.flags.server_proto + self.last_indirect_return = None + self.log_listeners = [] self.server = None @@ -590,6 +593,8 @@ cdef class PGConnection: WriteBuffer bind_data bytes stmt_name ssize_t idx = start + bytes sql + tuple sqls out = WriteBuffer.new() parsed = set() @@ -607,7 +612,6 @@ cdef class PGConnection: ) stmt_name = query_unit.sql_hash if stmt_name: - assert len(query_unit.sql) == 1 # The same EdgeQL query may show up twice in the same script. # We just need to know and skip if we've already parsed the # same query within current send batch, because self.prep_stmts @@ -624,15 +628,16 @@ cdef class PGConnection: for query_unit, bind_data in zip( query_unit_group.units[start:end], bind_datas): stmt_name = query_unit.sql_hash + sql = query_unit.sql if stmt_name: if parse_array[idx]: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) - buf.write_bytestring(query_unit.sql[0]) + buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( - len(query_unit.sql[0]), + len(sql), self.get_tenant_label(), 'compiled', ) @@ -649,26 +654,25 @@ cdef class PGConnection: out.write_buffer(buf.end_message()) else: - for sql in query_unit.sql: - buf = WriteBuffer.new_message(b'P') - buf.write_bytestring(b'') # statement name - buf.write_bytestring(sql) - buf.write_int16(0) - out.write_buffer(buf.end_message()) - metrics.query_size.observe( - len(sql), self.get_tenant_label(), 'compiled' - ) + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') # statement name + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + metrics.query_size.observe( + len(sql), self.get_tenant_label(), 'compiled' + ) - buf = WriteBuffer.new_message(b'B') - buf.write_bytestring(b'') # portal name - buf.write_bytestring(b'') # statement name - buf.write_buffer(bind_data) - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_buffer(bind_data) + out.write_buffer(buf.end_message()) - buf = WriteBuffer.new_message(b'E') - buf.write_bytestring(b'') # portal name - buf.write_int32(0) # limit: 0 - return all rows - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) idx += 1 @@ -686,7 +690,7 @@ cdef class PGConnection: out = WriteBuffer.new() buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(b'???') + buf.write_bytestring(b'') buf.write_int16(0) # Then do a sync to get everything executed and lined back up @@ -796,7 +800,6 @@ cdef class PGConnection: elif mtype == b'I': ## result # EmptyQueryResponse self.buffer.discard_message() - return result else: self.fallthrough() @@ -819,6 +822,10 @@ cdef class PGConnection: WriteBuffer out WriteBuffer buf bytes stmt_name + bytes sql + tuple sqls + bytes prologue_sql + bytes epilogue_sql int32_t dat_len @@ -833,14 +840,6 @@ cdef class PGConnection: uint64_t msgs_executed = 0 uint64_t i - if use_pending_func_cache and query.cache_func_call: - sql, stmt_name = query.cache_func_call - sqls = (sql,) - else: - sqls = query.sql - stmt_name = query.sql_hash - msgs_num = (len(sqls)) - out = WriteBuffer.new() if state is not None: @@ -848,7 +847,7 @@ cdef class PGConnection: if ( query.tx_id or not query.is_transactional - or query.append_rollback + or query.run_and_rollback or tx_isolation is not None ): # This query has START TRANSACTION or non-transactional command @@ -860,22 +859,22 @@ cdef class PGConnection: state_sync = 1 self.write_sync(out) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: if self.in_tx(): sp_name = f'_edb_{time.monotonic_ns()}' - sql = f'SAVEPOINT {sp_name}'.encode('utf-8') + prologue_sql = f'SAVEPOINT {sp_name}'.encode('utf-8') else: sp_name = None - sql = b'START TRANSACTION' + prologue_sql = b'START TRANSACTION' if tx_isolation is not None: - sql += ( + prologue_sql += ( f' ISOLATION LEVEL {tx_isolation._value_}' .encode('utf-8') ) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(sql) + buf.write_bytestring(prologue_sql) buf.write_int16(0) out.write_buffer(buf.end_message()) @@ -895,9 +894,17 @@ cdef class PGConnection: # Insert a SYNC as a boundary of the parsing logic later self.write_sync(out) + if use_pending_func_cache and query.cache_func_call: + sql, stmt_name = query.cache_func_call + sqls = (sql,) + else: + sqls = (query.sql,) + query.db_op_trailer + stmt_name = query.sql_hash + + msgs_num = (len(sqls)) + if use_prep_stmt: - parse = self.before_prepare( - stmt_name, dbver, out) + parse = self.before_prepare(stmt_name, dbver, out) else: stmt_name = b'' @@ -962,8 +969,8 @@ cdef class PGConnection: buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) - if query.append_rollback or tx_isolation is not None: - if query.append_rollback: + if query.run_and_rollback or tx_isolation is not None: + if query.run_and_rollback: if sp_name: sql = f'ROLLBACK TO SAVEPOINT {sp_name}'.encode('utf-8') else: @@ -985,6 +992,35 @@ cdef class PGConnection: buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) + elif query.append_tx_op: + if query.tx_commit: + sql = b'COMMIT' + elif query.tx_rollback: + sql = b'ROLLBACK' + else: + raise errors.InternalServerError( + "QueryUnit.append_tx_op is set but none of the " + "Query.tx_ properties are" + ) + + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_int16(0) # number of format codes + buf.write_int16(0) # number of parameters + buf.write_int16(0) # number of result columns + out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows @@ -999,7 +1035,7 @@ cdef class PGConnection: if state is not None: await self.wait_for_state_resp(state, state_sync) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: await self.wait_for_sync() buf = None @@ -1098,7 +1134,7 @@ cdef class PGConnection: self, *, query, - WriteBuffer bind_data, + WriteBuffer bind_data = NO_ARGS, frontend.AbstractFrontendConnection fe_conn = None, bint use_prep_stmt = False, bytes state = None, @@ -1127,29 +1163,21 @@ cdef class PGConnection: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, ) -> list[tuple[bytes, ...]]: - cdef tuple sql_tuple - - if not isinstance(sql, tuple): - sql_tuple = (sql,) - else: - sql_tuple = sql - if use_prep_stmt: sql_digest = hashlib.sha1() - for stmt in sql_tuple: - sql_digest.update(stmt) + sql_digest.update(sql) sql_hash = sql_digest.hexdigest().encode('latin1') else: sql_hash = None query = compiler.QueryUnit( - sql=sql_tuple, + sql=sql, sql_hash=sql_hash, status=b"", ) @@ -2067,42 +2095,24 @@ cdef class PGConnection: msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) - async def run_ddl( - self, - object query_unit, - bytes state=None - ): - data = await self.sql_fetch(query_unit.sql, state=state) - if query_unit.ddl_stmt_id is None: - return - else: - return self.load_ddl_return(query_unit, data) - - def load_ddl_return(self, object query_unit, data): + def load_last_ddl_return(self, object query_unit): if query_unit.ddl_stmt_id: + data = self.last_indirect_return if data: - ret = json.loads(data[0][0]) + ret = json.loads(data) if ret['ddl_stmt_id'] != query_unit.ddl_stmt_id: raise RuntimeError( - 'unrecognized data packet after a DDL command: ' - 'data_stmt_id do not match' + 'unrecognized data notice after a DDL command: ' + 'data_stmt_id do not match: expected ' + f'{query_unit.ddl_stmt_id!r}, got ' + f'{ret["ddl_stmt_id"]!r}' ) return ret else: raise RuntimeError( - 'missing the required data packet after a DDL command' + 'missing the required data notice after a DDL command' ) - async def handle_ddl_in_script( - self, object query_unit, bint parse, int dbver - ): - data = None - for sql in query_unit.sql: - data = await self.wait_for_command( - query_unit, parse, dbver, ignore_data=bool(data) - ) or data - return self.load_ddl_return(query_unit, data) - async def _dump(self, block, output_queue, fragment_suggested_size): cdef: WriteBuffer buf @@ -2527,6 +2537,7 @@ cdef class PGConnection: 'previous one') self.idle = False + self.last_indirect_return = None async def after_command(self): if self.idle: @@ -2675,14 +2686,18 @@ cdef class PGConnection: elif mtype == b'N': # NoticeResponse - if self.log_listeners: - _, fields = self.parse_error_message() - severity = fields.get('V') - message = fields.get('M') + _, fields = self.parse_error_message() + severity = fields.get('V') + message = fields.get('M') + detail = fields.get('D') + if ( + severity == "NOTICE" + and message.startswith("edb:notice:indirect_return") + ): + self.last_indirect_return = detail + elif self.log_listeners: for listener in self.log_listeners: self.loop.call_soon(listener, severity, message) - else: - self.buffer.discard_message() return True return False diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 72a8f265d9b..d35d5664dcf 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1582,7 +1582,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): if query_unit.sql: if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) + await pgcon.parse_execute(query=query_unit) + ddl_ret = pgcon.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 28730acc3fa..bb63859a204 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -94,19 +94,14 @@ cdef class ExecutionGroup: if state is not None: await be_conn.wait_for_state_resp(state, state_sync=0) for i, unit in enumerate(self.group): - if unit.output_format == FMT_NONE and unit.ddl_stmt_id is None: - for sql in unit.sql: - await be_conn.wait_for_command( - unit, parse_array[i], dbver, ignore_data=True - ) - rv = None - else: - for sql in unit.sql: - rv = await be_conn.wait_for_command( - unit, parse_array[i], dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + ignore_data = unit.output_format == FMT_NONE + rv = await be_conn.wait_for_command( + unit, + parse_array[i], + dbver, + ignore_data=ignore_data, + fe_conn=None if ignore_data else fe_conn, + ) return rv @@ -135,13 +130,11 @@ cpdef ExecutionGroup build_cache_persistence_units( assert serialized_result is not None if evict: - group.append(compiler.QueryUnit(sql=(evict,), status=b'')) + group.append(compiler.QueryUnit(sql=evict, status=b'')) if persist: - group.append(compiler.QueryUnit(sql=(persist,), status=b'')) + group.append(compiler.QueryUnit(sql=persist, status=b'')) group.append( - compiler.QueryUnit( - sql=(insert_sql,), sql_hash=sql_hash, status=b'', - ), + compiler.QueryUnit(sql=insert_sql, sql_hash=sql_hash, status=b''), args_ser.combine_raw_args(( query_unit.cache_key.bytes, query_unit.user_schema_version.bytes, @@ -276,9 +269,11 @@ async def execute( if query_unit.sql: if query_unit.user_schema: - ddl_ret = await be_conn.run_ddl(query_unit, state) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] + await be_conn.parse_execute(query=query_unit, state=state) + if query_unit.ddl_stmt_id is not None: + ddl_ret = be_conn.load_last_ddl_return(query_unit) + if ddl_ret and ddl_ret['new_types']: + new_types = ddl_ret['new_types'] else: bound_args_buf = args_ser.recode_bind_args( dbv, compiled, bind_args) @@ -519,35 +514,29 @@ async def execute_script( if query_unit.sql: parse = parse_array[idx] + fe_output = query_unit.output_format != FMT_NONE + ignore_data = ( + not fe_output + and not query_unit.needs_readback + ) + data = await conn.wait_for_command( + query_unit, + parse, + dbver, + ignore_data=ignore_data, + fe_conn=fe_conn if fe_output else None, + ) + if query_unit.ddl_stmt_id: - ddl_ret = await conn.handle_ddl_in_script( - query_unit, parse, dbver - ) + ddl_ret = conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] - elif query_unit.needs_readback: - config_data = [] - for sql in query_unit.sql: - config_data = await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=False - ) - if config_data: - config_ops = [ - config.Operation.from_json(r[0][1:]) - for r in config_data - ] - elif query_unit.output_format == FMT_NONE: - for sql in query_unit.sql: - await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=True - ) - else: - for sql in query_unit.sql: - data = await conn.wait_for_command( - query_unit, parse, dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + + if query_unit.needs_readback and data: + config_ops = [ + config.Operation.from_json(r[0][1:]) + for r in data + ] if config_ops: await dbv.apply_config_ops(conn, config_ops) @@ -642,12 +631,7 @@ async def execute_system_config( await conn.sql_fetch(b'select 1', state=state) if query_unit.sql: - if len(query_unit.sql) > 1: - raise errors.InternalServerError( - "unexpected multiple SQL statements in CONFIGURE INSTANCE " - "compilation product" - ) - data = await conn.sql_fetch_col(query_unit.sql[0]) + data = await conn.sql_fetch_col(query_unit.sql) else: data = None diff --git a/edb/server/server.py b/edb/server/server.py index 557fa39eaf0..59caa6311f9 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -1433,7 +1433,7 @@ async def _maybe_apply_patches( db_config = self._parse_db_config(config_json, user_schema) try: logger.info("repairing database '%s'", dbname) - sql += bootstrap.prepare_repair_patch( + rep_sql = bootstrap.prepare_repair_patch( self._std_schema, self._refl_schema, user_schema, @@ -1442,6 +1442,7 @@ async def _maybe_apply_patches( self._tenant.get_backend_runtime_params(), db_config, ) + sql += (rep_sql,) except errors.EdgeDBError as e: if isinstance(e, errors.InternalServerError): raise @@ -1454,7 +1455,7 @@ async def _maybe_apply_patches( ) from e if sql: - await conn.sql_fetch(sql) + await conn.sql_execute(sql) logger.info( "finished applying patch %d to database '%s'", num, dbname) From 81a15efe20793ff416f2fe5cad9c5c722df6f865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 13 Nov 2024 02:35:34 +0100 Subject: [PATCH 02/16] Emulate foreign keys for single links in introspection over SQL adapter (#7946) --- edb/pgsql/metaschema.py | 166 +++++++++++++++++++++++++++++++++---- edb/pgsql/resolver/expr.py | 8 ++ 2 files changed, 156 insertions(+), 18 deletions(-) diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index de712a72552..6c1077261c7 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -6362,20 +6362,6 @@ def _generate_sql_information_schema( ) ), ), - # TODO: Should we try to filter here, and fix up some stuff - # elsewhere, instead of overriding pg_get_constraintdef? - trampoline.VersionedView( - name=("edgedbsql", "pg_constraint"), - query=""" - SELECT - pc.*, - pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid - FROM pg_constraint pc - JOIN pg_namespace pn ON pc.connamespace = pn.oid - WHERE NOT (pn.nspname = 'edgedbpub' AND pc.conbin IS NOT NULL) - """ - ), - # pg_class that contains classes only for tables # This is needed so we can use it to filter pg_index to indexes only on # visible tables. @@ -6457,10 +6443,16 @@ def _generate_sql_information_schema( pi.indisclustered, pi.indisvalid, pi.indcheckxmin, - pi.indisready, + CASE + WHEN COALESCE(is_id.t, FALSE) THEN TRUE + ELSE FALSE -- override so pg_dump won't try to recreate them + END AS indisready, pi.indislive, pi.indisreplident, - pi.indkey, + CASE + WHEN COALESCE(is_id.t, FALSE) THEN ARRAY[1]::int2vector -- id: 1 + ELSE pi.indkey + END AS indkey, pi.indcollation, pi.indclass, pi.indoption, @@ -6817,6 +6809,88 @@ def _generate_sql_information_schema( WHERE FALSE """, ), + trampoline.VersionedView( + name=("edgedbsql", "pg_constraint"), + query=r""" + -- primary keys + SELECT + pc.oid, + vt.table_name || '_pk' AS conname, + pc.connamespace, + 'p'::"char" AS contype, + pc.condeferrable, + pc.condeferred, + pc.convalidated, + pc.conrelid, + pc.contypid, + pc.conindid, + pc.conparentid, + pc.confrelid, + pc.confupdtype, + pc.confdeltype, + pc.confmatchtype, + pc.conislocal, + pc.coninhcount, + pc.connoinherit, + pc.conkey, + pc.confkey, + pc.conpfeqop, + pc.conppeqop, + pc.conffeqop, + pc.confdelsetcols, + pc.conexclop, + pc.conbin, + pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid + FROM pg_constraint pc + JOIN edgedbsql_VER.pg_class_tables pct ON pct.oid = pc.conrelid + JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pct.reltype + JOIN pg_attribute pa ON (pa.attname = 'id' AND pa.attrelid = pct.oid) + WHERE contype = 'u' -- our ids and all links will have unique constraint + AND attnum = ANY(conkey) + + UNION ALL + + -- foreign keys + SELECT + edgedbsql_VER.uuid_to_oid(sl.id) as oid, + vt.table_name || '_fk_' || sl.name AS conname, + edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, + 'f'::"char" AS contype, + FALSE AS condeferrable, + FALSE AS condeferred, + TRUE AS convalidated, + pc.oid AS conrelid, + 0::oid AS contypid, + 0::oid AS conindid, -- let's hope this is not needed + 0::oid AS conparentid, + pc_target.oid AS confrelid, + 'a'::"char" AS confupdtype, + 'a'::"char" AS confdeltype, + 's'::"char" AS confmatchtype, + TRUE AS conislocal, + 0::int2 AS coninhcount, + TRUE AS connoinherit, + ARRAY[pa.attnum]::int2[] AS conkey, + ARRAY[1]::int2[] AS confkey, -- id will always have attnum 1 + ARRAY[2972]::oid[] AS conpfeqop, -- 2972 is eq comparison for uuids + ARRAY[2972]::oid[] AS conppeqop, -- 2972 is eq comparison for uuids + ARRAY[2972]::oid[] AS conffeqop, -- 2972 is eq comparison for uuids + NULL::int2[] AS confdelsetcols, + NULL::oid[] AS conexclop, + NULL::pg_node_tree AS conbin, + pa.tableoid, pa.xmin, pa.cmin, pa.xmax, pa.cmax, pa.ctid + FROM edgedbsql_VER.virtual_tables vt + JOIN pg_class pc ON pc.reltype = vt.pg_type_id + JOIN edgedb_VER."_SchemaLink" sl + ON sl.source = vt.id -- AND COALESCE(sl.cardinality = 'One', TRUE) + JOIN edgedbsql_VER.virtual_tables vt_target + ON sl.target = vt_target.id + JOIN pg_class pc_target ON pc_target.reltype = vt_target.pg_type_id + JOIN edgedbsql_VER.pg_attribute pa + ON pa.attrelid = pc.oid + AND pa.attname = sl.name || '_id' + """ + ), trampoline.VersionedView( name=("edgedbsql", "pg_statistic"), query=""" @@ -6993,6 +7067,18 @@ def _generate_sql_information_schema( WHERE c.relkind = 'v'::"char" """, ), + # Omit all descriptions (comments), becase all non-system comments + # are our internal implementation details. + trampoline.VersionedView( + name=("edgedbsql", "pg_description"), + query=""" + SELECT + *, + tableoid, xmin, cmin, xmax, cmax, ctid + FROM pg_description + WHERE FALSE + """, + ), ] # We expose most of the views as empty tables, just to prevent errors when @@ -7041,6 +7127,7 @@ def _generate_sql_information_schema( 'pg_subscription', 'pg_tables', 'pg_views', + 'pg_description', } PG_TABLES_WITH_SYSTEM_COLS = { @@ -7061,7 +7148,6 @@ def _generate_sql_information_schema( 'pg_db_role_setting', 'pg_default_acl', 'pg_depend', - 'pg_description', 'pg_enum', 'pg_event_trigger', 'pg_extension', @@ -7352,7 +7438,51 @@ def construct_pg_view( WHERE t.oid = typeoid ''', - ) + ), + trampoline.VersionedFunction( + name=("edgedbsql", "pg_get_constraintdef"), + args=[ + ('conid', ('oid',)), + ], + returns=('text',), + volatility='stable', + text=r""" + SELECT CASE + WHEN contype = 'p' THEN + 'PRIMARY KEY(' || ( + SELECT string_agg('"' || attname || '"', ', ') + FROM edgedbsql_VER.pg_attribute + WHERE attrelid = conrelid AND attnum = ANY(conkey) + ) || ')' + WHEN contype = 'f' THEN + 'FOREIGN KEY ("' || ( + SELECT attname + FROM edgedbsql_VER.pg_attribute + WHERE attrelid = conrelid AND attnum = ANY(conkey) + LIMIT 1 + ) || '")' || ' REFERENCES "' + || pn.nspname || '"."' || pc.relname || '"(id)' + ELSE '' + END + FROM edgedbsql_VER.pg_constraint con + LEFT JOIN edgedbsql_VER.pg_class_tables pc ON pc.oid = confrelid + LEFT JOIN edgedbsql_VER.pg_namespace pn + ON pc.relnamespace = pn.oid + WHERE con.oid = conid + """ + ), + trampoline.VersionedFunction( + name=("edgedbsql", "pg_get_constraintdef"), + args=[ + ('conid', ('oid',)), + ('pretty', ('bool',)), + ], + returns=('text',), + volatility='stable', + text=r""" + SELECT pg_get_constraintdef(conid) + """ + ), ] return ( diff --git a/edb/pgsql/resolver/expr.py b/edb/pgsql/resolver/expr.py index 8edb6f96bd7..16a92d8638c 100644 --- a/edb/pgsql/resolver/expr.py +++ b/edb/pgsql/resolver/expr.py @@ -456,6 +456,14 @@ def resolve_SortBy( common.versioned_schema('edgedbsql'), '_format_type', ), + ('pg_catalog', 'pg_get_constraintdef'): ( + common.versioned_schema('edgedbsql'), + 'pg_get_constraintdef', + ), + ('pg_get_constraintdef',): ( + common.versioned_schema('edgedbsql'), + 'pg_get_constraintdef', + ), } From aff8c6295bafe01885941ad75e595fed8738b7ca Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 12 Nov 2024 20:22:19 -0700 Subject: [PATCH 03/16] HTTP streaming fixes (#7984) This was failing to send the final SSEEnd message because it was not actually running the future. The issue was masked because we ignore the return type from `rpc_pipe.write`. ``` defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id))); ``` In addition, fix a problem where ARM64 Linux tests would occasionally fail due to an incorrect error message. --- edb/server/http/src/python.rs | 15 ++++- tests/test_http.py | 121 +++++++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 26 deletions(-) diff --git a/edb/server/http/src/python.rs b/edb/server/http/src/python.rs index e48be467b0d..d42f38bf4c5 100644 --- a/edb/server/http/src/python.rs +++ b/edb/server/http/src/python.rs @@ -218,10 +218,15 @@ async fn request_sse( return Err(format!("Failed to read response body: {e:?}")); } }; + + // Note that we use semaphores here in a strange way, but basically we + // want to have per-stream backpressure to avoid buffering messages + // indefinitely. let Ok(permit) = backpressure.acquire().await else { break; }; permit.forget(); + if rpc_pipe .write(RustToPythonMessage::SSEEvent(id, chunk)) .await @@ -464,8 +469,14 @@ async fn execute( drop(permit); } RequestSse(id, url, method, body, headers) => { - // Ensure we send the end message whenever this block exits - defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id))); + // Ensure we send the end message whenever this block exits (though + // we need to spawn a task to do so) + defer!({ + let rpc_pipe = rpc_pipe.clone(); + tokio::task::spawn_local(async move { + _ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id)).await; + }); + }); let Ok(permit) = permit_manager.acquire().await else { return; }; diff --git a/tests/test_http.py b/tests/test_http.py index f1438d48dd6..1afe0ffc34c 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -122,9 +122,11 @@ async def test_immediate_connection_drop(self): server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -135,21 +137,43 @@ async def mock_drop_server( try: with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.get(url) finally: server.close() await server.wait_closed() + async def test_streaming_get_with_no_sse(self): + with http.HttpClient(100) as client: + example_request = ( + 'GET', + self.base_url, + '/test-get-with-sse', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + "\"ok\"", + 200, + ) + ) + result = await client.stream_sse(url, method="GET") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.json(), "ok") + + +class HttpSSETest(tb.BaseHttpTest): async def test_immediate_connection_drop_streaming(self): """Test handling of a connection that is dropped immediately by the server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -160,31 +184,13 @@ async def mock_drop_server( try: with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.stream_sse(url) finally: server.close() await server.wait_closed() - async def test_streaming_get_with_no_sse(self): - with http.HttpClient(100) as client: - example_request = ( - 'GET', - self.base_url, - '/test-get-with-sse', - ) - url = f"{example_request[1]}{example_request[2]}" - self.mock_server.register_route_handler(*example_request)( - lambda _handler, request: ( - "\"ok\"", - 200, - ) - ) - result = await client.stream_sse(url, method="GET") - self.assertEqual(result.status_code, 200) - self.assertEqual(result.json(), "ok") - async def test_sse_with_mock_server(self): """Since the regular mock server doesn't support SSE, we need to test with a real socket. We handle just enough HTTP to get the job done.""" @@ -256,3 +262,70 @@ async def client_task(): await asyncio.wait_for(client_future, timeout=5.0) assert is_closed + + async def test_sse_with_mock_server_close(self): + """Try to close the server-side stream and see if the client detects + an end for the iterator. Note that this is technically not correct SSE: + the client should actually try to reconnect after the specified retry + interval, _but_ we don't handle retries yet.""" + + is_closed = False + + async def mock_sse_server( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + nonlocal is_closed + + await reader.readline() + + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Cache-Control: no-cache\r\n" + b"Connection: keep-alive\r\n\r\n" + ) + writer.write(headers) + await writer.drain() + + for i in range(3): + writer.write(b": test comment that should be ignored\n\n") + await writer.drain() + + writer.write( + f"event: message\ndata: Event {i + 1}\n\n".encode() + ) + await writer.drain() + await asyncio.sleep(0.1) + + await writer.drain() + writer.close() + await writer.wait_closed() + + is_closed = True + + server = await asyncio.start_server(mock_sse_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/sse' + + async def client_task(): + with http.HttpClient(100) as client: + response = await client.stream_sse(url, method="GET") + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'text/event-stream' + assert isinstance(response, http.ResponseSSE) + + events = [] + async for event in response: + self.assertEqual(event.event, 'message') + events.append(event) + + assert len(events) == 3 + assert events[0].data == 'Event 1' + assert events[1].data == 'Event 2' + assert events[2].data == 'Event 3' + + async with server: + client_future = asyncio.create_task(client_task()) + await asyncio.wait_for(client_future, timeout=5.0) + + assert is_closed From 30490ad472a1a84bc68a41aedd2664b9d8540c5e Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Wed, 13 Nov 2024 11:32:00 -0800 Subject: [PATCH 04/16] Fix clippy complaints (#7987) These are firing on my machine error: this manual char comparison can be written more succinctly --> edb/edgeql-parser/src/position.rs:127:48 | 127 | if let Some(loff) = prefix_s.rfind(|c| c == '\r' || c == '\n') { | ^^^^^^^^^^^^^^^^^^^^^^^^^^ help: consider using an array of `char`: `['\r', '\n']` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#manual_pattern_char_comparison Not sure why these aren't picked up in CI, but the fixes seem to make sense. --- edb/edgeql-parser/src/helpers/bytes.rs | 2 +- edb/edgeql-parser/src/position.rs | 2 +- edb/edgeql-parser/src/validation.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/edb/edgeql-parser/src/helpers/bytes.rs b/edb/edgeql-parser/src/helpers/bytes.rs index cffd954bc4e..f2b8976c86c 100644 --- a/edb/edgeql-parser/src/helpers/bytes.rs +++ b/edb/edgeql-parser/src/helpers/bytes.rs @@ -1,6 +1,6 @@ pub fn unquote_bytes(value: &str) -> Result, String> { let idx = value - .find(|c| c == '\'' || c == '"') + .find(['\'', '"']) .ok_or_else(|| "invalid bytes literal: missing quotes".to_string())?; let prefix = &value[..idx]; match prefix { diff --git a/edb/edgeql-parser/src/position.rs b/edb/edgeql-parser/src/position.rs index a5d8bdd02b2..2f93b529a06 100644 --- a/edb/edgeql-parser/src/position.rs +++ b/edb/edgeql-parser/src/position.rs @@ -124,7 +124,7 @@ impl InflatedPos { let prefix_s = from_utf8(prefix).map_err(InflatingError::Utf8)?; let line_offset; let line; - if let Some(loff) = prefix_s.rfind(|c| c == '\r' || c == '\n') { + if let Some(loff) = prefix_s.rfind(['\r', '\n']) { line_offset = loff + 1; let mut lines = &prefix[..loff]; if data[loff] == b'\n' && loff > 0 && data[loff - 1] == b'\r' { diff --git a/edb/edgeql-parser/src/validation.rs b/edb/edgeql-parser/src/validation.rs index fd05d869975..b1bd57f1228 100644 --- a/edb/edgeql-parser/src/validation.rs +++ b/edb/edgeql-parser/src/validation.rs @@ -167,7 +167,7 @@ pub fn parse_value(token: &Token) -> Result, String> { return Err("number is out of range for std::float64".to_string()); } if num == 0.0 { - let mend = text.find(|c| c == 'e' || c == 'E').unwrap_or(text.len()); + let mend = text.find(['e', 'E']).unwrap_or(text.len()); let mantissa = &text[..mend]; if mantissa.chars().any(|c| c != '0' && c != '.') { return Err("number is out of range for std::float64".to_string()); From c6e249cfa620caac08c8d8585dd4865e0d5e69e5 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Fri, 15 Nov 2024 04:05:07 +0800 Subject: [PATCH 05/16] Fix auth_jwt under Cython 3 (#7986) --- .github/workflows.src/tests.inc.yml | 6 +++--- .github/workflows/tests-ha.yml | 6 +++--- .github/workflows/tests-inplace.yml | 6 +++--- .github/workflows/tests-managed-pg.yml | 14 +++++++------- .github/workflows/tests-patches.yml | 6 +++--- .github/workflows/tests-pg-versions.yml | 6 +++--- .github/workflows/tests-pool.yml | 4 ++-- .github/workflows/tests-reflection.yml | 8 ++++---- .github/workflows/tests.yml | 8 ++++---- edb/server/protocol/auth_helpers.pxd | 2 +- edb/server/protocol/auth_helpers.pyx | 2 +- tests/test_server_auth.py | 6 ++++++ 12 files changed, 40 insertions(+), 34 deletions(-) diff --git a/.github/workflows.src/tests.inc.yml b/.github/workflows.src/tests.inc.yml index aba64c5b64c..fe36c283b5a 100644 --- a/.github/workflows.src/tests.inc.yml +++ b/.github/workflows.src/tests.inc.yml @@ -89,7 +89,7 @@ id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -195,7 +195,7 @@ if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -377,7 +377,7 @@ id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-ha.yml b/.github/workflows/tests-ha.yml index be61db8e813..899f4f32a9c 100644 --- a/.github/workflows/tests-ha.yml +++ b/.github/workflows/tests-ha.yml @@ -135,7 +135,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -241,7 +241,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -431,7 +431,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-inplace.yml b/.github/workflows/tests-inplace.yml index cabfd642236..bb736aaddb1 100644 --- a/.github/workflows/tests-inplace.yml +++ b/.github/workflows/tests-inplace.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -428,7 +428,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-managed-pg.yml b/.github/workflows/tests-managed-pg.yml index 766bc1f86a6..0f4b80752af 100644 --- a/.github/workflows/tests-managed-pg.yml +++ b/.github/workflows/tests-managed-pg.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -462,7 +462,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -704,7 +704,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -994,7 +994,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -1250,7 +1250,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -1494,7 +1494,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-patches.yml b/.github/workflows/tests-patches.yml index bd5da952e44..025282975bd 100644 --- a/.github/workflows/tests-patches.yml +++ b/.github/workflows/tests-patches.yml @@ -122,7 +122,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -228,7 +228,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -431,7 +431,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-pg-versions.yml b/.github/workflows/tests-pg-versions.yml index 21f1cef46fe..1b1bfb28f62 100644 --- a/.github/workflows/tests-pg-versions.yml +++ b/.github/workflows/tests-pg-versions.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -455,7 +455,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-pool.yml b/.github/workflows/tests-pool.yml index 1047d76bc29..2e2c21c8bc3 100644 --- a/.github/workflows/tests-pool.yml +++ b/.github/workflows/tests-pool.yml @@ -130,7 +130,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -236,7 +236,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: diff --git a/.github/workflows/tests-reflection.yml b/.github/workflows/tests-reflection.yml index 54f7f731a0e..f9ee8c816ac 100644 --- a/.github/workflows/tests-reflection.yml +++ b/.github/workflows/tests-reflection.yml @@ -122,7 +122,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -228,7 +228,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -371,7 +371,7 @@ jobs: with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - + - name: Download requirements.txt uses: actions/cache@v4 with: @@ -418,7 +418,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 450a48b7d1b..6cf0fb4ac90 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -132,7 +132,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -238,7 +238,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -511,7 +511,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -691,7 +691,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/edb/server/protocol/auth_helpers.pxd b/edb/server/protocol/auth_helpers.pxd index e43d6c4f504..9100d670051 100644 --- a/edb/server/protocol/auth_helpers.pxd +++ b/edb/server/protocol/auth_helpers.pxd @@ -18,7 +18,7 @@ cdef extract_token_from_auth_data(bytes auth_data) -cdef auth_jwt(tenant, str prefixed_token, str user, str dbname) +cdef auth_jwt(tenant, prefixed_token, str user, str dbname) cdef _check_jwt_authz(tenant, claims, token_version, str user, str dbname) cdef _get_jwt_edb_scope(claims, claim) cdef scram_get_verifier(tenant, str user) diff --git a/edb/server/protocol/auth_helpers.pyx b/edb/server/protocol/auth_helpers.pyx index ec8f1aeb52f..f0b22bc764f 100644 --- a/edb/server/protocol/auth_helpers.pyx +++ b/edb/server/protocol/auth_helpers.pyx @@ -39,7 +39,7 @@ cdef extract_token_from_auth_data(auth_data: bytes): return scheme.lower(), payload.strip() -cdef auth_jwt(tenant, prefixed_token: str, user: str, dbname: str): +cdef auth_jwt(tenant, prefixed_token: str | None, user: str, dbname: str): if not prefixed_token: raise errors.AuthenticationError( 'authentication failed: no authorization data provided') diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index 7e1cf5643fe..0c48bf40c33 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -406,6 +406,12 @@ async def test_server_auth_jwt_1(self): ''') await conn.aclose() + with self.assertRaisesRegex( + edgedb.AuthenticationError, + 'authentication failed: no authorization data provided', + ): + await sd.connect() + # bad secret keys with self.assertRaisesRegex( edgedb.AuthenticationError, From d392bc4a5294d2c19bcc9800caca024aec588dc2 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Fri, 15 Nov 2024 01:39:51 -0500 Subject: [PATCH 06/16] Allow volatile WITH in dml statements. (#7969) Given `type Foo { property a -> float64; };` Allows the query: `with x := random() insert Foo { a := x };` --- edb/edgeql/compiler/inference/cardinality.py | 4 +- edb/edgeql/compiler/inference/multiplicity.py | 4 +- edb/edgeql/compiler/inference/volatility.py | 2 +- edb/edgeql/compiler/stmt.py | 8 +- edb/edgeql/compiler/viewgen.py | 22 - edb/ir/ast.py | 2 +- edb/pgsql/compiler/clauses.py | 185 ++- edb/pgsql/compiler/dml.py | 2 +- edb/pgsql/compiler/group.py | 2 +- edb/pgsql/compiler/stmt.py | 2 +- tests/test_edgeql_insert.py | 1061 ++++++++++++++++- 11 files changed, 1195 insertions(+), 99 deletions(-) diff --git a/edb/edgeql/compiler/inference/cardinality.py b/edb/edgeql/compiler/inference/cardinality.py index ee7504b83c9..e0691701674 100644 --- a/edb/edgeql/compiler/inference/cardinality.py +++ b/edb/edgeql/compiler/inference/cardinality.py @@ -1205,7 +1205,7 @@ def _infer_stmt_cardinality( scope_tree: irast.ScopeTreeNode, ctx: inference_context.InfCtx, ) -> qltypes.Cardinality: - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_cardinality(part, scope_tree=scope_tree, ctx=ctx) result = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result @@ -1339,7 +1339,7 @@ def __infer_insert_stmt( scope_tree: irast.ScopeTreeNode, ctx: inference_context.InfCtx, ) -> qltypes.Cardinality: - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_cardinality(part, scope_tree=scope_tree, ctx=ctx) infer_cardinality( diff --git a/edb/edgeql/compiler/inference/multiplicity.py b/edb/edgeql/compiler/inference/multiplicity.py index 12126886d55..0ec86dca3c4 100644 --- a/edb/edgeql/compiler/inference/multiplicity.py +++ b/edb/edgeql/compiler/inference/multiplicity.py @@ -588,7 +588,7 @@ def _infer_stmt_multiplicity( ) -> inf_ctx.MultiplicityInfo: # WITH block bindings need to be validated; they don't have to # have multiplicity UNIQUE, but their sub-expressions must be valid. - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx) subj = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result @@ -688,7 +688,7 @@ def __infer_insert_stmt( ) -> inf_ctx.MultiplicityInfo: # WITH block bindings need to be validated, they don't have to # have multiplicity UNIQUE, but their sub-expressions must be valid. - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx) # INSERT will always return a proper set, but we still want to diff --git a/edb/edgeql/compiler/inference/volatility.py b/edb/edgeql/compiler/inference/volatility.py index 12593f63bf6..2317e8f99a8 100644 --- a/edb/edgeql/compiler/inference/volatility.py +++ b/edb/edgeql/compiler/inference/volatility.py @@ -320,7 +320,7 @@ def __infer_select_stmt( components.append(ir.limit) if ir.bindings is not None: - components.extend(ir.bindings) + components.extend(part for part, _ in ir.bindings) return _common_volatility(components, env) diff --git a/edb/edgeql/compiler/stmt.py b/edb/edgeql/compiler/stmt.py index 4c92ef53dab..02bdf674d09 100644 --- a/edb/edgeql/compiler/stmt.py +++ b/edb/edgeql/compiler/stmt.py @@ -68,6 +68,7 @@ from . import context from . import config_desc from . import dispatch +from . import inference from . import pathctx from . import policies from . import setgen @@ -1309,7 +1310,7 @@ def process_with_block( *, ctx: context.ContextLevel, parent_ctx: context.ContextLevel, -) -> List[irast.Set]: +) -> list[tuple[irast.Set, qltypes.Volatility]]: if edgeql_tree.aliases is None: return [] @@ -1329,7 +1330,10 @@ def process_with_block( binding_kind=irast.BindingKind.With, ctx=scopectx, ) - results.append(binding) + volatility = inference.infer_volatility( + binding, ctx.env, exclude_dml=True + ) + results.append((binding, volatility)) if reason := setgen.should_materialize(binding, ctx=ctx): had_materialized = True diff --git a/edb/edgeql/compiler/viewgen.py b/edb/edgeql/compiler/viewgen.py index a1a5daee93c..fd4f86e2f6c 100644 --- a/edb/edgeql/compiler/viewgen.py +++ b/edb/edgeql/compiler/viewgen.py @@ -69,7 +69,6 @@ from . import context from . import dispatch from . import eta_expand -from . import inference from . import pathctx from . import schemactx from . import setgen @@ -1946,27 +1945,6 @@ def _normalize_view_ptr_expr( assert ptrcls is not None - if materialized and is_mutation and any( - x.is_binding == irast.BindingKind.With - and x.expr - # If it is a computed pointer, look just at the definition. - # TODO: It is a weird artifact of how our shapes are defined - # that if a shape element is defined to be some WITH-bound variable, - # that set can is both a is_binding and an irast.Pointer. It seems like - # the is_binding part should be nested inside it. - and (y := x.expr.expr if isinstance(x.expr, irast.Pointer) else x.expr) - and inference.infer_volatility( - y, ctx.env, exclude_dml=True).is_volatile() - - for reason in materialized - if isinstance(reason, irast.MaterializeVisible) - for _, x in reason.sets - ): - raise errors.QueryError( - f'cannot refer to volatile WITH bindings from DML', - span=compexpr.span if compexpr else None, - ) - if materialized and not is_mutation and ctx.qlstmt: assert ptrcls not in ctx.env.materialized_sets ctx.env.materialized_sets[ptrcls] = ctx.qlstmt, materialized diff --git a/edb/ir/ast.py b/edb/ir/ast.py index e284fa55c75..157fcf72968 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -1143,7 +1143,7 @@ class Stmt(Expr): result: Set = DUMMY_SET parent_stmt: typing.Optional[Stmt] = None iterator_stmt: typing.Optional[Set] = None - bindings: typing.Optional[typing.List[Set]] = None + bindings: typing.Optional[list[tuple[Set, qltypes.Volatility]]] = None @property def typeref(self) -> TypeRef: diff --git a/edb/pgsql/compiler/clauses.py b/edb/pgsql/compiler/clauses.py index f8528b525ea..31e531edad4 100644 --- a/edb/pgsql/compiler/clauses.py +++ b/edb/pgsql/compiler/clauses.py @@ -34,6 +34,7 @@ from . import astutils from . import context from . import dispatch +from . import dml from . import enums as pgce from . import output from . import pathctx @@ -129,47 +130,56 @@ def compile_materialized_exprs( path_id=mat_set.materialized.path_id, ctx=matctx): continue - mat_ids = set(mat_set.uses) - - # We pack optional things into arrays also, since it works. - # TODO: use NULL? - card = mat_set.cardinality - assert card != qltypes.Cardinality.UNKNOWN - is_singleton = card.is_single() and not card.can_be_zero() - - old_scope = matctx.path_scope - matctx.path_scope = old_scope.new_child() - for mat_id in mat_ids: - for k in old_scope: - if k.startswith(mat_id): - matctx.path_scope[k] = None - mat_qry = relgen.set_as_subquery( - mat_set.materialized, as_value=True, ctx=matctx - ) + _compile_materialized_expr(query, mat_set, ctx=matctx) + + +def _compile_materialized_expr( + query: pgast.SelectStmt, + mat_set: irast.MaterializedSet, + *, + ctx: context.CompilerContextLevel, +) -> None: + mat_ids = set(mat_set.uses) + + # We pack optional things into arrays also, since it works. + # TODO: use NULL? + card = mat_set.cardinality + assert card != qltypes.Cardinality.UNKNOWN + is_singleton = card.is_single() and not card.can_be_zero() + + old_scope = ctx.path_scope + ctx.path_scope = old_scope.new_child() + for mat_id in mat_ids: + for k in old_scope: + if k.startswith(mat_id): + ctx.path_scope[k] = None + mat_qry = relgen.set_as_subquery( + mat_set.materialized, as_value=True, ctx=ctx + ) - if not is_singleton: - mat_qry = relctx.set_to_array( - path_id=mat_set.materialized.path_id, - query=mat_qry, - ctx=matctx) + if not is_singleton: + mat_qry = relctx.set_to_array( + path_id=mat_set.materialized.path_id, + query=mat_qry, + ctx=ctx) - if not mat_qry.target_list[0].name: - mat_qry.target_list[0].name = ctx.env.aliases.get('v') + if not mat_qry.target_list[0].name: + mat_qry.target_list[0].name = ctx.env.aliases.get('v') - ref = pgast.ColumnRef( - name=[mat_qry.target_list[0].name], - is_packed_multi=not is_singleton, - ) - for mat_id in mat_ids: - pathctx.put_path_packed_output(mat_qry, mat_id, ref) - - mat_rvar = relctx.rvar_for_rel(mat_qry, lateral=True, ctx=matctx) - for mat_id in mat_ids: - relctx.include_rvar( - query, mat_rvar, path_id=mat_id, - flavor='packed', update_mask=False, pull_namespace=False, - ctx=matctx, - ) + ref = pgast.ColumnRef( + name=[mat_qry.target_list[0].name], + is_packed_multi=not is_singleton, + ) + for mat_id in mat_ids: + pathctx.put_path_packed_output(mat_qry, mat_id, ref) + + mat_rvar = relctx.rvar_for_rel(mat_qry, lateral=True, ctx=ctx) + for mat_id in mat_ids: + relctx.include_rvar( + query, mat_rvar, path_id=mat_id, + flavor='packed', update_mask=False, pull_namespace=False, + ctx=ctx, + ) def compile_iterator_expr( @@ -260,20 +270,103 @@ def compile_output( return val -def compile_dml_bindings( - stmt: irast.Stmt, *, - ctx: context.CompilerContextLevel) -> None: - for binding in (stmt.bindings or ()): +def compile_volatile_bindings( + stmt: irast.Stmt, + *, + ctx: context.CompilerContextLevel +) -> None: + for binding, volatility in (stmt.bindings or ()): # If something we are WITH binding contains DML, we want to # compile it *now*, in the context of its initial appearance - # and not where the variable is used. This will populate - # dml_stmts with the CTEs, which will be picked up when the - # variable is referenced. - if irutils.contains_dml(binding): + # and not where the variable is used. + # + # Similarly, if something we are WITH binding is volatile and the stmt + # contains dml, we similarly want to compile it *now*. + + # If the binding is a with binding for a DML stmt, manually construct + # the CTEs. + # + # Note: This condition is checked first, because if the binding + # *references* DML then contains_dml is true. If the binding is compiled + # normally, since the referenced DML was already compiled, the rvar will + # be retrieved, and no CTEs will be set up. + if volatility.is_volatile() and irutils.contains_dml(stmt): + _compile_volatile_binding_for_dml(stmt, binding, ctx=ctx) + + # For typical DML, just compile it. This will populate dml_stmts with + # the CTEs, which will be picked up when the variable is referenced. + elif irutils.contains_dml(binding): with ctx.substmt() as bctx: dispatch.compile(binding, ctx=bctx) +def _compile_volatile_binding_for_dml( + stmt: irast.Stmt, + binding: irast.Set, + *, + ctx: context.CompilerContextLevel +) -> None: + materialized_set = None + if ( + stmt.materialized_sets + and binding.typeref.id in stmt.materialized_sets + ): + materialized_set = stmt.materialized_sets[binding.typeref.id] + assert materialized_set is not None + + last_iterator = ctx.enclosing_cte_iterator + + with ( + context.output_format(ctx, context.OutputFormat.NATIVE), + ctx.newrel() as matctx + ): + matctx.materializing |= {stmt} + matctx.expr_exposed = True + + dml.merge_iterator(last_iterator, matctx.rel, ctx=matctx) + setup_iterator_volatility(last_iterator, ctx=matctx) + + _compile_materialized_expr( + matctx.rel, materialized_set, ctx=matctx + ) + + # Add iterator identity + bind_pathid = ( + irast.PathId.new_dummy(ctx.env.aliases.get('bind_path')) + ) + with matctx.subrel() as bind_pathid_ctx: + relctx.create_iterator_identity_for_path( + bind_pathid, bind_pathid_ctx.rel, ctx=bind_pathid_ctx + ) + bind_id_rvar = relctx.rvar_for_rel( + bind_pathid_ctx.rel, lateral=True, ctx=matctx + ) + relctx.include_rvar( + matctx.rel, bind_id_rvar, path_id=bind_pathid, ctx=matctx + ) + + bind_cte = pgast.CommonTableExpr( + name=ctx.env.aliases.get('bind'), + query=matctx.rel, + materialized=False, + ) + + bind_iterator = pgast.IteratorCTE( + path_id=bind_pathid, + cte=bind_cte, + parent=last_iterator, + iterator_bond=True, + ) + ctx.toplevel_stmt.append_cte(bind_cte) + + # Merge the new iterator + ctx.path_scope = ctx.path_scope.new_child() + dml.merge_iterator(bind_iterator, ctx.rel, ctx=ctx) + setup_iterator_volatility(bind_iterator, ctx=ctx) + + ctx.enclosing_cte_iterator = bind_iterator + + def compile_filter_clause( ir_set: irast.Set, cardinality: qltypes.Cardinality, *, diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index 57da2adbe95..c443cbe433a 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -102,7 +102,7 @@ def init_dml_stmt( range_cte: Optional[pgast.CommonTableExpr] range_rvar: Optional[pgast.RelRangeVar] - clauses.compile_dml_bindings(ir_stmt, ctx=ctx) + clauses.compile_volatile_bindings(ir_stmt, ctx=ctx) if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate diff --git a/edb/pgsql/compiler/group.py b/edb/pgsql/compiler/group.py index 3a225f55877..a3c8e5a0c9a 100644 --- a/edb/pgsql/compiler/group.py +++ b/edb/pgsql/compiler/group.py @@ -173,7 +173,7 @@ def _compile_group( ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) -> pgast.BaseExpr: - clauses.compile_dml_bindings(stmt, ctx=ctx) + clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt diff --git a/edb/pgsql/compiler/stmt.py b/edb/pgsql/compiler/stmt.py index 9dd11aae42a..0b0a1b75f76 100644 --- a/edb/pgsql/compiler/stmt.py +++ b/edb/pgsql/compiler/stmt.py @@ -54,7 +54,7 @@ def compile_SelectStmt( parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common setup. - clauses.compile_dml_bindings(stmt, ctx=ctx) + clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt diff --git a/tests/test_edgeql_insert.py b/tests/test_edgeql_insert.py index 2a6ea609752..7123ddb9d9a 100644 --- a/tests/test_edgeql_insert.py +++ b/tests/test_edgeql_insert.py @@ -5822,33 +5822,1054 @@ async def test_edgeql_insert_cardinality_assertion(self): @tb.needs_factoring_weakly async def test_edgeql_insert_volatile_01(self): - # Ideally we'll support these versions eventually - async with self.assertRaisesRegexTx( - edgedb.QueryError, - "cannot refer to volatile WITH bindings from DML"): - await self.con.execute(''' - WITH name := random(), - INSERT Person { name := name, tag := name }; - ''') + await self.con.execute(''' + WITH name := random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_02(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ "!", + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_03(self): + await self.con.execute(''' + WITH + x := "!", + name := x ++ random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_04(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_05(self): + await self.con.execute(''' + WITH name := random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_06(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ "!", + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_07(self): + await self.con.execute(''' + WITH + x := "!", + name := x ++ random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_08(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_09(self): + await self.con.execute(''' + WITH x := random() + SELECT ( + WITH name := x ++ "!" + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_10(self): + await self.con.execute(''' + WITH x := "!" + SELECT ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_11(self): + await self.con.execute(''' + WITH x := random() + SELECT ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_12(self): + await self.con.execute(''' + WITH + x := random(), + y := x ++ random(), + SELECT ( + WITH name := y ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_13(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name } + ) + SELECT ( + INSERT Person { name := x.name ++ "!", tag := x.tag } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_14(self): + await self.con.execute(''' + WITH + x := "!", + y := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_15(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_16(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_17(self): + await self.con.execute(''' + WITH + x := "!", + y := ( + WITH name := x ++ random(), + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_18(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := x ++ "!", + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_19(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := x ++ random(), + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_20(self): + await self.con.execute(''' + WITH + x := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := random(), + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_21(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := "!", + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_22(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := random(), + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_23(self): + await self.con.execute(''' + WITH + x := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ random(), + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_24(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ "!", + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_25(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ random(), + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_26(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { + name := name, + tag := name, + tag2 := name, + } + ), + y := ( + WITH r := random(), + INSERT Person { + name := x.name ++ r, + tag := x.tag ++ r, + tag2 := x.tag, + } + ), + SELECT ( + WITH r := random(), + INSERT Person { + name := y.name ++ r, + tag := y.name ++ r, + tag2 := y.tag ++ r, + } + ); + ''') - async with self.assertRaisesRegexTx( - edgedb.QueryError, - "cannot refer to volatile WITH bindings from DML"): - await self.con.execute(''' - WITH name := random(), - SELECT (INSERT Person { name := name, tag := name }); - ''') + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [3], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [3], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_27(self): + await self.con.execute(''' + WITH x := "!" + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + await self.assert_query_result( + 'SELECT all(Note.name = Person.note)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_28(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + await self.assert_query_result( + 'SELECT all(Note.name = Person.note)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_29(self): + await self.con.execute(''' + WITH x := "!", + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + await self.assert_query_result( + 'SELECT all(Note.name = Person.note)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_30(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ "!" + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + await self.assert_query_result( + 'SELECT all(Note.name = Person.note)', + [True], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_31(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True], + ) + await self.assert_query_result( + 'SELECT all(Note.name = Person.note)', + [True], + ) + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_32(self): await self.con.execute(''' - FOR name in {random()} + FOR name in {random(), random()} UNION (INSERT Person { name := name, tag := name }); ''') await self.assert_query_result( - r''' - SELECT all(Person.name = Person.tag) - ''', - [True] + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_33(self): + await self.con.execute(''' + WITH x := "!" + FOR y in {random(), random()} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_34(self): + await self.con.execute(''' + WITH x := random() + FOR y in {"A", "B"} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_35(self): + await self.con.execute(''' + WITH x := random() + FOR y in {random(), random()} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_36(self): + await self.con.execute(''' + WITH x := "!" + FOR name in {x ++ random(), x ++ random()} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_37(self): + await self.con.execute(''' + WITH x := random() + FOR name in {x ++ "A", x ++ "B"} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_38(self): + await self.con.execute(''' + WITH x := random() + FOR name in {x ++ random(), x ++ random()} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_39(self): + await self.con.execute(''' + FOR x in {"A", "B"} + UNION ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_40(self): + await self.con.execute(''' + FOR x in {random(), random()} + UNION ( + WITH name := x ++ "!" + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_41(self): + await self.con.execute(''' + FOR x in {random(), random()} + UNION ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + @tb.needs_factoring_weakly + async def test_edgeql_insert_volatile_42(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { + name := name, + tag := name, + tag2 := name, + } + ) + FOR y in {random(), random()} + UNION ( + WITH name := x.name ++ y + INSERT Person { name := name, tag := name, tag2 := x.tag2 } + ); + ''') + + await self.assert_query_result( + 'SELECT all(Person.name = Person.tag)', + [True, True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], ) async def test_edgeql_insert_multi_exclusive_01(self): From cc14d6c453dfed2f0e960d36804c1a99a518f816 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Fri, 15 Nov 2024 10:15:26 -0700 Subject: [PATCH 07/16] Ignore hanging test (#7997) --- tests/test_http.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_http.py b/tests/test_http.py index 1afe0ffc34c..673f407215d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -19,6 +19,7 @@ import asyncio import json import random +import unittest from edb.server import http from edb.testbase import http as tb @@ -263,6 +264,7 @@ async def client_task(): assert is_closed + @unittest.skip("Hangs on CI") async def test_sse_with_mock_server_close(self): """Try to close the server-side stream and see if the client detects an end for the iterator. Note that this is technically not correct SSE: From ef4bdc4e96da58729ab1332cbf55f3adf2aff988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Fri, 15 Nov 2024 18:39:56 +0100 Subject: [PATCH 08/16] docs: expand SQL adapter docs (#7895) --- docs/reference/index.rst | 2 +- .../{sql_support.rst => sql_adapter.rst} | 216 ++++++++++++++++-- docs/stdlib/cfg.rst | 2 +- 3 files changed, 200 insertions(+), 20 deletions(-) rename docs/reference/{sql_support.rst => sql_adapter.rst} (50%) diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c418325fd31..527747c13db 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -24,7 +24,7 @@ Reference backend_ha configuration http - sql_support + sql_adapter protocol/index bindings/index admin/index diff --git a/docs/reference/sql_support.rst b/docs/reference/sql_adapter.rst similarity index 50% rename from docs/reference/sql_support.rst rename to docs/reference/sql_adapter.rst index 947252cf9a4..381c39e8cdf 100644 --- a/docs/reference/sql_support.rst +++ b/docs/reference/sql_adapter.rst @@ -1,9 +1,9 @@ .. versionadded:: 3.0 -.. _ref_sql_support: +.. _ref_sql_adapter: =========== -SQL support +SQL adapter =========== .. edb:youtube-embed:: 0KdY2MPb2oc @@ -11,12 +11,21 @@ SQL support Connecting ========== -EdgeDB supports running read-only SQL queries via the Postgres protocol to -enable connecting EdgeDB to existing BI and analytics solutions. Any -Postgres-compatible client can connect to your EdgeDB database by using the +EdgeDB server supports PostgreSQL connection interface. It implements PostgreSQL +wire protocol as well as SQL query language. + +As of EdgeDB 6.0, it also supports a subset of Data Modification Language, +namely INSERT, DELETE and UPDATE statements. + +It does not, however, support PostgreSQL Data Definition Language +(e.g. ``CREATE TABLE``). This means that it is not possible to use SQL +connections to EdgeDB to modify its schema. Instead, the schema should be +managed using ESDL (EdgeDB Schema Definition Language) and migration commands. + +Any Postgres-compatible client can connect to an EdgeDB database by using the same port that is used for the EdgeDB protocol and the -:versionreplace:`database;5.0:branch` name, username, and password you already -use for your database. +:versionreplace:`database;5.0:branch` name, username, and password already used +for the database. .. versionchanged:: _default @@ -52,7 +61,7 @@ use for your database. The insecure DSN returned by the CLI for EdgeDB Cloud instances will not contain the password. You will need to either :ref:`create a new role and - set the password `, using those values to connect + set the password `, using those values to connect to your SQL client, or change the password of the existing role, using that role name along with the newly created password. @@ -76,7 +85,7 @@ use for your database. ``libpq.dll``, click "Properties," and find the version on the "Details" tab. -.. _ref_sql_support_new_role: +.. _ref_sql_adapter_new_role: Creating a new role ------------------- @@ -177,24 +186,29 @@ Multi properties are in separate tables. ``source`` is the ``id`` of the Movie. SELECT source, target FROM "Movie.labels"; -When types are extended, parent object types' tables will by default contain -all objects of both the type and any types extended by it. The query below will +When using inheritance, parent object types' tables will by default contain +all objects of both the parent type and any child types. The query below will return all ``common::Content`` objects as well as all ``Movie`` objects. .. code-block:: sql SELECT id, title FROM common."Content"; -To omit objects of extended types, use ``ONLY``. This query will return +To omit objects of child types, use ``ONLY``. This query will return ``common::Content`` objects but not ``Movie`` objects. .. code-block:: sql SELECT id, title FROM ONLY common."Content"; -The SQL connector supports read-only statements and will throw errors if the -client attempts ``INSERT``, ``UPDATE``, ``DELETE``, or any DDL command. It -supports all SQL expressions supported by Postgres. +The SQL adapter supports a large majority of SQL language, including: + +- ``SELECT`` and all read-only constructs (``WITH``, sub-query, ``JOIN``, ...), +- ``INSERT`` / ``UPDATE`` / ``DELETE``, +- ``COPY ... FROM``, +- ``SET`` / ``RESET`` / ``SHOW``, +- transaction commands, +- ``PREPARE`` / ``EXECUTE`` / ``DEALLOCATE``. .. code-block:: sql @@ -207,8 +221,8 @@ supports all SQL expressions supported by Postgres. WHERE act.source = m.id ); -EdgeDB accomplishes this by emulating the ``information_schema`` and -``pg_catalog`` views to mimic the catalogs provided by Postgres 13. +The SQL adapter emulates the ``information_schema`` and ``pg_catalog`` views to +mimic the catalogs provided by Postgres 13. .. note:: @@ -244,7 +258,7 @@ Tested SQL tools include `XMIN Replication`_, incremental updates using "a user-defined monotonically increasing id," and full table updates. .. [2] dbt models are built and stored in the database as either tables or - views. Because the EdgeDB SQL connector does not allow writing or even + views. Because the EdgeDB SQL adapter does not allow writing or even creating schemas, view, or tables, any attempt to materialize dbt models will result in errors. If you want to build the models, we suggest first transferring your data to a true Postgres instance via pg_dump or Airbyte. @@ -254,3 +268,169 @@ Tested SQL tools https://www.postgresql.org/docs/current/runtime-config-replication.html .. _XMIN Replication: https://www.postgresql.org/docs/15/ddl-system-columns.html + + +ESDL to PostgreSQL +================== + +As mentioned, the SQL schema of the database is managed trough EdgeDB Schema +Definition Language (ESDL). Here is a breakdown of how each of the ESDL +construct is mapped to PostgreSQL schema: + +- Objects types are mapped into tables. + Each table has columns ``id UUID`` and ``__type__ UUID`` and one column for + each single property or link. + +- Single properties are mapped to tables columns. + +- Single links are mapped to table columns with suffix ``_id`` and are of type + ``UUID``. They contain the ids of the link's target type. + +- Multi properties are mapped to tables with two columns: + - ``source UUID``, which contains the id of the property's source object type, + - ``target``, which contains values of the property. + +- Multi links are mapped to tables with columns: + - ``source UUID``, which contains the id of the property's source object type, + - ``target UUID``, which contains the ids of the link's target object type, + - one column for each link property, using the same rules as properties on + object types. + +- Aliases are not mapped to PostgreSQL schema. + +- Globals are mapped to connection settings, prefixed with ``global ``. + For example, a ``global default::username: str`` can be set using + ``SET "global default::username" TO 'Tom'``. + +- Access policies are applied to object type tables when setting + ``apply_access_policies_sql`` is set to ``true``. + +- Mutation rewrites and triggers are applied to all DML commands. + + +DML commands +============ + +When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation rewrites +and triggers are applied. These commands do not have a straight-forward +translation to EdgeQL DML commands, but instead use the following mapping: + +- ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, + +- ``INSERT INTO "Foo.keywords"`` link/property table maps to an + ``update Foo { keywords += ... }``, + +- ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, + +- ``DELETE FROM "Foo.keywords"`` link property/table maps to + ``update Foo { keywords -= ... }``, + +- ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + +- ``UPDATE "Foo.keywords"`` is not supported. + + +Connection settings +=================== + +SQL adapter supports a limited subset of PostgreSQL connection settings. +There are the following additionally connection settings: + +- ``allow_user_specified_id`` (default ``false``), +- ``apply_access_policies_sql`` (default ``false``), +- settings prefixed with ``"global "`` can use used to set values of globals. + +Note that if ``allow_user_specified_id`` or ``apply_access_policies_sql`` are +unset, they default to configuration set by ``configure current database`` +EdgeQL command. + + +Example: gradual transition from ORMs to EdgeDB +=============================================== + +When a project is using Object-Relational Mappings (e.g. SQLAlchemy, Django, +Hibernate ORM, TypeORM) and is considering the migration to EdgeDB, it might +want to execute the transition gradually, as opposed to a total rewrite of the +project. + +In this case, the project can start the transition by migrating the ORM models +to EdgeDB Schema Definition Language. + +For example, such Hibernate ORM model in Java: + +.. code-block:: + + @Entity + class Movie { + @Id + @GeneratedValue(strategy = GenerationType.UUID) + UUID id; + + private String title; + + @NotNull + private Integer releaseYear; + + // ... getters and setters ... + } + +... would be translated to the following EdgeDB SDL: + +.. code-block:: sdl + + type Movie { + title: str; + + required releaseYear: int32; + } + +A new EdgeDB instance can now be created and migrated to the translated schema. +At this stage, EdgeDB will allow SQL connections to write into the ``"Movie"`` +table, just as it would have been created with the following DDL command: + +.. code-block:: sql + + CREATE TABLE "Movie" ( + id UUID PRIMARY KEY DEFAULT (...), + __type__ UUID NOT NULL DEFAULT (...), + title TEXT, + releaseYear INTEGER NOT NULL + ); + +When translating the old ORM model to EdgeDB SDL, one should aim to make the +SQL schema of EdgeDB match the SQL schema that the ORM expects. + +When this match is accomplished, any query that used to work with the old, plain +PostgreSQL, should now also work with the EdgeDB. For example, we can execute +the following query: + +.. code-block:: sql + + INSERT INTO "Movie" (title, releaseYear) + VALUES ("Madagascar", 2012) + RETURNING id, title, releaseYear; + +To complete the migration, the data can be exported from our old database into +an ``.sql`` file, which can be import it into EdgeDB: + +.. code-block:: bash + + $ pg_dump {your PostgreSQL connection params} \ + --data-only --inserts --no-owner --no-privileges \ + > dump.sql + + $ psql {your EdgeDB connection params} --file dump.sql + +Now, the ORM can be pointed to EdgeDB instead of the old PostgreSQL database, +which has been fully replaced. + +Arguably, the development of new features with the ORM is now more complex for +the duration of the transition, since the developer has to modify two model +definitions: the ORM and the EdgeDB schema. + +But it allows any new models to use EdgeDB schema, EdgeQL and code generators +for the client language of choice. The ORM-based code can now also be gradually +rewritten to use EdgeQL, one model at the time. + +For a detailed migration example, see repository +`edgedb/hibernate-example `_. diff --git a/docs/stdlib/cfg.rst b/docs/stdlib/cfg.rst index 18fbc11cef7..3e3a37d8430 100644 --- a/docs/stdlib/cfg.rst +++ b/docs/stdlib/cfg.rst @@ -439,7 +439,7 @@ Client connections - EdgeDB binary protocol * - ``cfg::ConnectionTransport.TCP_PG`` - Postgres protocol for the - :ref:`SQL query mode ` + :ref:`SQL query mode ` * - ``cfg::ConnectionTransport.HTTP`` - EdgeDB binary protocol :ref:`tunneled over HTTP ` From f210a5ef155dff6d8da326d2218ae5858e1f1b0a Mon Sep 17 00:00:00 2001 From: dnwpark Date: Fri, 15 Nov 2024 13:13:19 -0500 Subject: [PATCH 09/16] Fix ISE when enumerating a call to an aggregate function. (#7988) Aggregate functions are not allowed in `ROWS FROM (...)`. --- edb/pgsql/compiler/relgen.py | 5 ++++- tests/test_edgeql_functions.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index c0e0390fe8b..49d4de42f71 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -2823,9 +2823,12 @@ def process_set_as_enumerate( or arg_expr.limit or arg_expr.offset ) + ) and not any( + f_arg.param_typemod == qltypes.TypeModifier.SetOfType + for _, f_arg in arg_subj.args.items() ) ): - # Enumeration of a SET-returning function + # Enumeration of a non-aggregate function rvars = process_set_as_func_enumerate(ir_set, ctx=ctx) else: rvars = process_set_as_simple_enumerate(ir_set, ctx=ctx) diff --git a/tests/test_edgeql_functions.py b/tests/test_edgeql_functions.py index bd6eb250742..a82fe40236e 100644 --- a/tests/test_edgeql_functions.py +++ b/tests/test_edgeql_functions.py @@ -851,6 +851,23 @@ async def test_edgeql_functions_enumerate_08(self): ]) ) + async def test_edgeql_functions_enumerate_09(self): + await self.assert_query_result( + 'SELECT enumerate(sum({1,2,3}))', + [[0, 6]] + ) + await self.assert_query_result( + 'SELECT enumerate(count(Issue))', + [[0, 4]] + ) + await self.assert_query_result( + ''' + WITH x := (SELECT enumerate(array_agg((select User)))), + SELECT (x.0, array_unpack(x.1).name) + ''', + [[0, 'Elvis'], [0, 'Yury']] + ) + async def test_edgeql_functions_array_get_01(self): await self.assert_query_result( r'''SELECT array_get([1, 2, 3], 2);''', From c951b9ecd3d7affa47ba74149af8bbd44997388d Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 31 Oct 2024 11:57:12 -0700 Subject: [PATCH 10/16] setup: Allow building the CLI from a local path --- Makefile | 4 ++++ setup.py | 35 ++++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 44f57763c8e..f434cc0b3ba 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,10 @@ rust: build-reqs BUILD_EXT_MODE=rust-only python setup.py build_ext --inplace +cli: build-reqs + python setup.py build_cli + + docs: build-reqs find docs -name '*.rst' | xargs touch $(MAKE) -C docs html SPHINXOPTS=$(SPHINXOPTS) BUILDDIR="../build" diff --git a/setup.py b/setup.py index 22f78142163..3bb5408b248 100644 --- a/setup.py +++ b/setup.py @@ -413,21 +413,30 @@ def _compile_cli(build_base, build_temp): env = dict(os.environ) env['CARGO_TARGET_DIR'] = str(build_temp / 'rust' / 'cli') env['PSQL_DEFAULT_PATH'] = build_base / 'postgres' / 'install' / 'bin' - git_ref = env.get("EDGEDBCLI_GIT_REV") or EDGEDBCLI_COMMIT - git_rev = _get_git_rev(EDGEDBCLI_REPO, git_ref) - - subprocess.run( - [ - 'cargo', 'install', - '--verbose', '--verbose', + path = env.get("EDGEDBCLI_PATH") + args = [ + 'cargo', 'install', + '--verbose', '--verbose', + '--bin', 'edgedb', + '--root', rust_root, + '--features=dev_mode', + '--locked', + '--debug', + ] + if path: + args.extend([ + '--path', path, + ]) + else: + git_ref = env.get("EDGEDBCLI_GIT_REV") or EDGEDBCLI_COMMIT + git_rev = _get_git_rev(EDGEDBCLI_REPO, git_ref) + args.extend([ '--git', EDGEDBCLI_REPO, '--rev', git_rev, - '--bin', 'edgedb', - '--root', rust_root, - '--features=dev_mode', - '--locked', - '--debug', - ], + ]) + + subprocess.run( + args, env=env, check=True, ) From 33d29f3a145baeab0a52b7b8d1c8ac4ea92fd4bd Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 8 Nov 2024 16:22:46 -0800 Subject: [PATCH 11/16] edb.pgsql.parser: Only ignore .c files in root Various tooling picks up the ignore rule resulting in libpg_query stuff getting ignored inappropriately. --- edb/pgsql/parser/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edb/pgsql/parser/.gitignore b/edb/pgsql/parser/.gitignore index 064a8d8ef55..b5a455aa984 100644 --- a/edb/pgsql/parser/.gitignore +++ b/edb/pgsql/parser/.gitignore @@ -1 +1 @@ -*.c +/*.c From b4bdbc235feb6f66f7183bc2bc813efb1fb2b237 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 30 Sep 2024 10:42:30 -0700 Subject: [PATCH 12/16] Initial support for SQL over native protocol The binary protocol version is bumped to 3.0 to accommodate the addition of the new `input_language` field to the `Parse` and `Execute` messages. The field can be set to `b'E'` for EdgeQL or to `b'S'` for SQL. The result of a query sent as SQL is always encoded as a free object. In other regards the SQL execution and caching pipeline is identical to EdgeQL. Caveats and limitations: Not all query features are supported when in SQL mode, specifically cardinality assertions and output format are always `MANY` and `BINARY` correspondingly. DDL and explicit prepared statements are unsupported. Multi-statement queries are not supported yet. --- .github/workflows.src/tests.inc.yml | 2 +- .github/workflows/tests-ha.yml | 2 +- .github/workflows/tests-inplace.yml | 2 +- .github/workflows/tests-managed-pg.yml | 2 +- .github/workflows/tests-patches.yml | 2 +- .github/workflows/tests-pg-versions.yml | 2 +- .github/workflows/tests-pool.yml | 2 +- .github/workflows/tests.yml | 2 +- .gitignore | 4 +- .gitmodules | 2 +- Makefile | 4 + docs/changelog/1_0_rc2.rst | 2 +- edb/buildmeta.py | 2 +- edb/edgeql/__init__.py | 1 + edb/edgeql/tokenizer.py | 6 + edb/pgsql/parser/__init__.py | 31 ++- edb/pgsql/parser/libpg_query | 2 +- edb/pgsql/parser/parser.pxd | 41 +++ edb/pgsql/parser/parser.pyx | 330 +++++++++++++++++++++++- edb/pgsql/resolver/__init__.py | 24 ++ edb/pgsql/resolver/context.py | 3 + edb/protocol/messages.py | 8 + edb/protocol/protocol.pyx | 2 + edb/server/compiler/__init__.py | 3 +- edb/server/compiler/compiler.py | 156 ++++++++++- edb/server/compiler/dbstate.py | 17 +- edb/server/compiler/ddl.py | 2 +- edb/server/compiler/enums.py | 5 + edb/server/compiler/rpc.pxd | 7 +- edb/server/compiler/rpc.pyi | 6 + edb/server/compiler/rpc.pyx | 160 ++++++++---- edb/server/compiler/sertypes.py | 8 +- edb/server/compiler/sql.py | 75 +++++- edb/server/dbview/dbview.pxd | 3 + edb/server/dbview/dbview.pyi | 2 +- edb/server/dbview/dbview.pyx | 153 ++++++++++- edb/server/defines.py | 2 +- edb/server/pgcon/pgcon.pxd | 1 + edb/server/pgcon/pgcon.pyi | 5 + edb/server/pgcon/pgcon.pyx | 160 +++++++++++- edb/server/pgproto | 2 +- edb/server/protocol/args_ser.pxd | 1 + edb/server/protocol/args_ser.pyx | 32 ++- edb/server/protocol/binary.pyx | 82 +++++- edb/server/protocol/execute.pyx | 4 +- edb/server/tenant.py | 2 +- edb/testbase/connection.py | 4 + edb/testbase/server.py | 1 + pyproject.toml | 4 +- setup.py | 16 +- tests/test_http_auth.py | 2 + tests/test_protocol.py | 7 + tests/test_server_config.py | 1 + tests/test_server_ops.py | 1 + 54 files changed, 1272 insertions(+), 130 deletions(-) create mode 100644 edb/pgsql/parser/parser.pxd diff --git a/.github/workflows.src/tests.inc.yml b/.github/workflows.src/tests.inc.yml index fe36c283b5a..126c66044cc 100644 --- a/.github/workflows.src/tests.inc.yml +++ b/.github/workflows.src/tests.inc.yml @@ -123,7 +123,7 @@ steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-ha.yml b/.github/workflows/tests-ha.yml index 899f4f32a9c..5178c014516 100644 --- a/.github/workflows/tests-ha.yml +++ b/.github/workflows/tests-ha.yml @@ -169,7 +169,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-inplace.yml b/.github/workflows/tests-inplace.yml index bb736aaddb1..6cc5a61f149 100644 --- a/.github/workflows/tests-inplace.yml +++ b/.github/workflows/tests-inplace.yml @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-managed-pg.yml b/.github/workflows/tests-managed-pg.yml index 0f4b80752af..5c930923be2 100644 --- a/.github/workflows/tests-managed-pg.yml +++ b/.github/workflows/tests-managed-pg.yml @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-patches.yml b/.github/workflows/tests-patches.yml index 025282975bd..9078bb212e9 100644 --- a/.github/workflows/tests-patches.yml +++ b/.github/workflows/tests-patches.yml @@ -156,7 +156,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-pg-versions.yml b/.github/workflows/tests-pg-versions.yml index 1b1bfb28f62..e83d3ae547c 100644 --- a/.github/workflows/tests-pg-versions.yml +++ b/.github/workflows/tests-pg-versions.yml @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests-pool.yml b/.github/workflows/tests-pool.yml index 2e2c21c8bc3..20b58139aef 100644 --- a/.github/workflows/tests-pool.yml +++ b/.github/workflows/tests-pool.yml @@ -164,7 +164,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6cf0fb4ac90..00a208bb865 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -166,7 +166,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | diff --git a/.gitignore b/.gitignore index 6b75b711cce..37c1924e3a3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,8 @@ *.pyo *.o *.so -.vscode +.vscode/ +.zed/ *~ .#* .*.swp @@ -37,3 +38,4 @@ docs/_build /.vagga /.dmypy.json /compile_commands.json +/pyrightconfig.json diff --git a/.gitmodules b/.gitmodules index e286e4030fa..26b22531174 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ url = https://github.com/MagicStack/py-pgproto.git [submodule "edb/pgsql/parser/libpg_query"] path = edb/pgsql/parser/libpg_query - url = https://github.com/msullivan/libpg_query.git + url = https://github.com/edgedb/libpg_query.git diff --git a/Makefile b/Makefile index f434cc0b3ba..b301fbb8221 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,10 @@ parsers: python setup.py build_parsers --inplace +libpg-query: + python setup.py build_libpg_query + + ui: build-reqs python setup.py build_ui diff --git a/docs/changelog/1_0_rc2.rst b/docs/changelog/1_0_rc2.rst index 427ba33c6be..f4a2fff8520 100644 --- a/docs/changelog/1_0_rc2.rst +++ b/docs/changelog/1_0_rc2.rst @@ -298,7 +298,7 @@ Server configuration ``EDGEDB_SERVER_SECURITY`` - ``strict == default`` - - ``insecure_dev_mode`` — disable password-based authentication and allow + - ``insecure_dev_mode`` — disable password-based authentication and allow unencrypted HTTP traffic ``EDGEDB_DOCKER_APPLY_MIGRATIONS`` (Docker only) diff --git a/edb/buildmeta.py b/edb/buildmeta.py index cfa5ff88228..54dd684c0c3 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_11_12_00_00 +EDGEDB_CATALOG_VERSION = 2024_11_12_01_00 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/edgeql/__init__.py b/edb/edgeql/__init__.py index dcbdd259105..11d17479b1d 100644 --- a/edb/edgeql/__init__.py +++ b/edb/edgeql/__init__.py @@ -24,3 +24,4 @@ from .codegen import generate_source # NOQA from .parser import parse_fragment, parse_block, parse_query # NOQA from .parser.grammar import keywords # NOQA +from .quote import quote_literal, quote_ident # NOQA diff --git a/edb/edgeql/tokenizer.py b/edb/edgeql/tokenizer.py index b721d0adce8..4a85fa23300 100644 --- a/edb/edgeql/tokenizer.py +++ b/edb/edgeql/tokenizer.py @@ -77,6 +77,12 @@ def extra_counts(self) -> Sequence[int]: def extra_blobs(self) -> Sequence[bytes]: return () + def extra_formatted_as_text(self) -> bool: + return False + + def extra_type_oids(self) -> Sequence[int]: + return () + def serialize(self) -> bytes: return self._serialized diff --git a/edb/pgsql/parser/__init__.py b/edb/pgsql/parser/__init__.py index db3565bbdd0..063e51f2c6c 100644 --- a/edb/pgsql/parser/__init__.py +++ b/edb/pgsql/parser/__init__.py @@ -16,19 +16,40 @@ # limitations under the License. # -from typing import List +from __future__ import annotations + +from typing import ( + List, +) import json from edb.pgsql import ast as pgast -from .parser import pg_parse -from .ast_builder import build_stmts +from . import ast_builder +from . import parser +from .parser import ( + Source, + NormalizedSource, + deserialize, +) + + +__all__ = ( + "parse", + "Source", + "NormalizedSource", + "deserialize" +) def parse( sql_query: str, propagate_spans: bool = False ) -> List[pgast.Query | pgast.Statement]: - ast_json = pg_parse(bytes(sql_query, encoding="UTF8")) + ast_json = parser.pg_parse(bytes(sql_query, encoding="UTF8")) - return build_stmts(json.loads(ast_json), sql_query, propagate_spans) + return ast_builder.build_stmts( + json.loads(ast_json), + sql_query, + propagate_spans, + ) diff --git a/edb/pgsql/parser/libpg_query b/edb/pgsql/parser/libpg_query index c773fdd7100..b31c55490e2 160000 --- a/edb/pgsql/parser/libpg_query +++ b/edb/pgsql/parser/libpg_query @@ -1 +1 @@ -Subproject commit c773fdd7100c175d0bfbb8be3a79d1f46b370f46 +Subproject commit b31c55490e24328947ceaa090b4f5f97800cf26d diff --git a/edb/pgsql/parser/parser.pxd b/edb/pgsql/parser/parser.pxd new file mode 100644 index 00000000000..f5bdd79aac6 --- /dev/null +++ b/edb/pgsql/parser/parser.pxd @@ -0,0 +1,41 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2010-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from libc.stdint cimport uint8_t + +from edb.server.pgproto.pgproto cimport ( + ReadBuffer, + WriteBuffer, +) + + +cdef class Source: + cdef: + str _text + bytes _serialized + bytes _cache_key + + cdef WriteBuffer _serialize(self) + + +cdef class NormalizedSource(Source): + cdef: + str _orig_text + int _highest_extern_param_id + list _extracted_constants diff --git a/edb/pgsql/parser/parser.pyx b/edb/pgsql/parser/parser.pyx index 35aee3f811a..0f852185c8e 100644 --- a/edb/pgsql/parser/parser.pyx +++ b/edb/pgsql/parser/parser.pyx @@ -16,9 +16,27 @@ # limitations under the License. # +from typing import ( + Any, + NamedTuple, + Optional, +) + +import enum +import hashlib from .exceptions import PSqlParseError + +from edb.server.pgproto.pgproto cimport ( + FRBuffer, + ReadBuffer, + WriteBuffer, +) + +from libc.stdint cimport int8_t, uint8_t, int32_t + + cdef extern from "pg_query.h": ctypedef struct PgQueryError: char *message @@ -29,9 +47,40 @@ cdef extern from "pg_query.h": char *parse_tree PgQueryError *error - PgQueryParseResult pg_query_parse(const char* input) + ctypedef struct PgQueryNormalizeConstLocation: + int location + int length + int param_id + int token + char *val + + ctypedef struct PgQueryNormalizeResult: + char *normalized_query + PgQueryError *error + PgQueryNormalizeConstLocation *clocations + int clocations_count + int highest_extern_param_id + + PgQueryParseResult pg_query_parse(const char *input) + void pg_query_free_parse_result(PgQueryParseResult result) + + PgQueryNormalizeResult pg_query_normalize(const char *input) + void pg_query_free_normalize_result(PgQueryNormalizeResult result) - void pg_query_free_parse_result(PgQueryParseResult result); + +cdef extern from "protobuf/pg_query.pb-c.h": + ctypedef struct ProtobufCEnumValue: + const char *name + const char *c_name + int value + + ctypedef struct ProtobufCEnumDescriptor: + pass + + ProtobufCEnumDescriptor pg_query__token__descriptor + + const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value( + const ProtobufCEnumDescriptor *desc, int value) def pg_parse(query) -> str: @@ -49,3 +98,280 @@ def pg_parse(query) -> str: result_utf8 = result.parse_tree.decode('utf8') pg_query_free_parse_result(result) return result_utf8 + + +class LiteralTokenType(enum.StrEnum): + FCONST = "FCONST" + SCONST = "SCONST" + BCONST = "BCONST" + XCONST = "XCONST" + ICONST = "ICONST" + TRUE_P = "TRUE_P" + FALSE_P = "FALSE_P" + + +class PgLiteralTypeOID(enum.IntEnum): + BOOL = 16 + INT4 = 23 + TEXT = 25 + VARBIT = 1562 + NUMERIC = 1700 + + +class NormalizedQuery(NamedTuple): + text: str + highest_extern_param_id: int + extracted_constants: list[tuple[int, LiteralTokenType, bytes]] + + +def pg_normalize(query: str) -> NormalizedQuery: + cdef: + PgQueryNormalizeResult result + PgQueryNormalizeConstLocation loc + const ProtobufCEnumValue *token + int i + bytes queryb + bytes const + + queryb = query.encode("utf-8") + result = pg_query_normalize(queryb) + + try: + if result.error: + error = PSqlParseError( + result.error.message.decode('utf8'), + result.error.lineno, result.error.cursorpos + ) + raise error + + normalized_query = result.normalized_query.decode('utf8') + consts = [] + for i in range(result.clocations_count): + loc = result.clocations[i] + if loc.length != -1: + if loc.param_id < 0: + # Negative param_id means *relative* to highest explicit + # param id (after taking the absolute value). + param_id = ( + abs(loc.param_id) + + result.highest_extern_param_id + ) + else: + # Otherwise it's the absolute param id. + param_id = loc.param_id + if loc.val != NULL: + token = protobuf_c_enum_descriptor_get_value( + &pg_query__token__descriptor, loc.token) + if token == NULL: + raise RuntimeError( + f"could not lookup pg_query enum descriptor " + f"for token value {loc.token}" + ) + consts.append(( + param_id, + LiteralTokenType(bytes(token.name).decode("ascii")), + bytes(loc.val), + )) + + return NormalizedQuery( + text=normalized_query, + highest_extern_param_id=result.highest_extern_param_id, + extracted_constants=consts, + ) + finally: + pg_query_free_normalize_result(result) + + +cdef ReadBuffer _init_deserializer(serialized: bytes, tag: uint8_t, cls: str): + cdef ReadBuffer buf + + buf = ReadBuffer.new_message_parser(serialized) + + if buf.read_byte() != tag: + raise ValueError(f"malformed {cls} serialization") + + return buf + + +cdef class Source: + def __init__( + self, + text: str, + serialized: Optional[bytes] = None, + ) -> None: + self._text = text + if serialized is not None: + self._serialized = serialized + else: + self._serialized = b'' + self._cache_key = b'' + + @classmethod + def _tag(self) -> int: + return 0 + + cdef WriteBuffer _serialize(self): + cdef WriteBuffer buf = WriteBuffer.new() + buf.write_byte(self._tag()) + buf.write_len_prefixed_utf8(self._text) + return buf + + def serialize(self) -> bytes: + if not self._serialized: + self._serialized = bytes(self._serialize()) + return self._serialized + + @classmethod + def from_serialized(cls, serialized: bytes) -> NormalizedSource: + cdef ReadBuffer buf + + buf = _init_deserializer(serialized, cls._tag(), cls.__name__) + text = buf.read_len_prefixed_utf8() + + return Source(text, serialized) + + def text(self) -> str: + return self._text + + def cache_key(self) -> bytes: + if not self._cache_key: + self._cache_key = hashlib.blake2b(self.serialize()).digest() + return self._cache_key + + def variables(self) -> dict[str, Any]: + return {} + + def first_extra(self) -> Optional[int]: + return None + + def extra_counts(self) -> Sequence[int]: + return [] + + def extra_blobs(self) -> Sequence[bytes]: + return () + + def extra_formatted_as_text(self) -> bool: + return True + + def extra_type_oids(self) -> Sequence[int]: + return () + + @classmethod + def from_string(cls, text: str) -> Source: + return Source(text) + + +cdef class NormalizedSource(Source): + def __init__( + self, + normalized: NormalizedQuery, + orig_text: str, + serialized: Optional[bytes] = None, + ) -> None: + super().__init__(text=normalized.text, serialized=serialized) + self._extracted_constants = normalized.extracted_constants + self._highest_extern_param_id = normalized.highest_extern_param_id + self._orig_text = orig_text + + @classmethod + def _tag(cls) -> int: + return 1 + + cdef WriteBuffer _serialize(self): + cdef WriteBuffer buf + + buf = Source._serialize(self) + buf.write_len_prefixed_utf8(self._orig_text) + buf.write_int32(self._highest_extern_param_id) + buf.write_int32(len(self._extracted_constants)) + for param_id, token, val in self._extracted_constants: + buf.write_int32(param_id) + buf.write_len_prefixed_utf8(token.value) + buf.write_len_prefixed_bytes(val) + + return buf + + def variables(self) -> dict[str, bytes]: + return {f"${n}": v[1] for n, _, v in self._extracted_constants} + + def first_extra(self) -> Optional[int]: + return ( + self._highest_extern_param_id + if self._extracted_constants + else None + ) + + def extra_counts(self) -> Sequence[int]: + return [len(self._extracted_constants)] + + def extra_blobs(self) -> list[bytes]: + cdef WriteBuffer buf + buf = WriteBuffer.new() + for _, _, v in self._extracted_constants: + buf.write_len_prefixed_bytes(v) + + return [bytes(buf)] + + def extra_type_oids(self) -> Sequence[int]: + oids = [] + for _, token, _ in self._extracted_constants: + if token is LiteralTokenType.FCONST: + oids.append(PgLiteralTypeOID.NUMERIC) + elif token is LiteralTokenType.ICONST: + oids.append(PgLiteralTypeOID.INT4) + elif ( + token is LiteralTokenType.FALSE_P + or token is LiteralTokenType.TRUE_P + ): + oids.append(PgLiteralTypeOID.BOOL) + elif token is LiteralTokenType.SCONST: + oids.append(PgLiteralTypeOID.TEXT) + elif ( + token is LiteralTokenType.XCONST + or token is LiteralTokenType.BCONST + ): + oids.append(PgLiteralTypeOID.VARBIT) + else: + raise AssertionError(f"unexpected literal token type: {token}") + + return oids + + @classmethod + def from_string(cls, text: str) -> NormalizedSource: + normalized = pg_normalize(text) + return NormalizedSource(normalized, text) + + @classmethod + def from_serialized(cls, serialized: bytes) -> NormalizedSource: + cdef ReadBuffer buf + + buf = _init_deserializer(serialized, cls._tag(), cls.__name__) + text = buf.read_len_prefixed_utf8() + orig_text = buf.read_len_prefixed_utf8() + highest_extern_param_id = buf.read_int32() + n_constants = buf.read_int32() + consts = [] + for _ in range(n_constants): + param_id = buf.read_int32() + token = buf.read_len_prefixed_utf8() + val = buf.read_len_prefixed_bytes() + consts.append((param_id, LiteralTokenType(token), val)) + + return NormalizedSource( + NormalizedQuery( + text=text, + highest_extern_param_id=highest_extern_param_id, + extracted_constants=consts, + ), + orig_text, + serialized, + ) + + +def deserialize(serialized: bytes) -> Source: + if serialized[0] == 0: + return Source.from_serialized(serialized) + elif serialized[0] == 1: + return NormalizedSource.from_serialized(serialized) + + raise ValueError(f"Invalid type/version byte: {serialized[0]}") diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index ddf37165c15..1dc3c2e8aa1 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -18,6 +18,8 @@ from __future__ import annotations from typing import Optional, List + +import copy import dataclasses from edb.common import debug @@ -41,6 +43,10 @@ class ResolvedSQL: # AST representing the query that can be sent to PostgreSQL ast: pgast.Base + # Optionally, AST representing the query returning data in EdgeQL + # format (i.e. single-column output). + edgeql_output_format_ast: Optional[pgast.Base] + # Special behavior for "tag" of "CommandComplete" message of this query. command_complete_tag: Optional[dbstate.CommandCompleteTag] @@ -56,6 +62,7 @@ def resolve( if debug.flags.sql_input: debug.header('SQL Input') + debug_sql_text = pgcodegen.generate_source( query, reordered=True, pretty=True ) @@ -108,8 +115,25 @@ def resolve( ) debug.dump_code(debug_sql_text, lexer='sql') + if options.include_edgeql_io_format_alternative: + edgeql_output_format_ast = copy.copy(resolved) + if isinstance(edgeql_output_format_ast, pgast.SelectStmt): + edgeql_output_format_ast.target_list = [ + pgast.ResTarget( + val=pgast.RowExpr( + args=[ + rt.val + for rt in edgeql_output_format_ast.target_list + ] + ) + ) + ] + else: + edgeql_output_format_ast = None + return ResolvedSQL( ast=resolved, + edgeql_output_format_ast=edgeql_output_format_ast, command_complete_tag=command_complete_tag, params=ctx.query_params, ) diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index 592cee58cf9..303dbb07693 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -52,6 +52,9 @@ class Options: # apply access policies to select & dml statements apply_access_policies: bool + # whether to generate an EdgeQL-compatible single-column output variant. + include_edgeql_io_format_alternative: Optional[bool] + @dataclass(kw_only=True) class Scope: diff --git a/edb/protocol/messages.py b/edb/protocol/messages.py index 784fc2e7165..5985e5f02c5 100644 --- a/edb/protocol/messages.py +++ b/edb/protocol/messages.py @@ -488,6 +488,12 @@ def dump(self) -> bytes: ############################################################################### +class InputLanguage(enum.Enum): + + EDGEQL = 0x45 # b'E' + SQL = 0x53 # b'S' + + class OutputFormat(enum.Enum): BINARY = 0x62 @@ -789,6 +795,7 @@ class Parse(ClientMessage): compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') + input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') @@ -807,6 +814,7 @@ class Execute(ClientMessage): compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') + input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') diff --git a/edb/protocol/protocol.pyx b/edb/protocol/protocol.pyx index b91da984b0d..866e240e558 100644 --- a/edb/protocol/protocol.pyx +++ b/edb/protocol/protocol.pyx @@ -63,6 +63,7 @@ cdef class Connection: messages.Execute( annotations=[], command_text=query, + input_language=messages.InputLanguage.EDGEQL, output_format=messages.OutputFormat.NONE, expected_cardinality=messages.Cardinality.MANY, allowed_capabilities=messages.Capability.ALL, @@ -173,6 +174,7 @@ async def new_connection( tls_ca=tls_ca, tls_ca_file=tls_ca_file, tls_security=tls_security, + tls_server_name=None, wait_until_available=timeout, credentials=credentials, credentials_file=credentials_file, diff --git a/edb/server/compiler/__init__.py b/edb/server/compiler/__init__.py index 77be86bae03..3931622c080 100644 --- a/edb/server/compiler/__init__.py +++ b/edb/server/compiler/__init__.py @@ -27,7 +27,7 @@ from .compiler import maybe_force_database_error from .dbstate import QueryUnit, QueryUnitGroup from .enums import Capability, Cardinality -from .enums import InputFormat, OutputFormat +from .enums import InputFormat, OutputFormat, InputLanguage from .explain import analyze_explain_output from .ddl import repair_schema from .rpc import CompilationRequest @@ -43,6 +43,7 @@ 'QueryUnitGroup', 'Capability', 'InputFormat', + 'InputLanguage', 'OutputFormat', 'analyze_explain_output', 'compile_edgeql_script', diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index eb33544119e..9f833828ca8 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -118,7 +118,7 @@ class CompilerDatabaseState: cached_reflection: immutables.Map[str, Tuple[str, ...]] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class CompileContext: compiler_state: CompilerState @@ -145,6 +145,8 @@ class CompileContext: log_ddl_as_migrations: bool = True dump_restore_mode: bool = False notebook: bool = False + branch_name: Optional[str] = None + role_name: Optional[str] = None cache_key: Optional[uuid.UUID] = None def get_cache_mode(self) -> config.QueryCacheMode: @@ -583,7 +585,7 @@ def compile_serialized_request( self.state.compilation_config_serializer, ) - units, cstate = self.compile( + return self.compile( user_schema=user_schema, global_schema=global_schema, reflection_cache=reflection_cache, @@ -591,7 +593,6 @@ def compile_serialized_request( system_config=system_config, request=request, ) - return units, cstate def compile( self, @@ -641,10 +642,21 @@ def compile( json_parameters=request.input_format is enums.InputFormat.JSON, source=request.source, protocol_version=request.protocol_version, + role_name=request.role_name, + branch_name=request.branch_name, cache_key=request.get_cache_key(), ) - unit_group = compile(ctx=ctx, source=request.source) + match request.input_language: + case enums.InputLanguage.EDGEQL: + unit_group = compile(ctx=ctx, source=request.source) + case enums.InputLanguage.SQL: + unit_group = compile_sql_as_unit_group( + ctx=ctx, source=request.source) + case _: + raise NotImplementedError( + f"unnsupported input language: {request.input_language}") + tx_started = False for unit in unit_group: if unit.tx_id: @@ -727,8 +739,17 @@ def compile_in_tx( cache_key=request.get_cache_key(), ) - units = compile(ctx=ctx, source=request.source) - return units, ctx.state + match request.input_language: + case enums.InputLanguage.EDGEQL: + unit_group = compile(ctx=ctx, source=request.source) + case enums.InputLanguage.SQL: + unit_group = compile_sql_as_unit_group( + ctx=ctx, source=request.source) + case _: + raise NotImplementedError( + f"unnsupported input language: {request.input_language}") + + return unit_group, ctx.state def interpret_backend_error( self, @@ -2343,6 +2364,129 @@ def compile( raise original_err +def compile_sql_as_unit_group( + *, + ctx: CompileContext, + source: edgeql.Source, +) -> dbstate.QueryUnitGroup: + + setting = _get_config_val(ctx, 'allow_user_specified_id') + allow_user_specified_id = None + if setting: + allow_user_specified_id = sql.is_setting_truthy(setting) + + apply_access_policies_sql = None + setting = _get_config_val(ctx, 'apply_access_policies_sql') + if setting: + apply_access_policies_sql = sql.is_setting_truthy(setting) + + tx_state = ctx.state.current_tx() + schema = tx_state.get_schema(ctx.compiler_state.std_schema) + + settings = dbstate.DEFAULT_SQL_FE_SETTINGS + sql_tx_state = dbstate.SQLTransactionState( + in_tx=not tx_state.is_implicit(), + settings=settings, + in_tx_settings=settings, + in_tx_local_settings=settings, + savepoints=[ + (not_none(tx.name), settings, settings) + for tx in tx_state._savepoints.values() + ], + ) + + sql_units = sql.compile_sql( + source.text(), + schema=schema, + tx_state=sql_tx_state, + prepared_stmt_map={}, + current_database=ctx.branch_name or "", + current_user=ctx.role_name or "", + allow_user_specified_id=allow_user_specified_id, + apply_access_policies_sql=apply_access_policies_sql, + include_edgeql_io_format_alternative=True, + allow_prepared_statements=False, + ) + + qug = dbstate.QueryUnitGroup( + cardinality=sql_units[-1].cardinality, + cacheable=False, + ) + + for sql_unit in sql_units: + if sql_unit.eql_format_query is not None: + value_sql = sql_unit.eql_format_query.encode("utf-8") + intro_sql = sql_unit.query.encode("utf-8") + else: + value_sql = sql_unit.query.encode("utf-8") + intro_sql = None + if isinstance(sql_unit.command_complete_tag, dbstate.TagPlain): + status = sql_unit.command_complete_tag.tag + elif isinstance( + sql_unit.command_complete_tag, + (dbstate.TagCountMessages, dbstate.TagUnpackRow), + ): + status = sql_unit.command_complete_tag.prefix.encode("utf-8") + elif sql_unit.command_complete_tag is None: + status = b"SELECT" # XXX + else: + raise AssertionError( + f"unexpected SQLQueryUnit.command_complete_tag type: " + f"{sql_unit.command_complete_tag}" + ) + unit = dbstate.QueryUnit( + sql=value_sql, + introspection_sql=intro_sql, + status=status, + cardinality=sql_unit.cardinality, + capabilities=sql_unit.capabilities, + globals=[ + (str(sp.global_name), False) for sp in sql_unit.params + if isinstance(sp, dbstate.SQLParamGlobal) + ] if sql_unit.params else [], + output_format=( + enums.OutputFormat.NONE + if sql_unit.cardinality is enums.Cardinality.NO_RESULT + else enums.OutputFormat.BINARY + ), + ) + match sql_unit.tx_action: + case dbstate.TxAction.START: + ctx.state.start_tx() + tx_state = ctx.state.current_tx() + unit.tx_id = tx_state.id + case dbstate.TxAction.COMMIT: + ctx.state.commit_tx() + unit.tx_commit = True + case dbstate.TxAction.ROLLBACK: + ctx.state.rollback_tx() + unit.tx_rollback = True + case dbstate.TxAction.DECLARE_SAVEPOINT: + assert sql_unit.sp_name is not None + unit.tx_savepoint_declare = True + unit.sp_id = tx_state.declare_savepoint(sql_unit.sp_name) + unit.sp_name = sql_unit.sp_name + case dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: + assert sql_unit.sp_name is not None + tx_state.rollback_to_savepoint(sql_unit.sp_name) + unit.tx_savepoint_rollback = True + unit.sp_name = sql_unit.sp_name + case dbstate.TxAction.RELEASE_SAVEPOINT: + assert sql_unit.sp_name is not None + tx_state.release_savepoint(sql_unit.sp_name) + unit.sp_name = sql_unit.sp_name + case None: + pass + case _: + raise AssertionError( + f"unexpected SQLQueryUnit.tx_action: {sql_unit.tx_action}" + ) + + qug.append(unit) + + return qug + + def _try_compile( *, ctx: CompileContext, diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index ea44756af78..c442dc9dbe2 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -208,6 +208,8 @@ class Param: class QueryUnit: sql: bytes + introspection_sql: Optional[bytes] = None + # Status-line for the compiled command; returned to front-end # in a CommandComplete protocol message if the command is # executed successfully. When a QueryUnit contains multiple @@ -514,11 +516,22 @@ class SQLQueryUnit: query: str = dataclasses.field(repr=False) """Translated query text.""" + translation_data: Optional[pgcodegen.TranslationData] = None + """Translation source map.""" + + eql_format_query: Optional[str] = dataclasses.field( + repr=False, default=None) + """Translated query text returning data in single-column format.""" + + eql_format_translation_data: Optional[pgcodegen.TranslationData] = None + """Translation source map for single-column format query.""" + orig_query: str = dataclasses.field(repr=False) """Original query text before translation.""" - translation_data: Optional[pgcodegen.TranslationData] = None - """Translation source map.""" + cardinality: enums.Cardinality = enums.Cardinality.NO_RESULT + + capabilities: enums.Capability = enums.Capability.NONE fe_settings: SQLSettings """Frontend-only settings effective during translation of this unit.""" diff --git a/edb/server/compiler/ddl.py b/edb/server/compiler/ddl.py index fe1dd279839..76d443dbe26 100644 --- a/edb/server/compiler/ddl.py +++ b/edb/server/compiler/ddl.py @@ -262,7 +262,7 @@ def _compile_and_apply_ddl_stmt( (SELECT json_object_agg( "id"::text, - "backend_id" + json_build_array("backend_id", "name") ) FROM edgedb_VER."_SchemaType" diff --git a/edb/server/compiler/enums.py b/edb/server/compiler/enums.py index 0e1fb3ff4fc..15231ba8808 100644 --- a/edb/server/compiler/enums.py +++ b/edb/server/compiler/enums.py @@ -85,6 +85,11 @@ class InputFormat(strenum.StrEnum): JSON = 'JSON' +class InputLanguage(strenum.StrEnum): + EDGEQL = 'EDGEQL' + SQL = 'SQL' + + def cardinality_from_ir_value(card: ir.Cardinality) -> Cardinality: if card is ir.Cardinality.AT_MOST_ONE: return Cardinality.AT_MOST_ONE diff --git a/edb/server/compiler/rpc.pxd b/edb/server/compiler/rpc.pxd index afb0c01ec77..f8aca79144d 100644 --- a/edb/server/compiler/rpc.pxd +++ b/edb/server/compiler/rpc.pxd @@ -20,14 +20,17 @@ cimport cython cdef char serialize_output_format(val) cdef deserialize_output_format(char mode) +cdef char serialize_input_language(val) +cdef deserialize_input_language(char mode) @cython.final cdef class CompilationRequest: cdef: - object _serializer + object serializer readonly object source readonly object protocol_version + readonly object input_language readonly object output_format readonly object input_format readonly bint expect_one @@ -35,6 +38,8 @@ cdef class CompilationRequest: readonly bint inline_typeids readonly bint inline_typenames readonly bint inline_objectids + readonly str role_name + readonly str branch_name readonly object modaliases readonly object session_config diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index 47b6d2fe464..baada6b7d71 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -28,6 +28,7 @@ from edb.server.compiler import sertypes, enums class CompilationRequest: source: edgeql.Source protocol_version: defines.ProtocolVersion + input_language: enums.InputLanguage output_format: enums.OutputFormat input_format: enums.InputFormat expect_one: bool @@ -35,6 +36,8 @@ class CompilationRequest: inline_typeids: bool inline_typenames: bool inline_objectids: bool + role_name: str + branch_name: str modaliases: immutables.Map[str | None, str] | None session_config: immutables.Map[str, config.SettingValue] | None @@ -46,6 +49,7 @@ class CompilationRequest: protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, + input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = enums.OutputFormat.BINARY, input_format: enums.InputFormat = enums.InputFormat.BINARY, expect_one: bool = False, @@ -57,6 +61,8 @@ class CompilationRequest: session_config: typing.Mapping[str, config.SettingValue] | None = None, database_config: typing.Mapping[str, config.SettingValue] | None = None, system_config: typing.Mapping[str, config.SettingValue] | None = None, + role_name: str = defines.EDGEDB_SUPERUSER, + branch_name: str = defines.EDGEDB_SUPERUSER_DB, ): ... diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index d510e131774..856fde06462 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -32,6 +32,7 @@ from edb.edgeql import qltypes from edb.edgeql import tokenizer from edb.server import config, defines from edb.server.pgproto.pgproto cimport WriteBuffer, ReadBuffer +from edb.pgsql import parser as pgparser from . import enums, sertypes @@ -43,6 +44,9 @@ cdef object OUT_FMT_NONE = enums.OutputFormat.NONE cdef object IN_FMT_BINARY = enums.InputFormat.BINARY cdef object IN_FMT_JSON = enums.InputFormat.JSON +cdef object IN_LANG_EDGEQL = enums.InputLanguage.EDGEQL +cdef object IN_LANG_SQL = enums.InputLanguage.SQL + cdef char MASK_JSON_PARAMETERS = 1 << 0 cdef char MASK_EXPECT_ONE = 1 << 1 cdef char MASK_INLINE_TYPEIDS = 1 << 2 @@ -74,7 +78,26 @@ cdef deserialize_output_format(char mode): return OUT_FMT_NONE else: raise errors.BinaryProtocolError( - f'unknown output mode "{repr(mode)[2:-1]}"') + f'unknown output format {mode.to_bytes(1, "big")!r}') + + +cdef char serialize_input_language(val): + if val is IN_LANG_EDGEQL: + return b'E' + elif val is IN_LANG_SQL: + return b'S' + else: + raise AssertionError("unreachable") + + +cdef deserialize_input_language(char lang): + if lang == b'E': + return IN_LANG_EDGEQL + elif lang == b'S': + return IN_LANG_SQL + else: + raise errors.BinaryProtocolError( + f'unknown input language {lang.to_bytes(1, "big")!r}') @cython.final @@ -86,6 +109,7 @@ cdef class CompilationRequest: protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, + input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = OUT_FMT_BINARY, input_format: enums.InputFormat = IN_FMT_BINARY, expect_one: bint = False, @@ -97,10 +121,13 @@ cdef class CompilationRequest: session_config: Mapping[str, config.SettingValue] | None = None, database_config: Mapping[str, config.SettingValue] | None = None, system_config: Mapping[str, config.SettingValue] | None = None, + role_name: str = defines.EDGEDB_SUPERUSER, + branch_name: str = defines.EDGEDB_SUPERUSER_DB, ): - self._serializer = compilation_config_serializer + self.serializer = compilation_config_serializer self.source = source self.protocol_version = protocol_version + self.input_language = input_language self.output_format = output_format self.input_format = input_format self.expect_one = expect_one @@ -113,6 +140,8 @@ cdef class CompilationRequest: self.session_config = session_config self.database_config = database_config self.system_config = system_config + self.role_name = role_name + self.branch_name = branch_name self.serialized_cache = None self.cache_key = None @@ -124,7 +153,8 @@ cdef class CompilationRequest: source=self.source, protocol_version=self.protocol_version, schema_version=self.schema_version, - compilation_config_serializer=self._serializer, + compilation_config_serializer=self.serializer, + input_language=self.input_language, output_format=self.output_format, input_format=self.input_format, expect_one=self.expect_one, @@ -136,6 +166,8 @@ cdef class CompilationRequest: session_config=self.session_config, database_config=self.database_config, system_config=self.system_config, + role_name=self.role_name, + branch_name=self.branch_name, ) rv.serialized_cache = self.serialized_cache rv.cache_key = self.cache_key @@ -178,15 +210,8 @@ cdef class CompilationRequest: query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ) -> CompilationRequest: - buf = ReadBuffer.new_message_parser(data) - - if data[0] == 0: - return _deserialize_comp_req_v0( - buf, query_text, compilation_config_serializer) - else: - raise errors.UnsupportedProtocolVersionError( - f"unsupported compile cache: version {data[0]}" - ) + return _deserialize_comp_req( + data, query_text, compilation_config_serializer) def serialize(self) -> bytes: if self.serialized_cache is None: @@ -199,8 +224,12 @@ cdef class CompilationRequest: return self.cache_key cdef _serialize(self): - cache_key, buf = _serialize_comp_req_v0(self) - self.cache_key = cache_key + cdef WriteBuffer buf + + hash_obj, buf = _serialize_comp_req(self) + cache_key = hash_obj.digest() + buf.write_bytes(cache_key) + self.cache_key = uuidgen.from_bytes(cache_key) self.serialized_cache = bytes(buf) def __hash__(self): @@ -210,17 +239,44 @@ cdef class CompilationRequest: return ( self.source.cache_key() == other.source.cache_key() and self.protocol_version == other.protocol_version and + self.input_language == other.input_language and self.output_format == other.output_format and self.input_format == other.input_format and self.expect_one == other.expect_one and self.implicit_limit == other.implicit_limit and self.inline_typeids == other.inline_typeids and self.inline_typenames == other.inline_typenames and - self.inline_objectids == other.inline_objectids + self.inline_objectids == other.inline_objectids and + self.role_name == other.role_name and + self.branch_name == other.branch_name + ) + + +cdef CompilationRequest _deserialize_comp_req( + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, +): + cdef: + ReadBuffer buf = ReadBuffer.new_message_parser(data) + CompilationRequest req + + if data[0] == 1: + req = _deserialize_comp_req_v1( + buf, query_text, compilation_config_serializer) + else: + raise errors.UnsupportedProtocolVersionError( + f"unsupported compile cache: version {data[0]}" ) + # Cache key is always trailing regardless of the version. + req.cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + req.serialized_cache = data + + return req -cdef _deserialize_comp_req_v0( + +cdef _deserialize_comp_req_v1( buf: ReadBuffer, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, @@ -249,21 +305,14 @@ cdef _deserialize_comp_req_v0( # * Session config: int32-length-prefixed serialized data # * Serialized Source or NormalizedSource without the original query # string - # * 16-byte cache key = BLAKE-2b hash of: - # * All above serialized, - # * Except that the source is replaced with Source.cache_key(), and - # * Except that the serialized session config is replaced by - # serialized combined config (session -> database -> system) - # that only affects compilation. - # * The schema version - # * OPTIONALLY, the schema version. We wanted to bump the protocol - # version to include this, but 5.x hard crashes when it reads a - # persistent cache with entries it doesn't understand, so instead - # we stick it on the end where it will be ignored by old versions. + # * The schema version ID. + # * 1 byte input language (the same as in the binary protocol) + # * role_name as a UTF-8 encoded string + # * branch_name as a UTF-8 encoded string cdef char flags - assert buf.read_byte() == 0 # version + assert buf.read_byte() == 1 # version flags = buf.read_byte() if flags & MASK_JSON_PARAMETERS > 0: @@ -318,18 +367,29 @@ cdef _deserialize_comp_req_v0( else: session_config = None - source = tokenizer.deserialize( - buf.read_len_prefixed_bytes(), query_text - ) - - cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + serialized_source = buf.read_len_prefixed_bytes() schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + input_language = deserialize_input_language(buf.read_byte()) + role_name = buf.read_len_prefixed_utf8() + branch_name = buf.read_len_prefixed_utf8() + + if input_language is enums.InputLanguage.EDGEQL: + source = tokenizer.deserialize(serialized_source, query_text) + elif input_language is enums.InputLanguage.SQL: + source = pgparser.deserialize(serialized_source) + else: + raise AssertionError( + f"unexpected source language in serialized " + f"CompilationRequest: {input_language}" + ) + req = CompilationRequest( source=source, protocol_version=protocol_version, schema_version=schema_version, compilation_config_serializer=serializer, + input_language=input_language, output_format=output_format, input_format=input_format, expect_one=expect_one, @@ -339,19 +399,18 @@ cdef _deserialize_comp_req_v0( inline_objectids=inline_objectids, modaliases=modaliases, session_config=session_config, + role_name=role_name, + branch_name=branch_name, ) - req.serialized_cache = data - req.cache_key = cache_key - return req -cdef _serialize_comp_req_v0(req: CompilationRequest): - # Please see _deserialize_v0 for the format doc +cdef _serialize_comp_req(req: CompilationRequest): + # Please see _deserialize_comp_req_v1 for the format doc cdef: - char version = 0, flags + char version = 1, flags WriteBuffer out = WriteBuffer.new() out.write_byte(version) @@ -385,7 +444,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): out.write_str(k, "utf-8") out.write_str(v, "utf-8") - type_id, desc = req._serializer.describe() + type_id, desc = req.serializer.describe() out.write_bytes(type_id.bytes) out.write_len_prefixed_bytes(desc) @@ -395,7 +454,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): if req.session_config is None: session_config = b"" else: - session_config = req._serializer.encode_configs( + session_config = req.serializer.encode_configs( req.session_config ) out.write_len_prefixed_bytes(session_config) @@ -403,7 +462,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): # Build config that affects compilation: session -> database -> system. # This is only used for calculating cache_key, while session # config itreq is separately stored above in the serialized format. - serialized_comp_config = req._serializer.encode_configs( + serialized_comp_config = req.serializer.encode_configs( req.system_config, req.database_config, req.session_config ) hash_obj.update(serialized_comp_config) @@ -412,11 +471,18 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): assert req.schema_version is not None hash_obj.update(req.schema_version.bytes) - cache_key_bytes = hash_obj.digest() - cache_key = uuidgen.from_bytes(cache_key_bytes) - out.write_len_prefixed_bytes(req.source.serialize()) - out.write_bytes(cache_key_bytes) out.write_bytes(req.schema_version.bytes) - return cache_key, out + out.write_byte(serialize_input_language(req.input_language)) + hash_obj.update(req.input_language.value.encode("utf-8")) + + role_name = req.role_name.encode("utf-8") + out.write_len_prefixed_bytes(role_name) + hash_obj.update(role_name) + + branch_name = req.branch_name.encode("utf-8") + out.write_len_prefixed_bytes(branch_name) + hash_obj.update(branch_name) + + return hash_obj, out diff --git a/edb/server/compiler/sertypes.py b/edb/server/compiler/sertypes.py index 0e562c235ff..fd440218730 100644 --- a/edb/server/compiler/sertypes.py +++ b/edb/server/compiler/sertypes.py @@ -374,7 +374,7 @@ def _describe_tuple(t: s_types.Tuple, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -420,7 +420,7 @@ def _describe_array(t: s_types.Array, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -459,7 +459,7 @@ def _describe_range(t: s_types.Range, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -494,7 +494,7 @@ def _describe_multirange(t: s_types.MultiRange, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index 445abd21356..fa7529540b6 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Tuple, Mapping, Sequence, List, TYPE_CHECKING, Optional +from typing import Mapping, Sequence, List, TYPE_CHECKING, Optional import dataclasses import functools @@ -35,6 +35,7 @@ from edb.pgsql import parser as pg_parser from . import dbstate +from . import enums if TYPE_CHECKING: from edb.pgsql import resolver as pg_resolver @@ -62,6 +63,8 @@ def compile_sql( current_user: str, allow_user_specified_id: Optional[bool], apply_access_policies_sql: Optional[bool], + include_edgeql_io_format_alternative: bool = False, + allow_prepared_statements: bool = True, ) -> List[dbstate.SQLQueryUnit]: opts = ResolverOptionsPartial( query_str=query_str, @@ -69,6 +72,8 @@ def compile_sql( current_user=current_user, allow_user_specified_id=allow_user_specified_id, apply_access_policies_sql=apply_access_policies_sql, + include_edgeql_io_format_alternative=( + include_edgeql_io_format_alternative), ) stmts = pg_parser.parse(query_str, propagate_spans=True) @@ -118,6 +123,8 @@ def compile_sql( unit.set_vars = {stmt.name: value} unit.is_local = stmt.scope == pgast.OptionsScope.TRANSACTION + if not unit.is_local: + unit.capabilities |= enums.Capability.SESSION_CONFIG elif isinstance(stmt, pgast.VariableShowStmt): unit.get_var = stmt.name @@ -143,33 +150,45 @@ def compile_sql( elif isinstance(stmt, (pgast.BeginStmt, pgast.StartStmt)): unit.tx_action = dbstate.TxAction.START + unit.command_complete_tag = dbstate.TagPlain( + tag=b"START TRANSACTION") elif isinstance(stmt, pgast.CommitStmt): unit.tx_action = dbstate.TxAction.COMMIT unit.tx_chain = stmt.chain or False + unit.command_complete_tag = dbstate.TagPlain(tag=b"COMMIT") elif isinstance(stmt, pgast.RollbackStmt): unit.tx_action = dbstate.TxAction.ROLLBACK unit.tx_chain = stmt.chain or False + unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.SavepointStmt): unit.tx_action = dbstate.TxAction.DECLARE_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"SAVEPOINT") elif isinstance(stmt, pgast.ReleaseStmt): unit.tx_action = dbstate.TxAction.RELEASE_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"RELEASE") elif isinstance(stmt, pgast.RollbackToStmt): unit.tx_action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.TwoPhaseTransactionStmt): raise NotImplementedError( "two-phase transactions are not supported" ) elif isinstance(stmt, pgast.PrepareStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) + # Translate the underlying query. - stmt_resolved, stmt_source = resolve_query( + stmt_resolved, stmt_source, _ = resolve_query( stmt.query, schema, tx_state, opts ) if stmt.argtypes: @@ -202,6 +221,11 @@ def compile_sql( unit.command_complete_tag = dbstate.TagPlain(tag=b"PREPARE") elif isinstance(stmt, pgast.ExecuteStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) + orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: @@ -216,7 +240,12 @@ def compile_sql( stmt_name=orig_name, be_stmt_name=mangled_name.encode("utf-8"), ) + unit.cardinality = enums.Cardinality.MANY elif isinstance(stmt, pgast.DeallocateStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: @@ -238,19 +267,35 @@ def compile_sql( raise NotImplementedError("exclusive lock is not supported") # just ignore unit.query = "DO $$ BEGIN END $$;" - else: - assert isinstance(stmt, (pgast.Query, pgast.CopyStmt)) - stmt_resolved, stmt_source = resolve_query( + elif isinstance(stmt, (pgast.Query, pgast.CopyStmt)): + stmt_resolved, stmt_source, edgeql_fmt_src = resolve_query( stmt, schema, tx_state, opts ) - unit.query = stmt_source.text unit.translation_data = stmt_source.translation_data + if edgeql_fmt_src is not None: + unit.eql_format_query = edgeql_fmt_src.text + unit.eql_format_translation_data = ( + edgeql_fmt_src.translation_data) unit.command_complete_tag = stmt_resolved.command_complete_tag unit.params = stmt_resolved.params + if isinstance(stmt, pgast.DMLQuery) and not stmt.returning_list: + unit.cardinality = enums.Cardinality.NO_RESULT + else: + unit.cardinality = enums.Cardinality.MANY + else: + raise errors.UnsupportedFeatureError( + f"SQL {stmt.__class__.__name__} is not supported" + ) unit.stmt_name = compute_stmt_name(unit.query, tx_state).encode("utf-8") + if isinstance(stmt, pgast.DMLQuery): + unit.capabilities |= enums.Capability.MODIFICATIONS + + if unit.tx_action is not None: + unit.capabilities |= enums.Capability.TRANSACTION + tx_state.apply(unit) sql_units.append(unit) @@ -274,6 +319,7 @@ class ResolverOptionsPartial: query_str: str allow_user_specified_id: Optional[bool] apply_access_policies_sql: Optional[bool] + include_edgeql_io_format_alternative: Optional[bool] def resolve_query( @@ -281,7 +327,11 @@ def resolve_query( schema: s_schema.Schema, tx_state: dbstate.SQLTransactionState, opts: ResolverOptionsPartial, -) -> Tuple[pg_resolver.ResolvedSQL, pg_codegen.SQLSource]: +) -> tuple[ + pg_resolver.ResolvedSQL, + pg_codegen.SQLSource, + Optional[pg_codegen.SQLSource], +]: from edb.pgsql import resolver as pg_resolver search_path: Sequence[str] = ("public",) @@ -314,10 +364,19 @@ def resolve_query( search_path=search_path, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, + include_edgeql_io_format_alternative=( + opts.include_edgeql_io_format_alternative), ) resolved = pg_resolver.resolve(stmt, schema, options) source = pg_codegen.generate(resolved.ast, with_translation_data=True) - return resolved, source + if resolved.edgeql_output_format_ast is not None: + edgeql_format_source = pg_codegen.generate( + resolved.edgeql_output_format_ast, + with_translation_data=True, + ) + else: + edgeql_format_source = None + return resolved, source, edgeql_format_source def lookup_bool_setting( diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 45598caefb7..4e3342912cb 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -40,6 +40,8 @@ cdef class CompiledQuery: cdef public object first_extra # Optional[int] cdef public object extra_counts cdef public object extra_blobs + cdef public bint extra_formatted_as_text + cdef public object extra_type_oids cdef public object request cdef public object recompiled_cache cdef public bint use_pending_func_cache @@ -90,6 +92,7 @@ cdef class Database: readonly bytes user_schema_pickle readonly object reflection_cache readonly object backend_ids + readonly object backend_id_to_name readonly object extensions readonly object _feature_used_metrics diff --git a/edb/server/dbview/dbview.pyi b/edb/server/dbview/dbview.pyi index 6a0d59fefd4..710eb09db3f 100644 --- a/edb/server/dbview/dbview.pyi +++ b/edb/server/dbview/dbview.pyi @@ -182,7 +182,7 @@ class DatabaseIndex: schema_version: Optional[uuid.UUID], db_config: Optional[Config], reflection_cache: Optional[Mapping[str, tuple[str, ...]]], - backend_ids: Optional[Mapping[str, int]], + backend_ids: Optional[Mapping[str, tuple[int, str]]], extensions: Optional[set[str]], ext_config_settings: Optional[list[config.Setting]], early: bool = False, diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index f91f34763b1..b2be2587ec9 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -40,7 +40,8 @@ from edb.common import debug, lru, uuidgen, asyncutil from edb import edgeql from edb.edgeql import qltypes from edb.schema import schema as s_schema -from edb.server import compiler, defines, config, metrics +from edb.schema import name as s_name +from edb.server import compiler, defines, config, metrics, pgcon from edb.server.compiler import dbstate, enums, sertypes from edb.server.protocol import execute from edb.pgsql import dbops @@ -96,6 +97,8 @@ cdef class CompiledQuery: first_extra: Optional[int]=None, extra_counts=(), extra_blobs=(), + extra_formatted_as_text: bool = False, + extra_type_oids: Sequence[int] = (), request=None, recompiled_cache=None, use_pending_func_cache=False, @@ -104,6 +107,8 @@ cdef class CompiledQuery: self.first_extra = first_extra self.extra_counts = extra_counts self.extra_blobs = extra_blobs + self.extra_formatted_as_text = extra_formatted_as_text + self.extra_type_oids = tuple(extra_type_oids) self.request = request self.recompiled_cache = recompiled_cache self.use_pending_func_cache = use_pending_func_cache @@ -160,6 +165,12 @@ cdef class Database: self.user_config_spec = config.FlatSpec(*ext_config_settings) self.reflection_cache = reflection_cache self.backend_ids = backend_ids + if backend_ids is not None: + self.backend_id_to_name = { + v[0]: v[1] for k, v in backend_ids.items() + } + else: + self.backend_id_to_name = {} self.extensions = set() self._set_extensions(extensions) self._observe_auth_ext_config() @@ -367,6 +378,9 @@ cdef class Database: if backend_ids is not None: self.backend_ids = backend_ids + self.backend_id_to_name = { + v[0]: v[1] for k, v in backend_ids.items() + } if reflection_cache is not None: self.reflection_cache = reflection_cache if db_config is not None: @@ -402,6 +416,9 @@ cdef class Database: cdef _update_backend_ids(self, new_types): self.backend_ids.update(new_types) + self.backend_id_to_name.update({ + v[0]: v[1] for k, v in new_types.items() + }) cdef _invalidate_caches(self): self._sql_to_compiled.clear() @@ -727,15 +744,18 @@ cdef class DatabaseConnectionView: if self._in_tx: try: - return int(self._in_tx_new_types[type_id]) + tinfo = self._in_tx_new_types[type_id] except KeyError: pass + else: + return int(tinfo[0]) - tid = self._db.backend_ids.get(type_id) - if tid is None: + tinfo = self._db.backend_ids.get(type_id) + if tinfo is None: raise RuntimeError( f'cannot resolve backend OID for type {type_id}') - return tid + + return int(tinfo[0]) cdef bytes serialize_state(self): cdef list state @@ -1220,9 +1240,10 @@ cdef class DatabaseConnectionView: async def parse( self, query_req: rpc.CompilationRequest, - cached_globally=False, - bint use_metrics=True, - uint64_t allow_capabilities = compiler.Capability.ALL, + cached_globally: bint = False, + use_metrics: bint = True, + allow_capabilities: uint64_t = compiler.Capability.ALL, + pgcon: pgcon.PGConnection | None = None, ) -> CompiledQuery: query_unit_group = None if self._query_cache_enabled: @@ -1307,6 +1328,18 @@ cdef class DatabaseConnectionView: ) self._check_in_tx_error(query_unit_group) + if query_req.input_language is enums.InputLanguage.SQL: + if pgcon is None: + raise errors.InternalServerError( + "a valid backend connection is required to fully " + "compile a query in SQL mode", + ) + await self._amend_typedesc_in_sql( + query_req, + query_unit_group, + pgcon, + ) + if self._query_cache_enabled and query_unit_group.cacheable: if cached_globally: self.server.system_compile_cache[query_req] = ( @@ -1366,10 +1399,112 @@ cdef class DatabaseConnectionView: first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), + extra_formatted_as_text=source.extra_formatted_as_text(), + extra_type_oids=source.extra_type_oids(), request=query_req, recompiled_cache=recompiled_cache, ) + async def _amend_typedesc_in_sql( + self, + query_req: rpc.CompilationRequest, + qug: dbstate.QueryUnitGroup, + pgcon: pgcon.PGConnection, + ) -> None: + # The SQL QueryUnitGroup as initially returned from the compiler + # is missing the input/output type descriptors because we currently + # don't run static SQL type inference. To mend that we ask Postgres + # to infer the the result types (as an OID tuple) and then use + # our OID -> scalar type mapping to construct an EdgeQL free shape with + # corresponding properties which we then send to the compiler to + # compute the type descriptors. + to_describe = [] + + desc_map = {} + source = query_req.source + first_extra = source.first_extra() + if first_extra is not None: + all_type_oids = [0] * first_extra + source.extra_type_oids() + else: + all_type_oids = [] + + for i, query_unit in enumerate(qug): + if query_unit.cardinality is enums.Cardinality.NO_RESULT: + continue + + intro_sql = query_unit.introspection_sql + if intro_sql is None: + intro_sql = query_unit.sql[0] + param_desc, result_desc = await pgcon.sql_describe( + intro_sql, all_type_oids) + result_types = [] + for col, toid in result_desc: + edb_type_expr = self._db.backend_id_to_name.get(toid) + if edb_type_expr is None: + raise errors.UnsupportedFeatureError( + f"unsupported SQL type in column \"{col}\"" + f"with type OID {toid}" + ) + + result_types.append( + f"{edgeql.quote_ident(col)} := <{edb_type_expr}>{{}}" + ) + if first_extra is not None: + param_desc = param_desc[:first_extra] + params = [] + for pi, toid in enumerate(param_desc): + edb_type_expr = self._db.backend_id_to_name.get(toid) + if edb_type_expr is None: + raise errors.UnsupportedFeatureError( + f"unsupported type in SQL parameter ${pi} " + f"with type OID {toid}" + ) + + params.append( + f"_p{pi} := <{edb_type_expr}>${pi}" + ) + + intro_qry = "" + if params: + intro_qry += "with _p := {" + ", ".join(params) + "} " + + if result_types: + intro_qry += "select {" + ", ".join(result_types) + "}" + else: + # No direct syntactic way of constructing an empty shape, + # so we have to do it this way. + intro_qry += "select {foo := 1}{}" + to_describe.append(intro_qry) + + desc_map[len(to_describe) - 1] = i + + if to_describe: + desc_req = rpc.CompilationRequest( + source=edgeql.Source.from_string(";\n".join(to_describe)), + protocol_version=query_req.protocol_version, + schema_version=query_req.schema_version, + compilation_config_serializer=query_req.serializer, + ) + + desc_qug = await self._compile(desc_req) + + for i, desc_qu in enumerate(desc_qug): + qu_i = desc_map[i] + qug[qu_i].out_type_data = desc_qu.out_type_data + qug[qu_i].out_type_id = desc_qu.out_type_id + qug[qu_i].in_type_data = desc_qu.in_type_data + qug[qu_i].in_type_id = desc_qu.in_type_id + qug[qu_i].in_type_args = desc_qu.in_type_args + qug[qu_i].in_type_args_real_count = ( + desc_qu.in_type_args_real_count) + + qug.out_type_data = desc_qug.out_type_data + qug.out_type_id = desc_qug.out_type_id + qug.in_type_data = desc_qug.in_type_data + qug.in_type_id = desc_qug.in_type_id + qug.in_type_args = desc_qug.in_type_args + qug.in_type_args_real_count = desc_qug.in_type_args_real_count + cdef inline _check_in_tx_error(self, query_unit_group): if self.in_tx_error(): # The current transaction is aborted, so we must fail @@ -1404,6 +1539,8 @@ cdef class DatabaseConnectionView: first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), + extra_formatted_as_text=source.extra_formatted_as_text(), + extra_type_oids=source.extra_type_oids(), use_pending_func_cache=use_pending_func_cache, ) diff --git a/edb/server/defines.py b/edb/server/defines.py index bbfa973bc74..a8e1d746d69 100644 --- a/edb/server/defines.py +++ b/edb/server/defines.py @@ -81,7 +81,7 @@ ProtocolVersion: TypeAlias = tuple[int, int] MIN_PROTOCOL: ProtocolVersion = (1, 0) -CURRENT_PROTOCOL: ProtocolVersion = (2, 0) +CURRENT_PROTOCOL: ProtocolVersion = (3, 0) MIN_SUGGESTED_CLIENT_POOL_SIZE = 10 MAX_SUGGESTED_CLIENT_POOL_SIZE = 100 diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 4b6a0ed0608..8f95af2e3be 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -156,6 +156,7 @@ cdef class PGConnection: cdef bint before_prepare( self, bytes stmt_name, int dbver, WriteBuffer outbuf) cdef write_sync(self, WriteBuffer outbuf) + cdef send_sync(self) cdef make_clean_stmt_message(self, bytes stmt_name) cdef send_query_unit_group( diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index e6c00c287b9..7c6821bf028 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -75,6 +75,11 @@ class PGConnection(asyncio.Protocol): use_prep_stmt: bool = False, state: Optional[bytes] = None, ) -> list[bytes]: ... + async def sql_describe( + self, + sql: bytes, + param_type_oids: list[int] | None = None, + ) -> tuple[list[int], list[tuple[str, int]]]: ... def terminate(self) -> None: ... def add_log_listener(self, cb: Callable[[str, str], None]) -> None: ... def get_server_parameter_status(self, parameter: str) -> Optional[str]: ... diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 0c66aa3e1c3..d4a7bf58317 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -453,6 +453,10 @@ cdef class PGConnection: outbuf.write_bytes(_SYNC_MESSAGE) self.waiting_for_sync += 1 + cdef send_sync(self): + self.write(_SYNC_MESSAGE) + self.waiting_for_sync += 1 + def _build_apply_state_req(self, bytes serstate, WriteBuffer out): cdef: WriteBuffer buf @@ -807,6 +811,152 @@ cdef class PGConnection: finally: self.buffer.finish_message() + async def _describe( + self, + query: bytes, + param_type_oids: Optional[list[int]], + ): + cdef: + WriteBuffer out + + out = WriteBuffer.new() + + buf = WriteBuffer.new_message(b"P") # Parse + buf.write_bytestring(b"") + buf.write_bytestring(query) + if param_type_oids: + buf.write_int16(len(param_type_oids)) + for oid in param_type_oids: + buf.write_int32(oid) + else: + buf.write_int16(0) + out.write_buffer(buf.end_message()) + + buf = WriteBuffer.new_message(b"D") # Describe + buf.write_byte(b"S") + buf.write_bytestring(b"") + out.write_buffer(buf.end_message()) + + out.write_bytes(FLUSH_MESSAGE) + + self.write(out) + + param_desc = None + result_desc = None + + try: + buf = None + while True: + if not self.buffer.take_message(): + await self.wait_for_message() + mtype = self.buffer.get_message_type() + + try: + if mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + elif mtype == b't': + # ParameterDescription + param_desc = self._decode_param_desc(self.buffer) + + elif mtype == b'T': + # RowDescription + result_desc = self._decode_row_desc(self.buffer) + break + + elif mtype == b'n': + # NoData + self.buffer.discard_message() + param_desc = [] + result_desc = [] + break + + elif mtype == b'E': ## result + # ErrorResponse + er_cls, er_fields = self.parse_error_message() + raise er_cls(fields=er_fields) + + else: + self.fallthrough() + + finally: + self.buffer.finish_message() + except Exception: + self.send_sync() + await self.wait_for_sync() + raise + + if param_desc is None: + raise RuntimeError( + "did not receive ParameterDescription from backend " + "in response to Describe" + ) + + if result_desc is None: + raise RuntimeError( + "did not receive RowDescription from backend " + "in response to Describe" + ) + + return param_desc, result_desc + + def _decode_param_desc(self, buf: ReadBuffer): + cdef: + int16_t nparams + uint32_t p_oid + list result = [] + + nparams = buf.read_int16() + + for _ in range(nparams): + p_oid = buf.read_int32() + result.append(p_oid) + + return result + + def _decode_row_desc(self, buf: ReadBuffer): + cdef: + int16_t nfields + + bytes f_name + uint32_t f_table_oid + int16_t f_column_num + uint32_t f_dt_oid + int16_t f_dt_size + int32_t f_dt_mod + int16_t f_format + + list result + + nfields = buf.read_int16() + + result = [] + for _ in range(nfields): + f_name = buf.read_null_str() + f_table_oid = buf.read_int32() + f_column_num = buf.read_int16() + f_dt_oid = buf.read_int32() + f_dt_size = buf.read_int16() + f_dt_mod = buf.read_int32() + f_format = buf.read_int16() + + result.append((f_name.decode("utf-8"), f_dt_oid)) + + return result + + async def sql_describe( + self, + query: bytes, + param_type_oids: Optional[list[int]] = None, + ) -> tuple[list[int], list[tuple[str, int]]]: + self.before_command() + started_at = time.monotonic() + try: + return await self._describe(query, param_type_oids) + finally: + await self.after_command() + async def _parse_execute( self, query, @@ -817,6 +967,7 @@ cdef class PGConnection: int dbver, bint use_pending_func_cache, tx_isolation, + list param_data_types, ): cdef: WriteBuffer out @@ -938,7 +1089,12 @@ cdef class PGConnection: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) buf.write_bytestring(sqls[0]) - buf.write_int16(0) + if param_data_types: + buf.write_int16(len(param_data_types)) + for oid in param_data_types: + buf.write_int32(oid) + else: + buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( len(sqls[0]), self.get_tenant_label(), 'compiled' @@ -1135,6 +1291,7 @@ cdef class PGConnection: *, query, WriteBuffer bind_data = NO_ARGS, + list param_data_types = None, frontend.AbstractFrontendConnection fe_conn = None, bint use_prep_stmt = False, bytes state = None, @@ -1154,6 +1311,7 @@ cdef class PGConnection: dbver, use_pending_func_cache, tx_isolation, + param_data_types, ) finally: metrics.backend_query_duration.observe( diff --git a/edb/server/pgproto b/edb/server/pgproto index b8109fb311a..9f415b2c834 160000 --- a/edb/server/pgproto +++ b/edb/server/pgproto @@ -1 +1 @@ -Subproject commit b8109fb311a59f30f9947567a13508da9a776564 +Subproject commit 9f415b2c834df119422c011e5163e21064bff6ad diff --git a/edb/server/protocol/args_ser.pxd b/edb/server/protocol/args_ser.pxd index 75550076d62..04bb2c3ad47 100644 --- a/edb/server/protocol/args_ser.pxd +++ b/edb/server/protocol/args_ser.pxd @@ -26,6 +26,7 @@ cdef WriteBuffer recode_bind_args( dbview.CompiledQuery compiled, bytes bind_args, list positions = ?, + list data_types = ?, ) diff --git a/edb/server/protocol/args_ser.pyx b/edb/server/protocol/args_ser.pyx index c5efda663f7..d95b2ebac61 100644 --- a/edb/server/protocol/args_ser.pyx +++ b/edb/server/protocol/args_ser.pyx @@ -102,6 +102,7 @@ cdef WriteBuffer recode_bind_args( bytes bind_args, # XXX do something better?!? list positions = None, + list data_types = None, ): cdef: FRBuffer in_buf @@ -121,10 +122,6 @@ cdef WriteBuffer recode_bind_args( cpython.PyBytes_AS_STRING(bind_args), cpython.Py_SIZE(bind_args)) - # all parameters are in binary - if live: - out_buf.write_int32(0x00010001) - # number of elements in the tuple # for empty tuple it's okay to send zero-length arguments qug = compiled.query_unit_group @@ -155,11 +152,36 @@ cdef WriteBuffer recode_bind_args( f"argument count mismatch {recv_args} != {compiled.first_extra}" num_args += compiled.extra_counts[0] - num_args += _count_globals(qug) + num_globals = _count_globals(qug) + num_args += num_globals if live: + if not compiled.extra_formatted_as_text: + # all parameter values are in binary + out_buf.write_int32(0x00010001) + elif not recv_args and not num_globals: + # all parameter values are in text (i.e extracted SQL constants) + out_buf.write_int16(0x0000) + else: + # got a mix of binary and text, spell them out explicitly + out_buf.write_int16(num_args) + # explicit args are in binary + for _ in range(recv_args): + out_buf.write_int16(0x0001) + # and extracted SQL constants are in text + for _ in range(compiled.extra_counts[0]): + out_buf.write_int16(0x0000) + # and injected globals are binary again + for _ in range(num_globals): + out_buf.write_int16(0x0001) + out_buf.write_int16(num_args) + if data_types is not None and compiled.extra_type_oids: + data_types.extend([0] * recv_args) + data_types.extend(compiled.extra_type_oids) + data_types.extend([0] * num_globals) + if qug.in_type_args: for param in qug.in_type_args: if positions is not None: diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index d35d5664dcf..6f23400bdeb 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -43,6 +43,8 @@ from edb import buildmeta from edb import edgeql from edb.edgeql import qltypes +from edb.pgsql import parser as pgparser + from edb.server.pgproto cimport hton from edb.server.pgproto.pgproto cimport ( WriteBuffer, @@ -94,6 +96,10 @@ cdef object CARD_AT_MOST_ONE = compiler.Cardinality.AT_MOST_ONE cdef object CARD_MANY = compiler.Cardinality.MANY cdef object FMT_NONE = compiler.OutputFormat.NONE +cdef object FMT_BINARY = compiler.OutputFormat.BINARY + +cdef object LANG_EDGEQL = compiler.InputLanguage.EDGEQL +cdef object LANG_SQL = compiler.InputLanguage.SQL cdef tuple DUMP_VER_MIN = (0, 7) cdef tuple DUMP_VER_MAX = edbdef.CURRENT_PROTOCOL @@ -486,12 +492,25 @@ cdef class EdgeConnection(frontend.FrontendConnection): fe_conn=self, ) - def _tokenize(self, eql: bytes) -> edgeql.Source: + def _tokenize( + self, + eql: bytes, + lang: enums.InputLanguage, + ) -> edgeql.Source: text = eql.decode('utf-8') - if debug.flags.edgeql_disable_normalization: - return edgeql.Source.from_string(text) + if lang is LANG_EDGEQL: + if debug.flags.edgeql_disable_normalization: + return edgeql.Source.from_string(text) + else: + return edgeql.NormalizedSource.from_string(text) + elif lang is LANG_SQL: + if debug.flags.edgeql_disable_normalization: + return pgparser.Source.from_string(text) + else: + return pgparser.NormalizedSource.from_string(text) else: - return edgeql.NormalizedSource.from_string(text) + raise errors.UnsupportedFeatureError( + f"unsupported input language: {lang}") async def _suppress_tx_timeout(self): async with self.with_pgcon() as conn: @@ -529,6 +548,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): 'Cache key', source.cache_key(), f"protocol_version={query_req.protocol_version}", + f"input_language={query_req.input_language}", f"output_format={query_req.output_format}", f"expect_one={query_req.expect_one}", f"implicit_limit={query_req.implicit_limit}", @@ -550,9 +570,18 @@ cdef class EdgeConnection(frontend.FrontendConnection): if suppress_timeout: await self._suppress_tx_timeout() try: - return await dbv.parse( - query_req, allow_capabilities=allow_capabilities - ) + if query_req.input_language is LANG_SQL: + async with self.with_pgcon() as pg_conn: + return await dbv.parse( + query_req, + allow_capabilities=allow_capabilities, + pgcon=pg_conn, + ) + else: + return await dbv.parse( + query_req, + allow_capabilities=allow_capabilities, + ) finally: if suppress_timeout: await self._restore_tx_timeout(dbv) @@ -752,6 +781,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): bint inline_typenames = False bint inline_typeids = False bint inline_objectids = False + object cardinality object output_format bint expect_one = False bytes query @@ -779,10 +809,29 @@ cdef class EdgeConnection(frontend.FrontendConnection): & messages.CompilationFlag.INJECT_OUTPUT_OBJECT_IDS ) + if self.protocol_version >= (3, 0): + lang = rpc.deserialize_input_language(self.buffer.read_byte()) + else: + lang = LANG_EDGEQL + output_format = rpc.deserialize_output_format(self.buffer.read_byte()) - expect_one = ( - self.parse_cardinality(self.buffer.read_byte()) is CARD_AT_MOST_ONE - ) + if ( + lang is LANG_SQL + and output_format is not FMT_NONE + and output_format is not FMT_BINARY + ): + raise errors.UnsupportedFeatureError( + "non-binary output format is not supported with " + "SQL as the input language" + ) + + cardinality = self.parse_cardinality(self.buffer.read_byte()) + expect_one = cardinality is CARD_AT_MOST_ONE + if lang is LANG_SQL and cardinality is not CARD_MANY: + raise errors.UnsupportedFeatureError( + "output cardinality assertions are not supported with " + "SQL as the input language" + ) query = self.buffer.read_len_prefixed_bytes() if not query: @@ -803,10 +852,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): cfg_ser = self.server.compilation_config_serializer rv = rpc.CompilationRequest( - source=self._tokenize(query), + source=self._tokenize(query, lang), protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, + input_language=lang, output_format=output_format, expect_one=expect_one, implicit_limit=implicit_limit, @@ -817,6 +867,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): session_config=_dbview.get_session_config(), database_config=_dbview.get_database_config(), system_config=_dbview.get_compilation_system_config(), + role_name=self.username, + branch_name=self.dbname, ) return rv, allow_capabilities @@ -893,6 +945,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): else: compiled = _dbview.as_compiled(query_req, query_unit_group) + if query_req.input_language is LANG_SQL and len(query_unit_group) > 1: + raise errors.UnsupportedFeatureError( + "multi-statement SQL scripts are not supported yet") + self._query_count += 1 # Clear the _last_anon_compiled so that the next Execute - if @@ -1438,6 +1494,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, + role_name=self.username, + branch_name=self.dbname, ) compiled = await _dbview.parse(query_req) @@ -1838,6 +1896,8 @@ async def run_script( schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, output_format=FMT_NONE, + role_name=user, + branch_name=database, ), ) if len(compiled.query_unit_group) > 1: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index bb63859a204..43f5c896b29 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -275,8 +275,9 @@ async def execute( if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: + data_types = [] bound_args_buf = args_ser.recode_bind_args( - dbv, compiled, bind_args) + dbv, compiled, bind_args, None, data_types) assert not (query_unit.database_config and query_unit.needs_readback), ( @@ -289,6 +290,7 @@ async def execute( query=query_unit, fe_conn=fe_conn if not read_data else None, bind_data=bound_args_buf, + param_data_types=data_types, use_prep_stmt=use_prep_stmt, state=state, dbver=dbv.dbver, diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 187f3d10b00..4523372f488 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -1209,7 +1209,7 @@ async def _introspect_db( SELECT json_object_agg( "id"::text, - "backend_id" + json_build_array("backend_id", "name") )::text FROM edgedb_VER."_SchemaType" diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 73fca8ec430..92e026b1735 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -436,6 +436,7 @@ async def _fetchall( implicit_limit=__limit__, inline_typeids=__typeids__, inline_typenames=__typenames__, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.BINARY, allow_capabilities=__allow_capabilities__, ) @@ -457,6 +458,7 @@ async def _fetchall_json( qc=self._query_cache.query_cache, implicit_limit=__limit__, inline_typenames=False, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, ) ) @@ -469,6 +471,7 @@ async def _fetchall_json_elements(self, query: str, *args, **kwargs): kwargs=kwargs, reg=self._query_cache.codecs_registry, qc=self._query_cache.query_cache, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON_ELEMENTS, allow_capabilities=edgedb_enums.Capability.EXECUTE, # type: ignore ) @@ -499,6 +502,7 @@ def is_closed(self): async def connect(self, single_attempt=False): self._params, client_config = con_utils.parse_connect_arguments( **self._connect_args, + tls_server_name=None, command_timeout=None, server_settings=None, ) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 2054a63d536..1b8806897ed 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -572,6 +572,7 @@ def http_con_binary_request( compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=query, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, diff --git a/pyproject.toml b/pyproject.toml index 8963db4eec4..3f0088c0798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "Gel Server" requires-python = '>=3.12.0' dynamic = ["version"] dependencies = [ - 'edgedb==2.2.0', + 'edgedb==3.0.0b1', 'httptools>=0.6.0', 'immutables>=0.18', @@ -103,7 +103,7 @@ requires = [ "wheel", "parsing ~= 2.0", - "edgedb==2.2.0", + "edgedb==3.0.0b1", ] # Custom backend needed to set up build-time sys.path because # setup.py needs to import `edb.buildmeta`. diff --git a/setup.py b/setup.py index 3bb5408b248..d98045fabd8 100644 --- a/setup.py +++ b/setup.py @@ -691,8 +691,8 @@ class build_ext(setuptools_build_ext.build_ext): user_options = setuptools_build_ext.build_ext.user_options + [ ('cython-annotate', None, 'Produce a colorized HTML version of the Cython source.'), - ('cython-directives=', None, - 'Cython compiler directives'), + ('cython-extra-directives=', None, + 'Extra Cython compiler directives'), ] def initialize_options(self): @@ -707,17 +707,17 @@ def initialize_options(self): if os.environ.get('EDGEDB_DEBUG'): self.cython_always = True self.cython_annotate = True - self.cython_directives = "linetrace=True" + self.cython_extra_directives = "linetrace=True" self.define = 'PG_DEBUG,CYTHON_TRACE,CYTHON_TRACE_NOGIL' self.debug = True else: self.cython_always = False self.cython_annotate = None - self.cython_directives = None + self.cython_extra_directives = None self.debug = False self.build_mode = os.environ.get('BUILD_EXT_MODE', 'both') - def finalize_options(self): + def finalize_options(self) -> None: # finalize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. @@ -731,12 +731,12 @@ def finalize_options(self): super(build_ext, self).finalize_options() return - directives = { + directives: dict[str, str | bool] = { 'language_level': '3' } - if self.cython_directives: - for directive in self.cython_directives.split(','): + if self.cython_extra_directives: + for directive in self.cython_extra_directives.split(','): k, _, v = directive.partition('=') if v.lower() == 'false': v = False diff --git a/tests/test_http_auth.py b/tests/test_http_auth.py index e66fc81ba0a..4b173979ba1 100644 --- a/tests/test_http_auth.py +++ b/tests/test_http_auth.py @@ -265,6 +265,7 @@ def test_http_binary_proto_too_old(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text="SELECT 42", + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, @@ -308,6 +309,7 @@ def test_http_binary_proto_old_supported(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text="SELECT 42", + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 13211593534..2de6c650c4d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -46,6 +46,7 @@ async def _execute( compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=command_text, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, @@ -150,6 +151,7 @@ async def test_proto_flush_01(self): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.BINARY, expected_cardinality=compiler.Cardinality.AT_MOST_ONE, command_text='SEL ECT 1', @@ -174,6 +176,7 @@ async def test_proto_flush_01(self): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.BINARY, expected_cardinality=compiler.Cardinality.AT_MOST_ONE, command_text='SELECT 1', @@ -425,6 +428,7 @@ async def _parse(self, query, output_format=protocol.OutputFormat.BINARY): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=output_format, expected_cardinality=compiler.Cardinality.MANY, command_text=query, @@ -540,6 +544,7 @@ async def _parse_execute(self, query, args): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=query, + input_language=protocol.InputLanguage.EDGEQL, output_format=output_format, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=res.input_typedesc_id, @@ -846,6 +851,7 @@ async def test_proto_connection_lost_cancel_query(self): UPDATE tclcq SET { p := 'inner' }; SELECT sys::_sleep(10); """, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, @@ -914,6 +920,7 @@ async def test_proto_gh3170_connection_lost_error(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text='START TRANSACTION', + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, diff --git a/tests/test_server_config.py b/tests/test_server_config.py index fa6c3f7af65..3f8f8f25591 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -1972,6 +1972,7 @@ async def test_server_config_idle_transaction(self): messages.Execute( annotations=[], command_text=query, + input_language=messages.InputLanguage.EDGEQL, output_format=messages.OutputFormat.NONE, expected_cardinality=messages.Cardinality.MANY, allowed_capabilities=messages.Capability.ALL, diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 0a904be368a..e51df00e72d 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -684,6 +684,7 @@ async def _test_connection(self, con): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text='SELECT 1', + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, From 9f5f94bb998457468282c98d301f2011830366cc Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Wed, 13 Nov 2024 17:23:57 -0800 Subject: [PATCH 13/16] Add scalar type declarations for some unexposed Postgres types Add text `json` (counterpart to `jsonb`) and base variants of date/time/interval types. --- edb/api/types.txt | 6 ++++++ edb/lib/pg.edgeql | 6 ++++++ edb/pgsql/types.py | 11 +++++++++++ edb/schema/_types.py | 10 ++++++++++ edb/server/dbview/dbview.pyx | 2 +- edb/tools/gen_types.py | 2 +- 6 files changed, 35 insertions(+), 2 deletions(-) diff --git a/edb/api/types.txt b/edb/api/types.txt index 1197b178bb4..3825a01dd6d 100644 --- a/edb/api/types.txt +++ b/edb/api/types.txt @@ -49,3 +49,9 @@ 00000000-0000-0000-0000-000000000112 std::cal::date_duration 00000000-0000-0000-0000-000000000130 cfg::memory + +00000000-0000-0000-0000-000001000001 std::pg::json +00000000-0000-0000-0000-000001000002 std::pg::timestamptz +00000000-0000-0000-0000-000001000003 std::pg::timestamp +00000000-0000-0000-0000-000001000004 std::pg::date +00000000-0000-0000-0000-000001000005 std::pg::interval diff --git a/edb/lib/pg.edgeql b/edb/lib/pg.edgeql index d7ca8cdaf5c..a6359f62f93 100644 --- a/edb/lib/pg.edgeql +++ b/edb/lib/pg.edgeql @@ -84,3 +84,9 @@ create index match for std::cal::local_date using std::pg::brin; create index match for std::cal::local_time using std::pg::brin; create index match for std::cal::relative_duration using std::pg::brin; create index match for std::cal::date_duration using std::pg::brin; + +create scalar type std::pg::json extending std::anyscalar; +create scalar type std::pg::timestamptz extending std::anycontiguous; +create scalar type std::pg::timestamp extending std::anycontiguous; +create scalar type std::pg::date extending std::anydiscrete; +create scalar type std::pg::interval extending std::anycontiguous; diff --git a/edb/pgsql/types.py b/edb/pgsql/types.py index 9a2319c2e1a..f49033cf67c 100644 --- a/edb/pgsql/types.py +++ b/edb/pgsql/types.py @@ -68,6 +68,12 @@ ('edgedbt', 'date_duration_t'), s_obj.get_known_type_id('cfg::memory'): ('edgedbt', 'memory_t'), + + s_obj.get_known_type_id('std::pg::json'): ('json',), + s_obj.get_known_type_id('std::pg::timestamptz'): ('timestamptz',), + s_obj.get_known_type_id('std::pg::timestamp'): ('timestamp',), + s_obj.get_known_type_id('std::pg::date'): ('date',), + s_obj.get_known_type_id('std::pg::interval'): ('interval',), } type_to_range_name_map = { @@ -85,6 +91,9 @@ # custom range is a big hassle, and daterange already has the # correct canonicalization function ('edgedbt', 'date_t'): ('daterange',), + ('timestamptz',): ('tstzrange',), + ('timestamp',): ('tsrange',), + ('date',): ('daterange',), } # Construct a multirange map based on type_to_range_name_map by replacing @@ -143,6 +152,8 @@ 'edgedbt.memory_t': sn.QualName('cfg', 'memory'), 'memory_t': sn.QualName('cfg', 'memory'), + + 'json': sn.QualName('std::pg', 'json'), } pg_tsvector_typeref = irast.TypeRef( diff --git a/edb/schema/_types.py b/edb/schema/_types.py index d96435e99bf..bf4af583475 100644 --- a/edb/schema/_types.py +++ b/edb/schema/_types.py @@ -66,4 +66,14 @@ UUID('00000000-0000-0000-0000-000000000112'), sn.name_from_string('cfg::memory'): UUID('00000000-0000-0000-0000-000000000130'), + sn.name_from_string('std::pg::json'): + UUID('00000000-0000-0000-0000-000001000001'), + sn.name_from_string('std::pg::timestamptz'): + UUID('00000000-0000-0000-0000-000001000002'), + sn.name_from_string('std::pg::timestamp'): + UUID('00000000-0000-0000-0000-000001000003'), + sn.name_from_string('std::pg::date'): + UUID('00000000-0000-0000-0000-000001000004'), + sn.name_from_string('std::pg::interval'): + UUID('00000000-0000-0000-0000-000001000005'), } diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index b2be2587ec9..0a7d8ea620e 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -1442,7 +1442,7 @@ cdef class DatabaseConnectionView: edb_type_expr = self._db.backend_id_to_name.get(toid) if edb_type_expr is None: raise errors.UnsupportedFeatureError( - f"unsupported SQL type in column \"{col}\"" + f"unsupported SQL type in column \"{col}\" " f"with type OID {toid}" ) diff --git a/edb/tools/gen_types.py b/edb/tools/gen_types.py index f9bb0ae0058..b16713e3966 100644 --- a/edb/tools/gen_types.py +++ b/edb/tools/gen_types.py @@ -63,7 +63,7 @@ def main(*, stdout: bool): f'\n\n\n' f'from __future__ import annotations' f'\n' - f'from typing import * # NoQA' + f'from typing import Type' f'\n\n\n' f'import uuid' f'\n\n' From 96be0791de95e5384a2533b58e71edb4452beb3e Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 15 Nov 2024 10:08:41 -0800 Subject: [PATCH 14/16] Add some SQL-over-proto tests --- edb/common/assert_data_shape.py | 11 ++-- edb/testbase/connection.py | 6 +- edb/testbase/server.py | 108 ++++++++++++++++++++++---------- pyproject.toml | 4 +- tests/test_sql_query.py | 99 +++++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 42 deletions(-) diff --git a/edb/common/assert_data_shape.py b/edb/common/assert_data_shape.py index 54c8a40cd82..22cde8812bb 100644 --- a/edb/common/assert_data_shape.py +++ b/edb/common/assert_data_shape.py @@ -20,14 +20,13 @@ from __future__ import annotations +import datetime import decimal import math import pprint import uuid import unittest -from datetime import timedelta - import edgedb @@ -280,14 +279,14 @@ def _assert_generic_shape(path, data, shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') - elif isinstance(shape, (str, int, bytes, timedelta, + elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.RelativeDuration): - if data != timedelta( + if data != datetime.timedelta( days=shape.months * 30 + shape.days, microseconds=shape.microseconds, ): @@ -295,7 +294,7 @@ def _assert_generic_shape(path, data, shape): f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.DateDuration): - if data != timedelta( + if data != datetime.timedelta( days=shape.months * 30 + shape.days, ): fail( @@ -352,7 +351,7 @@ def _assert_generic_shape(path, data, shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') - elif isinstance(shape, (str, int, bytes, timedelta, + elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 92e026b1735..3da73069a89 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -57,6 +57,9 @@ def raise_first_warning(warnings, res): raise warnings[0] +InputLanguage = protocol.InputLanguage + + class BaseTransaction(abc.ABC): ID_COUNTER = 0 @@ -419,6 +422,7 @@ async def _fetchall( self, query: str, *args, + __language__: protocol.InputLanguage = protocol.InputLanguage.EDGEQL, __limit__: int = 0, __typeids__: bool = False, __typenames__: bool = False, @@ -436,7 +440,7 @@ async def _fetchall( implicit_limit=__limit__, inline_typeids=__typeids__, inline_typenames=__typenames__, - input_language=protocol.InputLanguage.EDGEQL, + input_language=__language__, output_format=protocol.OutputFormat.BINARY, allow_capabilities=__allow_capabilities__, ) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 1b8806897ed..a234a6662fb 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -27,6 +27,7 @@ Type, Union, Iterable, + Literal, Sequence, Dict, List, @@ -1112,40 +1113,51 @@ def assert_data_shape(self, data, shape, message=message, rel_tol=rel_tol, abs_tol=abs_tol, ) - async def assert_query_result(self, query, - exp_result_json, - exp_result_binary=..., - *, - always_typenames=False, - msg=None, sort=None, implicit_limit=0, - variables=None, json_only=False, - rel_tol=None, abs_tol=None): + async def assert_query_result( + self, + query, + exp_result_json, + exp_result_binary=..., + *, + always_typenames=False, + msg=None, + sort=None, + implicit_limit=0, + variables=None, + json_only=False, + binary_only=False, + rel_tol=None, + abs_tol=None, + language: Literal["sql", "edgeql"] = "edgeql", + ): fetch_args = variables if isinstance(variables, tuple) else () fetch_kw = variables if isinstance(variables, dict) else {} - try: - tx = self.con.transaction() - await tx.start() - try: - res = await self.con._fetchall_json( - query, - *fetch_args, - __limit__=implicit_limit, - **fetch_kw) - finally: - await tx.rollback() - res = json.loads(res) - if sort is not None: - assert_data_shape.sort_results(res, sort) - assert_data_shape.assert_data_shape( - res, exp_result_json, self.fail, - message=msg, rel_tol=rel_tol, abs_tol=abs_tol, - ) - except Exception: - self.add_fail_notes(serialization='json') - if msg: - self.add_fail_notes(msg=msg) - raise + if not binary_only and language != "sql": + try: + tx = self.con.transaction() + await tx.start() + try: + res = await self.con._fetchall_json( + query, + *fetch_args, + __limit__=implicit_limit, + **fetch_kw) + finally: + await tx.rollback() + + res = json.loads(res) + if sort is not None: + assert_data_shape.sort_results(res, sort) + assert_data_shape.assert_data_shape( + res, exp_result_json, self.fail, + message=msg, rel_tol=rel_tol, abs_tol=abs_tol, + ) + except Exception: + self.add_fail_notes(serialization='json') + if msg: + self.add_fail_notes(msg=msg) + raise if json_only: return @@ -1164,14 +1176,22 @@ async def assert_query_result(self, query, __typenames__=typenames, __typeids__=typeids, __limit__=implicit_limit, + __language__=( + tconn.InputLanguage.SQL if language == "sql" + else tconn.InputLanguage.EDGEQL + ), **fetch_kw ) res = serutils.serialize(res) if sort is not None: assert_data_shape.sort_results(res, sort) assert_data_shape.assert_data_shape( - res, exp_result_binary, self.fail, - message=msg, rel_tol=rel_tol, abs_tol=abs_tol, + res, + exp_result_binary, + self.fail, + message=msg, + rel_tol=rel_tol, + abs_tol=abs_tol, ) except Exception: self.add_fail_notes( @@ -1182,6 +1202,28 @@ async def assert_query_result(self, query, self.add_fail_notes(msg=msg) raise + async def assert_sql_query_result( + self, + query, + exp_result, + *, + msg=None, + sort=None, + variables=None, + rel_tol=None, + abs_tol=None, + ): + await self.assert_query_result( + query, + exp_result, + msg=msg, + sort=sort, + variables=variables, + rel_tol=rel_tol, + abs_tol=abs_tol, + language="sql", + ) + async def assert_index_use(self, query, *args, index_type): def look(obj): if ( diff --git a/pyproject.toml b/pyproject.toml index 3f0088c0798..875884969f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "Gel Server" requires-python = '>=3.12.0' dynamic = ["version"] dependencies = [ - 'edgedb==3.0.0b1', + 'edgedb==3.0.0b2', 'httptools>=0.6.0', 'immutables>=0.18', @@ -103,7 +103,7 @@ requires = [ "wheel", "parsing ~= 2.0", - "edgedb==3.0.0b1", + "edgedb==3.0.0b2", ] # Custom backend needed to set up build-time sys.path because # setup.py needs to import `edb.buildmeta`. diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 79a7326c9ec..7d086ee17ea 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -26,6 +26,8 @@ from edb.tools import test from edb.testbase import server as tb +import edgedb + try: import asyncpg from asyncpg import serverversion @@ -2017,3 +2019,100 @@ async def test_sql_query_access_policy_04(self): self.assertEqual(len(res), 0) await tran.rollback() + + async def test_native_sql_query_00(self): + await self.assert_sql_query_result( + """ + SELECT + 1 AS a, + 'two' AS b, + to_json('three') AS c, + timestamp '2000-12-16 12:21:13' AS d, + timestamp with time zone '2000-12-16 12:21:13' AS e, + date '0001-01-01 AD' AS f, + interval '2000 years' AS g, + ARRAY[1, 2, 3] AS h, + FALSE AS i + """, + [{ + "a": 1, + "b": "two", + "c": '"three"', + "d": "2000-12-16T12:21:13", + "e": "2000-12-16T12:21:13+00:00", + "f": "0001-01-01", + "g": edgedb.RelativeDuration(months=2000 * 12), + "h": [1, 2, 3], + "i": False, + }] + ) + + async def test_native_sql_query_01(self): + await self.assert_sql_query_result( + """ + SELECT + "Movie".title, + "Genre".name AS genre + FROM + "Movie", + "Genre" + WHERE + "Movie".genre_id = "Genre".id + AND "Genre".name = 'Drama' + ORDER BY + title + """, + [{ + "title": "Forrest Gump", + "genre": "Drama", + }, { + "title": "Saving Private Ryan", + "genre": "Drama", + }] + ) + + async def test_native_sql_query_02(self): + await self.assert_sql_query_result( + """ + SELECT + "Movie".title, + "Genre".name AS genre + FROM + "Movie", + "Genre" + WHERE + "Movie".genre_id = "Genre".id + AND "Genre".name = $1::text + AND length("Movie".title) > $2::int + ORDER BY + title + """, + [{ + "title": "Saving Private Ryan", + "genre": "Drama", + }], + variables={ + "0": "Drama", + "1": 14, + }, + ) + + async def test_native_sql_query_03(self): + # No output at all + await self.assert_sql_query_result( + """ + SELECT + WHERE NULL + """, + [], + ) + + # Empty tuples + await self.assert_sql_query_result( + """ + SELECT + FROM "Movie" + LIMIT 1 + """, + [{}], + ) From a3cb17c1aa592a104e7656ff6f4811a2b51ec92c Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 14 Nov 2024 15:41:00 -0800 Subject: [PATCH 15/16] edb.test: Report long-running tests in verbose output mode --- edb/tools/test/runner.py | 57 +++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/edb/tools/test/runner.py b/edb/tools/test/runner.py index 207f92e3252..2bfdfb39854 100644 --- a/edb/tools/test/runner.py +++ b/edb/tools/test/runner.py @@ -285,6 +285,17 @@ def monitor_thread(queue, result): method(*args, **kwargs) +def status_thread_func( + result: ParallelTextTestResult, + stop_event: threading.Event, +) -> None: + while True: + result.report_still_running() + time.sleep(1) + if stop_event.is_set(): + break + + class ParallelTestSuite(unittest.TestSuite): def __init__( self, tests, server_conn, num_workers, backend_dsn, init_worker @@ -310,10 +321,22 @@ def run(self, result): worker_param_queue.put((self.server_conn, self.backend_dsn)) result_thread = threading.Thread( - name='test-monitor', target=monitor_thread, - args=(result_queue, result), daemon=True) + name='test-monitor', + target=monitor_thread, + args=(result_queue, result), + daemon=True, + ) result_thread.start() + status_thread_stop_event = threading.Event() + status_thread = threading.Thread( + name='test-status', + target=status_thread_func, + args=(result, status_thread_stop_event), + daemon=True, + ) + status_thread.start() + initargs = ( status_queue, worker_param_queue, result_queue, self.init_worker ) @@ -357,12 +380,13 @@ def run(self, result): # Post the terminal message to the queue so that # test-monitor can stop. result_queue.put((None, None, None)) + status_thread_stop_event.set() - # Give the test-monitor thread some time to - # process the queue messages. If something - # goes wrong, the thread will be forcibly + # Give the test-monitor and test-status threads some time to process the + # queue messages. If something goes wrong, the thread will be forcibly # joined by a timeout. result_thread.join(timeout=3) + status_thread.join(timeout=3) return result @@ -450,6 +474,9 @@ def report(self, test, marker, description=None, *, currently_running): def report_start(self, test, *, currently_running): return + def report_still_running(self, still_running: dict[str, float]): + return + class SimpleRenderer(BaseRenderer): def report(self, test, marker, description=None, *, currently_running): @@ -480,6 +507,10 @@ def report(self, test, marker, description=None, *, currently_running): click.echo(style(self._render_test(test, marker, description)), file=self.stream) + def report_still_running(self, still_running: dict[str, float]) -> None: + items = [f"{t} for {d:.02f}s" for t, d in still_running.items()] + click.echo(f"still running:\n {'\n '.join(items)}") + class MultiLineRenderer(BaseRenderer): @@ -521,6 +552,10 @@ def report(self, test, marker, description=None, *, currently_running): def report_start(self, test, *, currently_running): self._render(currently_running) + def report_still_running(self, still_running: dict[str, float]): + # Still-running tests are already reported in normal repert + return + def _render_modname(self, name): return name.replace('.', '/') + '.py' @@ -727,6 +762,16 @@ def report_progress(self, test, marker, description=None): currently_running=list(self.currently_running), ) + def report_still_running(self): + now = time.monotonic() + still_running = {} + for test, start in self.currently_running.items(): + running_for = now - start + if running_for > 5.0: + still_running[test] = running_for + if still_running: + self.ren.report_still_running(still_running) + def record_test_stats(self, test, stats): self.test_stats.append((test, stats)) @@ -745,7 +790,7 @@ def getDescription(self, test): def startTest(self, test): super().startTest(test) - self.currently_running[test] = True + self.currently_running[test] = time.monotonic() self.ren.report_start( test, currently_running=list(self.currently_running)) From 743d96cd5414942ae18b70322be496e953fdd6ea Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 14 Nov 2024 15:47:50 -0800 Subject: [PATCH 16/16] Update to latest edgedb-rust --- Cargo.lock | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0cde3ad83d6..d82ddd3928b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -540,7 +540,7 @@ dependencies = [ [[package]] name = "edgedb-errors" version = "0.4.2" -source = "git+https://github.com/edgedb/edgedb-rust#f9d784470af6e013051d1503882c1d88c51a5dcb" +source = "git+https://github.com/edgedb/edgedb-rust#b38fb4af07ae0017329eb3cce30ca37fe12acd29" dependencies = [ "bytes", ] @@ -548,7 +548,7 @@ dependencies = [ [[package]] name = "edgedb-protocol" version = "0.6.1" -source = "git+https://github.com/edgedb/edgedb-rust#f9d784470af6e013051d1503882c1d88c51a5dcb" +source = "git+https://github.com/edgedb/edgedb-rust#b38fb4af07ae0017329eb3cce30ca37fe12acd29" dependencies = [ "bigdecimal", "bitflags", @@ -2234,6 +2234,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b835cb902660db3415a672d862905e791e54d306c6e8189168c7f3d9ae1c79d" dependencies = [ + "backtrace", "snafu-derive", ]