diff --git a/.github/workflows.src/build.inc.yml b/.github/workflows.src/build.inc.yml index 50e0291d780..81f7e495b1e 100644 --- a/.github/workflows.src/build.inc.yml +++ b/.github/workflows.src/build.inc.yml @@ -152,7 +152,7 @@ uses: actions/setup-python@v5 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/dryrun.yml b/.github/workflows/dryrun.yml index c3355e44caa..6743410f2b6 100644 --- a/.github/workflows/dryrun.yml +++ b/.github/workflows/dryrun.yml @@ -994,7 +994,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -1065,7 +1065,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 12f611932db..62c2c1a8f25 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -999,7 +999,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -1070,7 +1070,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 28ec1b4fbc7..9bdf49912f9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -541,7 +541,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -610,7 +610,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index eaa7275798a..a328b2aafc7 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -563,7 +563,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -633,7 +633,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/docs/reference/configuration.rst b/docs/reference/configuration.rst index d10dae96d20..2179cd93054 100644 --- a/docs/reference/configuration.rst +++ b/docs/reference/configuration.rst @@ -74,9 +74,8 @@ Resource usage :eql:synopsis:`shared_buffers -> cfg::memory` The amount of memory used for shared memory buffers. -:eql:synopsis:`net_http_max_connections -> int64` - The maximum number of concurrent HTTP connections to allow when using the - ``std::net::http`` module. +:eql:synopsis:`http_max_connections -> int64` + The maximum number of concurrent outbound HTTP connections to allow. Query planning -------------- diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 53f6f716ca7..ff4a3ab16a4 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_10_21_00_00 +EDGEDB_CATALOG_VERSION = 2024_10_23_14_39 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/edgeql/compiler/options.py b/edb/edgeql/compiler/options.py index 9dceb12fcc4..2a02abf0edb 100644 --- a/edb/edgeql/compiler/options.py +++ b/edb/edgeql/compiler/options.py @@ -37,7 +37,7 @@ SourceOrPathId = s_types.Type | s_pointers.Pointer | pathid.PathId -@dataclass +@dataclass(kw_only=True) class GlobalCompilerOptions: """Compiler toggles that affect compilation as a whole.""" @@ -102,7 +102,7 @@ class GlobalCompilerOptions: dump_restore_mode: bool = False -@dataclass +@dataclass(kw_only=True) class CompilerOptions(GlobalCompilerOptions): #: Module name aliases. diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index bdca5d2300b..939913da510 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -238,8 +238,8 @@ ALTER TYPE cfg::AbstractConfig { 'Where the query cache is finally stored'; }; - # std::net::http Configuration - CREATE PROPERTY net_http_max_connections -> std::int64 { + # HTTP Worker Configuration + CREATE PROPERTY http_max_connections -> std::int64 { SET default := 10; CREATE ANNOTATION std::description := 'The maximum number of concurrent HTTP connections.'; diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 8f51168255f..c0e0390fe8b 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -3323,7 +3323,13 @@ def _compile_func_epilogue( aspect=pgce.PathAspect.VALUE, ) - aspects: Tuple[pgce.PathAspect, ...] = (pgce.PathAspect.VALUE,) + aspects: Tuple[pgce.PathAspect, ...] + if expr.body: + # For inlined functions, we want all of the aspects provided. + aspects = tuple(pathctx.list_path_aspects(func_rel, ir_set.path_id)) + else: + # Otherwise we just know we have value. + aspects = (pgce.PathAspect.VALUE,) func_rvar = relctx.new_rel_rvar(ir_set, func_rel, ctx=ctx) relctx.include_rvar( @@ -3605,6 +3611,12 @@ def process_set_as_func_expr( _compile_inlined_call_args(ir_set, ctx=newctx) set_expr = dispatch.compile(expr.body, ctx=newctx) + # Map the path id so that we can extract source aspects + # from it, which we want so that we can directly select + # from an INSERT instead of using overlays. + pathctx.put_path_id_map( + newctx.rel, ir_set.path_id, expr.body.path_id + ) else: args = _compile_call_args(ir_set, ctx=newctx) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 56c7db1de4a..8fdc26c87e8 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -1629,13 +1629,16 @@ def apply( if not overload: variadic = func.get_params(schema).find_variadic(schema) - self.pgops.add( - dbops.DropFunction( - name=self.get_pgname(func, schema), - args=self.compile_args(func, schema), - has_variadic=variadic is not None, + if func.get_volatility(schema) != ql_ft.Volatility.Modifying: + # Modifying functions are not compiled. + # See: compile_edgeql_function + self.pgops.add( + dbops.DropFunction( + name=self.get_pgname(func, schema), + args=self.compile_args(func, schema), + has_variadic=variadic is not None, + ) ) - ) return super().apply(schema, context) @@ -5370,10 +5373,10 @@ def _create_table( id = sn.QualName( module=prop.get_name(schema).module, name=str(prop.id)) - index_name = common.convert_name(id, 'idx0', catenate=True) + index_name = common.convert_name(id, 'idx0', catenate=False) pg_index = dbops.Index( - name=index_name, table_name=new_table_name, + name=index_name[1], table_name=new_table_name, unique=False, columns=[src_col], metadata={'code': DEFAULT_INDEX_CODE}, ) diff --git a/edb/pgsql/deltadbops.py b/edb/pgsql/deltadbops.py index c6e0806e7f5..a9bbf3d2399 100644 --- a/edb/pgsql/deltadbops.py +++ b/edb/pgsql/deltadbops.py @@ -485,7 +485,13 @@ def create_constr_trigger_function( return [dbops.CreateFunction(func, or_replace=True)] def drop_constr_trigger_function(self, proc_name: Tuple[str, ...]): - return [dbops.DropFunction(name=proc_name, args=(), if_exists=True)] + return [dbops.DropFunction( + name=proc_name, + args=(), + # Use a condition instead of if_exists ot reduce annoying + # debug spew from postgres. + conditions=[dbops.FunctionExists(name=proc_name, args=())], + )] def create_constraint(self, constraint: SchemaConstraintTableConstraint): # Add the constraint normally to our table diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 754ede7ac14..2c213cad20c 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -6673,6 +6673,7 @@ def _generate_sql_information_schema( SELECT attrelid, attname, atttypid, + attstattarget, attlen, attnum, attnum as attnum_internal, @@ -6680,8 +6681,8 @@ def _generate_sql_information_schema( attcacheoff, atttypmod, attbyval, - attalign, attstorage, + attalign, attnotnull, atthasdef, atthasmissing, @@ -6691,7 +6692,6 @@ def _generate_sql_information_schema( attislocal, attinhcount, attcollation, - attstattarget, attacl, attoptions, attfdwoptions, @@ -6716,6 +6716,7 @@ def _generate_sql_information_schema( SELECT pc_oid as attrelid, col_name as attname, COALESCE(atttypid, 25) as atttypid, -- defaults to TEXT + COALESCE(attstattarget, -1) as attstattarget, COALESCE(attlen, -1) as attlen, (ROW_NUMBER() OVER ( PARTITION BY pc_oid @@ -6726,8 +6727,8 @@ def _generate_sql_information_schema( COALESCE(attcacheoff, -1) as attcacheoff, COALESCE(atttypmod, -1) as atttypmod, COALESCE(attbyval, FALSE) as attbyval, - COALESCE(attalign, 'i') as attalign, COALESCE(attstorage, 'x') as attstorage, + COALESCE(attalign, 'i') as attalign, required as attnotnull, -- Always report no default, to avoid expr trouble false as atthasdef, @@ -6738,7 +6739,6 @@ def _generate_sql_information_schema( COALESCE(attislocal, TRUE) as attislocal, COALESCE(attinhcount, 0) as attinhcount, COALESCE(attcollation, 0) as attcollation, - COALESCE(attstattarget, -1) as attstattarget, attacl, attoptions, attfdwoptions, @@ -6839,14 +6839,15 @@ def _generate_sql_information_schema( attrelid, attname, atttypid, + attstattarget, attlen, attnum, attndims, attcacheoff, atttypmod, attbyval, - attalign, attstorage, + attalign, attnotnull, atthasdef, atthasmissing, @@ -6856,7 +6857,6 @@ def _generate_sql_information_schema( attislocal, attinhcount, attcollation, - attstattarget, attacl, attoptions, attfdwoptions, diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index cb90c125220..0a118c0bcdf 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -1135,6 +1135,9 @@ async def create_branch( elif line.startswith('CREATE TYPE'): if any(skip in line for skip in to_skip): skipping = True + elif line == 'SET transaction_timeout = 0;': + continue + if skipping: continue new_lines.append(line) @@ -1535,6 +1538,15 @@ def cleanup_tpldbdump(tpldbdump: bytes) -> bytes: flags=re.MULTILINE, ) + # PostgreSQL 17 adds a transaction_timeout config setting that + # didn't exist on earlier versions. + tpldbdump = re.sub( + rb'^SET transaction_timeout = 0;$', + rb'', + tpldbdump, + flags=re.MULTILINE, + ) + return tpldbdump @@ -2144,7 +2156,6 @@ async def _populate_misc_instance_data( json.dumps(json_single_role_metadata), ) - assert backend_params.has_create_database if not backend_params.has_create_database: await _store_static_json_cache( ctx, diff --git a/edb/server/compiler/__init__.py b/edb/server/compiler/__init__.py index c9fc5c13421..77be86bae03 100644 --- a/edb/server/compiler/__init__.py +++ b/edb/server/compiler/__init__.py @@ -30,9 +30,11 @@ from .enums import InputFormat, OutputFormat from .explain import analyze_explain_output from .ddl import repair_schema +from .rpc import CompilationRequest __all__ = ( 'Cardinality', + 'CompilationRequest', 'Compiler', 'CompilerState', 'CompileContext', diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 1d6194f9ea2..df3cdd5ba11 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -567,7 +567,7 @@ def compile_sql( backend_runtime_params=self.state.backend_runtime_params, ) - def compile_request( + def compile_serialized_request( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, @@ -579,54 +579,35 @@ def compile_request( ) -> Tuple[ dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] ]: - request = rpc.CompilationRequest( - self.state.compilation_config_serializer + request = rpc.CompilationRequest.deserialize( + serialized_request, + original_query, + self.state.compilation_config_serializer, ) - request.deserialize(serialized_request, original_query) units, cstate = self.compile( - user_schema, - global_schema, - reflection_cache, - database_config, - system_config, - request.source, - request.modaliases, - request.session_config, - request.output_format, - request.expect_one, - request.implicit_limit, - request.inline_typeids, - request.inline_typenames, - request.protocol_version, - request.inline_objectids, - request.json_parameters, - cache_key=request.get_cache_key(), + user_schema=user_schema, + global_schema=global_schema, + reflection_cache=reflection_cache, + database_config=database_config, + system_config=system_config, + request=request, ) return units, cstate def compile( self, + *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, Tuple[str, ...]], database_config: Optional[immutables.Map[str, config.SettingValue]], system_config: Optional[immutables.Map[str, config.SettingValue]], - source: edgeql.Source, - sess_modaliases: Optional[immutables.Map[Optional[str], str]], - sess_config: Optional[immutables.Map[str, config.SettingValue]], - output_format: enums.OutputFormat, - expect_one: bool, - implicit_limit: int, - inline_typeids: bool, - inline_typenames: bool, - protocol_version: defines.ProtocolVersion, - inline_objectids: bool = True, - json_parameters: bool = False, - cache_key: Optional[uuid.UUID] = None, + request: rpc.CompilationRequest, ) -> Tuple[dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState]]: + sess_config = request.session_config if sess_config is None: sess_config = EMPTY_MAP @@ -636,6 +617,7 @@ def compile( if system_config is None: system_config = EMPTY_MAP + sess_modaliases = request.modaliases if sess_modaliases is None: sess_modaliases = DEFAULT_MODULE_ALIASES_MAP @@ -652,19 +634,19 @@ def compile( ctx = CompileContext( compiler_state=self.state, state=state, - output_format=output_format, - expected_cardinality_one=expect_one, - implicit_limit=implicit_limit, - inline_typeids=inline_typeids, - inline_typenames=inline_typenames, - inline_objectids=inline_objectids, - json_parameters=json_parameters, - source=source, - protocol_version=protocol_version, - cache_key=cache_key, + output_format=request.output_format, + expected_cardinality_one=request.expect_one, + implicit_limit=request.implicit_limit, + inline_typeids=request.inline_typeids, + inline_typenames=request.inline_typenames, + inline_objectids=request.inline_objectids, + json_parameters=request.input_format is enums.InputFormat.JSON, + source=request.source, + protocol_version=request.protocol_version, + cache_key=request.get_cache_key(), ) - unit_group = compile(ctx=ctx, source=source) + unit_group = compile(ctx=ctx, source=request.source) tx_started = False for unit in unit_group: if unit.tx_id: @@ -676,7 +658,7 @@ def compile( else: return unit_group, None - def compile_in_tx_request( + def compile_serialized_request_in_tx( self, state: dbstate.CompilerConnectionState, txid: int, @@ -686,11 +668,28 @@ def compile_in_tx_request( ) -> Tuple[ dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] ]: - request = rpc.CompilationRequest( - self.state.compilation_config_serializer + request = rpc.CompilationRequest.deserialize( + serialized_request, + original_query, + self.state.compilation_config_serializer, + ) + return self.compile_in_tx( + state=state, + txid=txid, + request=request, + expect_rollback=expect_rollback, ) - request.deserialize(serialized_request, original_query) + def compile_in_tx( + self, + *, + state: dbstate.CompilerConnectionState, + txid: int, + request: rpc.CompilationRequest, + expect_rollback: bool = False, + ) -> Tuple[ + dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] + ]: # Apply session differences if any if ( request.modaliases is not None @@ -703,39 +702,6 @@ def compile_in_tx_request( ): state.current_tx().update_session_config(session_config) - units, cstate = self.compile_in_tx( - state, - txid, - request.source, - request.output_format, - request.expect_one, - request.implicit_limit, - request.inline_typeids, - request.inline_typenames, - request.protocol_version, - request.inline_objectids, - request.json_parameters, - expect_rollback=expect_rollback, - cache_key=request.get_cache_key(), - ) - return units, cstate - - def compile_in_tx( - self, - state: dbstate.CompilerConnectionState, - txid: int, - source: edgeql.Source, - output_format: enums.OutputFormat, - expect_one: bool, - implicit_limit: int, - inline_typeids: bool, - inline_typenames: bool, - protocol_version: defines.ProtocolVersion, - inline_objectids: bool = True, - json_parameters: bool = False, - expect_rollback: bool = False, - cache_key: Optional[uuid.UUID] = None, - ) -> Tuple[dbstate.QueryUnitGroup, dbstate.CompilerConnectionState]: if ( expect_rollback and state.current_tx().id != txid and @@ -743,27 +709,28 @@ def compile_in_tx( ): # This is a special case when COMMIT MIGRATION fails, the compiler # doesn't have the right transaction state, so we just roll back. - return self._try_compile_rollback(source)[0], state + return self._try_compile_rollback(request.source)[0], state else: state.sync_tx(txid) ctx = CompileContext( compiler_state=self.state, state=state, - output_format=output_format, - expected_cardinality_one=expect_one, - implicit_limit=implicit_limit, - inline_typeids=inline_typeids, - inline_typenames=inline_typenames, - inline_objectids=inline_objectids, - source=source, - protocol_version=protocol_version, - json_parameters=json_parameters, + output_format=request.output_format, + expected_cardinality_one=request.expect_one, + implicit_limit=request.implicit_limit, + inline_typeids=request.inline_typeids, + inline_typenames=request.inline_typenames, + inline_objectids=request.inline_objectids, + source=request.source, + protocol_version=request.protocol_version, + json_parameters=request.input_format is enums.InputFormat.JSON, expect_rollback=expect_rollback, - cache_key=cache_key, + cache_key=request.get_cache_key(), ) - return compile(ctx=ctx, source=source), ctx.state + units = compile(ctx=ctx, source=request.source) + return units, ctx.state def interpret_backend_error( self, @@ -1278,7 +1245,14 @@ def compile_schema_storage_in_delta( # We drop first instead of using or_replace, in case # something about the arguments changed. df = pg_dbops.DropFunction( - name=func.name, args=func.args or (), if_exists=True + name=func.name, + args=func.args or (), + # Use a condition instead of if_exists ot reduce annoying + # debug spew from postgres. + conditions=[pg_dbops.FunctionExists( + name=func.name, + args=func.args or (), + )], ) df.generate(funcblock) diff --git a/edb/server/compiler/rpc.pxd b/edb/server/compiler/rpc.pxd index 946e91b1804..afb0c01ec77 100644 --- a/edb/server/compiler/rpc.pxd +++ b/edb/server/compiler/rpc.pxd @@ -29,7 +29,7 @@ cdef class CompilationRequest: readonly object source readonly object protocol_version readonly object output_format - readonly bint json_parameters + readonly object input_format readonly bint expect_one readonly int implicit_limit readonly bint inline_typeids @@ -46,4 +46,3 @@ cdef class CompilationRequest: object cache_key cdef _serialize(self) - cdef _deserialize_v0(self, bytes data, str query_text) diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index 758f76521c2..47b6d2fe464 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -29,7 +29,7 @@ class CompilationRequest: source: edgeql.Source protocol_version: defines.ProtocolVersion output_format: enums.OutputFormat - json_parameters: bool + input_format: enums.InputFormat expect_one: bool implicit_limit: int inline_typeids: bool @@ -41,7 +41,22 @@ class CompilationRequest: def __init__( self, + *, + source: edgeql.Source, + protocol_version: defines.ProtocolVersion, + schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, + output_format: enums.OutputFormat = enums.OutputFormat.BINARY, + input_format: enums.InputFormat = enums.InputFormat.BINARY, + expect_one: bool = False, + implicit_limit: int = 0, + inline_typeids: bool = False, + inline_typenames: bool = False, + inline_objectids: bool = True, + modaliases: typing.Mapping[str | None, str] | None = None, + session_config: typing.Mapping[str, config.SettingValue] | None = None, + database_config: typing.Mapping[str, config.SettingValue] | None = None, + system_config: typing.Mapping[str, config.SettingValue] | None = None, ): ... @@ -86,7 +101,13 @@ class CompilationRequest: def serialize(self) -> bytes: ... - def deserialize(self, data: bytes, query_text: str) -> CompilationRequest: + @classmethod + def deserialize( + cls, + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, + ) -> CompilationRequest: ... def get_cache_key(self) -> uuid.UUID: diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index fd956a9703e..d510e131774 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -16,6 +16,10 @@ # limitations under the License. # +from typing import ( + Mapping, +) + import hashlib import uuid @@ -77,35 +81,11 @@ cdef deserialize_output_format(char mode): cdef class CompilationRequest: def __cinit__( self, - compilation_config_serializer: sertypes.CompilationConfigSerializer, - ): - self._serializer = compilation_config_serializer - - def __copy__(self): - cdef CompilationRequest rv = CompilationRequest(self._serializer) - rv.source = self.source - rv.protocol_version = self.protocol_version - rv.output_format = self.output_format - rv.json_parameters = self.json_parameters - rv.expect_one = self.expect_one - rv.implicit_limit = self.implicit_limit - rv.inline_typeids = self.inline_typeids - rv.inline_typenames = self.inline_typenames - rv.inline_objectids = self.inline_objectids - rv.modaliases = self.modaliases - rv.session_config = self.session_config - rv.database_config = self.database_config - rv.system_config = self.system_config - rv.schema_version = self.schema_version - rv.serialized_cache = self.serialized_cache - rv.cache_key = self.cache_key - return rv - - def update( - self, + *, source: edgeql.Source, protocol_version: defines.ProtocolVersion, - *, + schema_version: uuid.UUID, + compilation_config_serializer: sertypes.CompilationConfigSerializer, output_format: enums.OutputFormat = OUT_FMT_BINARY, input_format: enums.InputFormat = IN_FMT_BINARY, expect_one: bint = False, @@ -113,20 +93,53 @@ cdef class CompilationRequest: inline_typeids: bint = False, inline_typenames: bint = False, inline_objectids: bint = True, - ) -> CompilationRequest: + modaliases: Mapping[str | None, str] | None = None, + session_config: Mapping[str, config.SettingValue] | None = None, + database_config: Mapping[str, config.SettingValue] | None = None, + system_config: Mapping[str, config.SettingValue] | None = None, + ): + self._serializer = compilation_config_serializer self.source = source self.protocol_version = protocol_version self.output_format = output_format - self.json_parameters = input_format is IN_FMT_JSON + self.input_format = input_format self.expect_one = expect_one self.implicit_limit = implicit_limit self.inline_typeids = inline_typeids self.inline_typenames = inline_typenames self.inline_objectids = inline_objectids + self.schema_version = schema_version + self.modaliases = modaliases + self.session_config = session_config + self.database_config = database_config + self.system_config = system_config self.serialized_cache = None self.cache_key = None - return self + + def __copy__(self): + cdef CompilationRequest rv + + rv = CompilationRequest( + source=self.source, + protocol_version=self.protocol_version, + schema_version=self.schema_version, + compilation_config_serializer=self._serializer, + output_format=self.output_format, + input_format=self.input_format, + expect_one=self.expect_one, + implicit_limit=self.implicit_limit, + inline_typeids=self.inline_typeids, + inline_typenames=self.inline_typenames, + inline_objectids=self.inline_objectids, + modaliases=self.modaliases, + session_config=self.session_config, + database_config=self.database_config, + system_config=self.system_config, + ) + rv.serialized_cache = self.serialized_cache + rv.cache_key = self.cache_key + return rv def set_modaliases(self, value) -> CompilationRequest: self.modaliases = value @@ -158,14 +171,22 @@ cdef class CompilationRequest: self.cache_key = None return self - def deserialize(self, bytes data, str query_text) -> CompilationRequest: + @classmethod + def deserialize( + cls, + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, + ) -> CompilationRequest: + buf = ReadBuffer.new_message_parser(data) + if data[0] == 0: - self._deserialize_v0(data, query_text) + return _deserialize_comp_req_v0( + buf, query_text, compilation_config_serializer) else: raise errors.UnsupportedProtocolVersionError( f"unsupported compile cache: version {data[0]}" ) - return self def serialize(self) -> bytes: if self.serialized_cache is None: @@ -178,182 +199,9 @@ cdef class CompilationRequest: return self.cache_key cdef _serialize(self): - # Please see _deserialize_v0 for the format doc - - cdef: - char version = 0, flags - WriteBuffer out = WriteBuffer.new() - - out.write_byte(version) - - flags = ( - (MASK_JSON_PARAMETERS if self.json_parameters else 0) | - (MASK_EXPECT_ONE if self.expect_one else 0) | - (MASK_INLINE_TYPEIDS if self.inline_typeids else 0) | - (MASK_INLINE_TYPENAMES if self.inline_typenames else 0) | - (MASK_INLINE_OBJECTIDS if self.inline_objectids else 0) - ) - out.write_byte(flags) - - out.write_int16(self.protocol_version[0]) - out.write_int16(self.protocol_version[1]) - out.write_byte(serialize_output_format(self.output_format)) - out.write_int64(self.implicit_limit) - - if self.modaliases is None: - out.write_int32(-1) - else: - out.write_int32(len(self.modaliases)) - for k, v in sorted( - self.modaliases.items(), - key=lambda i: (0, i[0]) if i[0] is None else (1, i[0]) - ): - if k is None: - out.write_byte(0) - else: - out.write_byte(1) - out.write_str(k, "utf-8") - out.write_str(v, "utf-8") - - type_id, desc = self._serializer.describe() - out.write_bytes(type_id.bytes) - out.write_len_prefixed_bytes(desc) - - hash_obj = hashlib.blake2b(memoryview(out), digest_size=16) - hash_obj.update(self.source.cache_key()) - - if self.session_config is None: - session_config = b"" - else: - session_config = self._serializer.encode_configs( - self.session_config - ) - out.write_len_prefixed_bytes(session_config) - - # Build config that affects compilation: session -> database -> system. - # This is only used for calculating cache_key, while session - # config itself is separately stored above in the serialized format. - serialized_comp_config = self._serializer.encode_configs( - self.system_config, self.database_config, self.session_config - ) - hash_obj.update(serialized_comp_config) - - # Must set_schema_version() before serializing compilation request - assert self.schema_version is not None - hash_obj.update(self.schema_version.bytes) - - cache_key_bytes = hash_obj.digest() - self.cache_key = uuidgen.from_bytes(cache_key_bytes) - - out.write_len_prefixed_bytes(self.source.serialize()) - out.write_bytes(cache_key_bytes) - out.write_bytes(self.schema_version.bytes) - - self.serialized_cache = bytes(out) - - cdef _deserialize_v0(self, bytes data, str query_text): - # Format: - # - # * 1 byte of version (0) - # * 1 byte of bit flags: - # * json_parameters - # * expect_one - # * inline_typeids - # * inline_typenames - # * inline_objectids - # * protocol_version (major: int64, minor: int16) - # * 1 byte output_format (the same as in the binary protocol) - # * implicit_limit: int64 - # * Module aliases: - # * length: int32 (negative means the modaliases is None) - # * For each alias pair: - # * 1 byte, 0 if the name is None - # * else, C-String as the name - # * C-String as the alias - # * Session config type descriptor - # * 16 bytes type ID - # * int32-length-prefixed serialized type descriptor - # * Session config: int32-length-prefixed serialized data - # * Serialized Source or NormalizedSource without the original query - # string - # * 16-byte cache key = BLAKE-2b hash of: - # * All above serialized, - # * Except that the source is replaced with Source.cache_key(), and - # * Except that the serialized session config is replaced by - # serialized combined config (session -> database -> system) - # that only affects compilation. - # * The schema version - # * OPTIONALLY, the schema version. We wanted to bump the protocol - # version to include this, but 5.x hard crashes when it reads a - # persistent cache with entries it doesn't understand, so instead - # we stick it on the end where it will be ignored by old versions. - - cdef char flags - - self.serialized_cache = data - - buf = ReadBuffer.new_message_parser(data) - - assert buf.read_byte() == 0 # version - - flags = buf.read_byte() - self.json_parameters = flags & MASK_JSON_PARAMETERS > 0 - self.expect_one = flags & MASK_EXPECT_ONE > 0 - self.inline_typeids = flags & MASK_INLINE_TYPEIDS > 0 - self.inline_typenames = flags & MASK_INLINE_TYPENAMES > 0 - self.inline_objectids = flags & MASK_INLINE_OBJECTIDS > 0 - - self.protocol_version = buf.read_int16(), buf.read_int16() - self.output_format = deserialize_output_format(buf.read_byte()) - self.implicit_limit = buf.read_int64() - - size = buf.read_int32() - if size >= 0: - modaliases = [] - for _ in range(size): - if buf.read_byte(): - k = buf.read_null_str().decode("utf-8") - else: - k = None - v = buf.read_null_str().decode("utf-8") - modaliases.append((k, v)) - self.modaliases = immutables.Map(modaliases) - else: - self.modaliases = None - - type_id = uuidgen.from_bytes(buf.read_bytes(16)) - if type_id == self._serializer.type_id: - serializer = self._serializer - buf.read_len_prefixed_bytes() - else: - serializer = sertypes.CompilationConfigSerializer( - type_id, buf.read_len_prefixed_bytes(), defines.CURRENT_PROTOCOL - ) - self._serializer = serializer - - data = buf.read_len_prefixed_bytes() - if data: - self.session_config = immutables.Map( - ( - k, - config.SettingValue( - name=k, - value=v, - source='session', - scope=qltypes.ConfigScope.SESSION, - ) - ) for k, v in serializer.decode(data).items() - ) - else: - self.session_config = None - - self.source = tokenizer.deserialize( - buf.read_len_prefixed_bytes(), query_text - ) - self.cache_key = uuidgen.from_bytes(buf.read_bytes(16)) - - if buf._length >= 16: - self.schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + cache_key, buf = _serialize_comp_req_v0(self) + self.cache_key = cache_key + self.serialized_cache = bytes(buf) def __hash__(self): return hash(self.get_cache_key()) @@ -363,10 +211,212 @@ cdef class CompilationRequest: self.source.cache_key() == other.source.cache_key() and self.protocol_version == other.protocol_version and self.output_format == other.output_format and - self.json_parameters == other.json_parameters and + self.input_format == other.input_format and self.expect_one == other.expect_one and self.implicit_limit == other.implicit_limit and self.inline_typeids == other.inline_typeids and self.inline_typenames == other.inline_typenames and self.inline_objectids == other.inline_objectids ) + + +cdef _deserialize_comp_req_v0( + buf: ReadBuffer, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, +): + # Format: + # + # * 1 byte of version (0) + # * 1 byte of bit flags: + # * json_parameters + # * expect_one + # * inline_typeids + # * inline_typenames + # * inline_objectids + # * protocol_version (major: int64, minor: int16) + # * 1 byte output_format (the same as in the binary protocol) + # * implicit_limit: int64 + # * Module aliases: + # * length: int32 (negative means the modaliases is None) + # * For each alias pair: + # * 1 byte, 0 if the name is None + # * else, C-String as the name + # * C-String as the alias + # * Session config type descriptor + # * 16 bytes type ID + # * int32-length-prefixed serialized type descriptor + # * Session config: int32-length-prefixed serialized data + # * Serialized Source or NormalizedSource without the original query + # string + # * 16-byte cache key = BLAKE-2b hash of: + # * All above serialized, + # * Except that the source is replaced with Source.cache_key(), and + # * Except that the serialized session config is replaced by + # serialized combined config (session -> database -> system) + # that only affects compilation. + # * The schema version + # * OPTIONALLY, the schema version. We wanted to bump the protocol + # version to include this, but 5.x hard crashes when it reads a + # persistent cache with entries it doesn't understand, so instead + # we stick it on the end where it will be ignored by old versions. + + cdef char flags + + assert buf.read_byte() == 0 # version + + flags = buf.read_byte() + if flags & MASK_JSON_PARAMETERS > 0: + input_format = IN_FMT_JSON + else: + input_format = IN_FMT_BINARY + expect_one = flags & MASK_EXPECT_ONE > 0 + inline_typeids = flags & MASK_INLINE_TYPEIDS > 0 + inline_typenames = flags & MASK_INLINE_TYPENAMES > 0 + inline_objectids = flags & MASK_INLINE_OBJECTIDS > 0 + + protocol_version = buf.read_int16(), buf.read_int16() + output_format = deserialize_output_format(buf.read_byte()) + implicit_limit = buf.read_int64() + + size = buf.read_int32() + if size >= 0: + modaliases = [] + for _ in range(size): + if buf.read_byte(): + k = buf.read_null_str().decode("utf-8") + else: + k = None + v = buf.read_null_str().decode("utf-8") + modaliases.append((k, v)) + modaliases = immutables.Map(modaliases) + else: + modaliases = None + + serializer = compilation_config_serializer + type_id = uuidgen.from_bytes(buf.read_bytes(16)) + if type_id == serializer.type_id: + buf.read_len_prefixed_bytes() + else: + serializer = sertypes.CompilationConfigSerializer( + type_id, buf.read_len_prefixed_bytes(), defines.CURRENT_PROTOCOL + ) + + data = buf.read_len_prefixed_bytes() + if data: + session_config = immutables.Map( + ( + k, + config.SettingValue( + name=k, + value=v, + source='session', + scope=qltypes.ConfigScope.SESSION, + ) + ) for k, v in serializer.decode(data).items() + ) + else: + session_config = None + + source = tokenizer.deserialize( + buf.read_len_prefixed_bytes(), query_text + ) + + cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + + req = CompilationRequest( + source=source, + protocol_version=protocol_version, + schema_version=schema_version, + compilation_config_serializer=serializer, + output_format=output_format, + input_format=input_format, + expect_one=expect_one, + implicit_limit=implicit_limit, + inline_typeids=inline_typeids, + inline_typenames=inline_typenames, + inline_objectids=inline_objectids, + modaliases=modaliases, + session_config=session_config, + ) + + req.serialized_cache = data + req.cache_key = cache_key + + return req + + +cdef _serialize_comp_req_v0(req: CompilationRequest): + # Please see _deserialize_v0 for the format doc + + cdef: + char version = 0, flags + WriteBuffer out = WriteBuffer.new() + + out.write_byte(version) + + flags = ( + (MASK_JSON_PARAMETERS if req.input_format is IN_FMT_JSON else 0) | + (MASK_EXPECT_ONE if req.expect_one else 0) | + (MASK_INLINE_TYPEIDS if req.inline_typeids else 0) | + (MASK_INLINE_TYPENAMES if req.inline_typenames else 0) | + (MASK_INLINE_OBJECTIDS if req.inline_objectids else 0) + ) + out.write_byte(flags) + + out.write_int16(req.protocol_version[0]) + out.write_int16(req.protocol_version[1]) + out.write_byte(serialize_output_format(req.output_format)) + out.write_int64(req.implicit_limit) + + if req.modaliases is None: + out.write_int32(-1) + else: + out.write_int32(len(req.modaliases)) + for k, v in sorted( + req.modaliases.items(), + key=lambda i: (0, i[0]) if i[0] is None else (1, i[0]) + ): + if k is None: + out.write_byte(0) + else: + out.write_byte(1) + out.write_str(k, "utf-8") + out.write_str(v, "utf-8") + + type_id, desc = req._serializer.describe() + out.write_bytes(type_id.bytes) + out.write_len_prefixed_bytes(desc) + + hash_obj = hashlib.blake2b(memoryview(out), digest_size=16) + hash_obj.update(req.source.cache_key()) + + if req.session_config is None: + session_config = b"" + else: + session_config = req._serializer.encode_configs( + req.session_config + ) + out.write_len_prefixed_bytes(session_config) + + # Build config that affects compilation: session -> database -> system. + # This is only used for calculating cache_key, while session + # config itreq is separately stored above in the serialized format. + serialized_comp_config = req._serializer.encode_configs( + req.system_config, req.database_config, req.session_config + ) + hash_obj.update(serialized_comp_config) + + # Must set_schema_version() before serializing compilation request + assert req.schema_version is not None + hash_obj.update(req.schema_version.bytes) + + cache_key_bytes = hash_obj.digest() + cache_key = uuidgen.from_bytes(cache_key_bytes) + + out.write_len_prefixed_bytes(req.source.serialize()) + out.write_bytes(cache_key_bytes) + out.write_bytes(req.schema_version.bytes) + + return cache_key, out diff --git a/edb/server/compiler_pool/multitenant_worker.py b/edb/server/compiler_pool/multitenant_worker.py index 2fb3cc4cbda..105000467dc 100644 --- a/edb/server/compiler_pool/multitenant_worker.py +++ b/edb/server/compiler_pool/multitenant_worker.py @@ -28,6 +28,7 @@ from edb import graphql from edb.common import debug +from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler @@ -195,7 +196,7 @@ def compile( ): client_schema = clients[client_id] db = client_schema.dbs[dbname] - units, cstate = COMPILER.compile_request( + units, cstate = COMPILER.compile_serialized_request( db.user_schema, client_schema.global_schema, db.reflection_cache, @@ -225,6 +226,7 @@ def compile_in_tx( ): global LAST_STATE if cstate == state.REUSE_LAST_STATE_MARKER: + assert LAST_STATE is not None cstate = LAST_STATE else: cstate = pickle.loads(cstate) @@ -236,7 +238,8 @@ def compile_in_tx( client_schema = clients[client_id] db = client_schema.dbs[dbname] cstate.set_root_user_schema(db.user_schema) - units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_serialized_request_in_tx( + cstate, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate, -1) @@ -286,23 +289,30 @@ def compile_graphql( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) - unit_group, _ = COMPILER.compile( - user_schema=db.user_schema, - global_schema=client_schema.global_schema, - reflection_cache=db.reflection_cache, - database_config=db.database_config, - system_config=client_schema.instance_config, + cfg_ser = COMPILER.state.compilation_config_serializer + request = compiler.CompilationRequest( source=source, - sess_modaliases=None, - sess_config=None, + protocol_version=defines.CURRENT_PROTOCOL, + schema_version=uuidgen.uuid4(), + compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, + input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, - json_parameters=True, - protocol_version=defines.CURRENT_PROTOCOL, + modaliases=None, + session_config=None, + ) + + unit_group, _ = COMPILER.compile( + user_schema=db.user_schema, + global_schema=client_schema.global_schema, + reflection_cache=db.reflection_cache, + database_config=db.database_config, + system_config=client_schema.instance_config, + request=request, ) return unit_group, gql_op diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index 2e9ac518c62..99e4708e5b1 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -26,6 +26,7 @@ from edb import edgeql from edb import graphql +from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler @@ -175,7 +176,7 @@ def compile( system_config, ) - units, cstate = COMPILER.compile_request( + units, cstate = COMPILER.compile_serialized_request( db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, @@ -199,6 +200,7 @@ def compile_in_tx( ): global LAST_STATE if cstate == state.REUSE_LAST_STATE_MARKER: + assert LAST_STATE is not None cstate = LAST_STATE else: cstate = pickle.loads(cstate) @@ -207,7 +209,8 @@ def compile_in_tx( cstate.set_root_user_schema(pickle.loads(user_schema)) else: cstate.set_root_user_schema(DBS[dbname].user_schema) - units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_serialized_request_in_tx( + cstate, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate, -1) @@ -275,23 +278,30 @@ def compile_graphql( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) - unit_group, _ = COMPILER.compile( - user_schema=db.user_schema, - global_schema=GLOBAL_SCHEMA, - reflection_cache=db.reflection_cache, - database_config=db.database_config, - system_config=INSTANCE_CONFIG, + cfg_ser = COMPILER.state.compilation_config_serializer + request = compiler.CompilationRequest( source=source, - sess_modaliases=None, - sess_config=None, + protocol_version=defines.CURRENT_PROTOCOL, + schema_version=uuidgen.uuid4(), + compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, + input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, - json_parameters=True, - protocol_version=defines.CURRENT_PROTOCOL, + modaliases=None, + session_config=None, + ) + + unit_group, _ = COMPILER.compile( + user_schema=db.user_schema, + global_schema=GLOBAL_SCHEMA, + reflection_cache=db.reflection_cache, + database_config=db.database_config, + system_config=INSTANCE_CONFIG, + request=request, ) return unit_group, gql_op diff --git a/edb/server/connpool/__init__.py b/edb/server/connpool/__init__.py index 6ce2fa62bd7..dda78b56bd7 100644 --- a/edb/server/connpool/__init__.py +++ b/edb/server/connpool/__init__.py @@ -22,7 +22,7 @@ # During the transition period we allow for the pool to be swapped out. The # current default is to use the old pool, however this will be switched to use # the new pool once we've fully implemented all required features. -if os.environ.get("EDGEDB_USE_NEW_CONNPOOL", "") == "0": +if os.environ.get("EDGEDB_USE_NEW_CONNPOOL", "") == "1": Pool = Pool2Impl Pool2 = Pool1Impl else: diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 2c6f3e18e32..ed376cf16ba 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -419,9 +419,11 @@ cdef class Database: def hydrate_cache(self, query_cache): for _, in_data, out_data in query_cache: - query_req = rpc.CompilationRequest( - self.server.compilation_config_serializer) - query_req.deserialize(in_data, "") + query_req = rpc.CompilationRequest.deserialize( + in_data, + "", + self.server.compilation_config_serializer, + ) if query_req not in self._eql_to_compiled: unit = dbstate.QueryUnit.deserialize(out_data) diff --git a/edb/server/http.py b/edb/server/http.py new file mode 100644 index 00000000000..a01c2f82899 --- /dev/null +++ b/edb/server/http.py @@ -0,0 +1,191 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import ( + Tuple, + Any, + Mapping, + Optional, +) + +import asyncio +import dataclasses +import logging +import os +import json as json_lib +import urllib.parse + +from edb.server._http import Http + +logger = logging.getLogger("edb.server") + + +class HttpClient: + def __init__(self, limit: int): + self._client = Http(limit) + self._fd = self._client._fd + self._task = None + self._skip_reads = 0 + self._loop = asyncio.get_running_loop() + self._task = self._loop.create_task(self._boot(self._loop)) + self._next_id = 0 + self._requests: dict[int, asyncio.Future] = {} + + def __del__(self) -> None: + if self._task: + self._task.cancel() + self._task = None + + def _update_limit(self, limit: int): + self._client._update_limit(limit) + + async def request( + self, + *, + method: str, + url: str, + content: bytes | None, + headers: list[tuple[str, str]] | None, + ) -> tuple[int, bytes, dict[str, str]]: + if content is None: + content = bytes() + if headers is None: + headers = [] + id = self._next_id + self._next_id += 1 + self._requests[id] = asyncio.Future() + try: + self._client._request(id, url, method, content, headers) + resp = await self._requests[id] + return resp + finally: + del self._requests[id] + + async def get( + self, path: str, *, headers: dict[str, str] | None = None + ) -> Response: + headers_list = [(k, v) for k, v in headers.items()] if headers else None + result = await self.request( + method="GET", url=path, content=None, headers=headers_list + ) + return Response.from_tuple(result) + + async def post( + self, + path: str, + *, + headers: dict[str, str] | None = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> Response: + if json is not None: + data = json_lib.dumps(json).encode('utf-8') + headers = headers or {} + headers['Content-Type'] = 'application/json' + elif isinstance(data, str): + data = data.encode('utf-8') + elif isinstance(data, dict): + data = urllib.parse.urlencode(data).encode('utf-8') + headers = headers or {} + headers['Content-Type'] = 'application/x-www-form-urlencoded' + + headers_list = [(k, v) for k, v in headers.items()] if headers else None + result = await self.request( + method="POST", url=path, content=data, headers=headers_list + ) + return Response.from_tuple(result) + + async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: + logger.info("Python-side HTTP client booted") + reader = asyncio.StreamReader(loop=loop) + reader_protocol = asyncio.StreamReaderProtocol(reader) + fd = os.fdopen(self._client._fd, 'rb') + transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) + try: + while len(await reader.read(1)) == 1: + if not self._client or not self._task: + break + if self._skip_reads > 0: + self._skip_reads -= 1 + continue + msg = self._client._read() + if not msg: + break + self._process_message(msg) + finally: + transport.close() + + def _process_message(self, msg): + msg_type, id, data = msg + + if id in self._requests: + if msg_type == 1: + self._requests[id].set_result(data) + elif msg_type == 0: + self._requests[id].set_exception(Exception(data)) + + +class CaseInsensitiveDict(dict): + def __init__(self, data: Optional[list[Tuple[str, str]]] = None): + super().__init__() + if data: + for k, v in data: + self[k.lower()] = v + + def __setitem__(self, key: str, value: str): + super().__setitem__(key.lower(), value) + + def __getitem__(self, key: str): + return super().__getitem__(key.lower()) + + def get(self, key: str, default=None): + return super().get(key.lower(), default) + + def update(self, *args, **kwargs: str) -> None: + if args: + data = args[0] + if isinstance(data, Mapping): + for key, value in data.items(): + self[key] = value + else: + for key, value in data: + self[key] = value + for key, value in kwargs.items(): + self[key] = value + + +@dataclasses.dataclass +class Response: + status_code: int + body: bytes + headers: CaseInsensitiveDict + + @classmethod + def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]): + status_code, body, headers_list = data + headers = CaseInsensitiveDict([(k, v) for k, v in headers_list.items()]) + return cls(status_code, body, headers) + + def json(self): + return json_lib.loads(self.body.decode('utf-8')) + + @property + def text(self) -> str: + return self.body.decode('utf-8') diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 9038b3c230b..8c82fa75e84 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -24,12 +24,11 @@ import asyncio import logging import base64 -import os from edb.ir import statypes from edb.server import defines from edb.server.protocol import execute -from edb.server._http import Http +from edb.server.http import HttpClient from edb.common import retryloop from . import dbview @@ -47,10 +46,10 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: - net_http_max_connections = tenant._server.config_lookup( - 'net_http_max_connections', tenant.get_sys_config() + http_max_connections = tenant._server.config_lookup( + 'http_max_connections', tenant.get_sys_config() ) - http_client._update_limit(net_http_max_connections) + http_client._update_limit(http_max_connections) try: async with (asyncio.TaskGroup() as g,): for db in tenant.iter_dbs(): @@ -99,82 +98,11 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: ) -class HttpClient: - def __init__(self, limit: int): - self._client = Http(limit) - self._fd = self._client._fd - self._task = None - self._skip_reads = 0 - self._loop = asyncio.get_running_loop() - self._task = self._loop.create_task(self._boot(self._loop)) - self._next_id = 0 - self._requests: dict[int, asyncio.Future] = {} - - def __del__(self) -> None: - if self._task: - self._task.cancel() - self._task = None - - def _update_limit(self, limit: int): - self._client._update_limit(limit) - - async def request( - self, - *, - method: str, - url: str, - content: bytes | None, - headers: list[tuple[str, str]] | None, - ): - if content is None: - content = bytes() - if headers is None: - headers = [] - id = self._next_id - self._next_id += 1 - self._requests[id] = asyncio.Future() - try: - self._client._request(id, url, method, content, headers) - resp = await self._requests[id] - return resp - finally: - del self._requests[id] - - async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: - logger.info("Python-side HTTP client booted") - reader = asyncio.StreamReader(loop=loop) - reader_protocol = asyncio.StreamReaderProtocol(reader) - fd = os.fdopen(self._client._fd, 'rb') - transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) - try: - while len(await reader.read(1)) == 1: - if not self._client or not self._task: - break - if self._skip_reads > 0: - self._skip_reads -= 1 - continue - msg = self._client._read() - if not msg: - break - self._process_message(msg) - finally: - transport.close() - - def _process_message(self, msg): - msg_type, id, data = msg - - if id in self._requests: - if msg_type == 1: - self._requests[id].set_result(data) - elif msg_type == 0: - self._requests[id].set_exception(Exception(data)) - - def create_http(tenant: edbtenant.Tenant): - net_http_max_connections = tenant._server.config_lookup( - 'net_http_max_connections', tenant.get_sys_config() + http_max_connections = tenant._server.config_lookup( + 'http_max_connections', tenant.get_sys_config() ) - return HttpClient(net_http_max_connections) + return HttpClient(http_max_connections) async def http(server: edbserver.BaseServer) -> None: @@ -337,7 +265,11 @@ def _warn(e): ) async for iteration in rloop: async with iteration: - result = await execute.parse_execute_json( + if not db.tenant.is_database_connectable(db.name): + # Don't run the net_worker if the database is not + # connectable, e.g. being dropped + continue + result_json = await execute.parse_execute_json( db, """ with requests := ( @@ -346,13 +278,15 @@ def _warn(e): and (datetime_of_statement() - .updated_at) > $expires_in ) - delete requests; + select count((delete requests)); """, variables={"expires_in": expires_in.to_backend_str()}, cached_globally=True, + tx_isolation=defines.TxIsolationLevel.RepeatableRead, ) - if len(result) > 0: - logger.info(f"Deleted requests: {result!r}") + result: list[int] = json.loads(result_json) + if result[0] > 0: + logger.info(f"Deleted {result[0]} requests") else: logger.info(f"No requests to delete") @@ -382,7 +316,8 @@ async def gc(server: edbserver.BaseServer) -> None: if tenant.accept_new_tasks ] try: - await asyncio.wait(tasks) + if tasks: + await asyncio.wait(tasks) except Exception as ex: logger.debug( "GC of std::net::http::ScheduledRequest failed", exc_info=ex diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index e97edbc8c4f..d42bac0cbf8 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -37,7 +37,7 @@ def __init__( client_secret: str, *, additional_scope: str | None, - http_factory: Callable[..., http_client.HttpClient], + http_factory: Callable[..., http_client.AuthHttpClient], ): self.name = name self.issuer_url = issuer_url diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 6457f62df59..2837edfd241 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -31,7 +31,15 @@ import mimetypes import uuid -from typing import Any, Optional, Tuple, FrozenSet, cast, TYPE_CHECKING +from typing import ( + Any, + Optional, + Tuple, + FrozenSet, + cast, + TYPE_CHECKING, + Callable, +) import aiosmtplib from jwcrypto import jwk, jwt @@ -80,6 +88,27 @@ def __init__( self.tenant = tenant self.test_mode = tenant.server.in_test_mode() + def _get_url_munger( + self, request: protocol.HttpRequest + ) -> Callable[[str], str] | None: + """ + Returns a callable that can be used to modify the base URL + when making requests to the OAuth provider. + + This is used to redirect requests to the test OAuth provider + when running in test mode. + """ + if not self.test_mode: + return None + test_url = ( + request.params[b'oauth-test-server'].decode() + if (request.params and b'oauth-test-server' in request.params) + else None + ) + if test_url: + return lambda path: f"{test_url}{urllib.parse.quote(path)}" + return None + async def handle_request( self, request: protocol.HttpRequest, @@ -270,7 +299,10 @@ async def handle_authorize( query, "challenge", fallback_keys=["code_challenge"] ) oauth_client = oauth.Client( - db=self.db, provider_name=provider_name, base_url=self.test_url + db=self.db, + provider_name=provider_name, + url_munger=self._get_url_munger(request), + http_client=self.tenant.get_http_client(), ) await pkce.create(self.db, challenge) authorize_url = await oauth_client.get_authorize_url( @@ -369,7 +401,8 @@ async def handle_callback( oauth_client = oauth.Client( db=self.db, provider_name=provider_name, - base_url=self.test_url, + url_munger=self._get_url_munger(request), + http_client=self.tenant.get_http_client(), ) ( identity, diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py index 52302b97c31..563605b5900 100644 --- a/edb/server/protocol/auth_ext/http_client.py +++ b/edb/server/protocol/auth_ext/http_client.py @@ -16,51 +16,49 @@ # limitations under the License. # -from typing import Any -import urllib.parse +from typing import Any, Callable, Self -import hishel -import httpx +from edb.server import http -class HttpClient(httpx.AsyncClient): +class AuthHttpClient: def __init__( self, - *args: Any, - edgedb_test_url: str | None, - base_url: str, - **kwargs: Any, + http_client: http.HttpClient, + url_munger: Callable[[str], str] | None = None, + base_url: str | None = None, ): - self.edgedb_orig_base_url = None - if edgedb_test_url: - self.edgedb_orig_base_url = urllib.parse.quote(base_url, safe='') - base_url = edgedb_test_url - cache = hishel.AsyncCacheTransport( - transport=httpx.AsyncHTTPTransport(), - storage=hishel.AsyncInMemoryStorage(capacity=5), - ) - super().__init__( - *args, base_url=base_url, transport=cache, **kwargs - ) + self.url_munger = url_munger + self.http_client = http_client + self.base_url = base_url - async def post( # type: ignore[override] + async def post( self, path: str, - *args: Any, - **kwargs: Any, - ) -> httpx.Response: - if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}{path}' - return await super().post( - path, *args, **kwargs + *, + headers: dict[str, str] | None = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> http.Response: + if self.base_url: + path = self.base_url + path + if self.url_munger: + path = self.url_munger(path) + return await self.http_client.post( + path, headers=headers, data=data, json=json ) - async def get( # type: ignore[override] - self, - path: str, - *args: Any, - **kwargs: Any, - ) -> httpx.Response: - if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}{path}' - return await super().get(path, *args, **kwargs) + async def get( + self, path: str, *, headers: dict[str, str] | None = None + ) -> http.Response: + if self.base_url: + path = self.base_url + path + if self.url_munger: + path = self.url_munger(path) + return await self.http_client.get(path, headers=headers) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args) -> None: # type: ignore + pass diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index a34d49f2fae..73d1cf03666 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -19,23 +19,29 @@ import json -from typing import cast, Any +from typing import cast, Any, Callable from edb.server.protocol import execute +from edb.server.http import HttpClient from . import github, google, azure, apple, discord, slack -from . import config, errors, util, data, base, http_client +from . import config, errors, util, data, base, http_client as _http_client class Client: provider: base.BaseProvider def __init__( - self, db: Any, provider_name: str, base_url: str | None = None + self, + *, + db: Any, + provider_name: str, + http_client: HttpClient, + url_munger: Callable[[str], str] | None = None, ): self.db = db - http_factory = lambda *args, **kwargs: http_client.HttpClient( - *args, edgedb_test_url=base_url, **kwargs + http_factory = lambda *args, **kwargs: _http_client.AuthHttpClient( + *args, url_munger=url_munger, http_client=http_client, **kwargs # type: ignore ) provider_config = self._get_provider_config(provider_name) diff --git a/edb/server/protocol/auth_ext/pkce.py b/edb/server/protocol/auth_ext/pkce.py index 028a6e9e162..a8acfdc5752 100644 --- a/edb/server/protocol/auth_ext/pkce.py +++ b/edb/server/protocol/auth_ext/pkce.py @@ -158,23 +158,29 @@ async def delete(db: edbtenant.dbview.Database, id: str) -> None: assert len(result_json) == 1 +async def _delete_challenge(db: edbtenant.dbview.Database) -> None: + if not db.tenant.is_database_connectable(db.name): + # Don't run gc if the database is not connectable, e.g. being dropped + return + + await execute.parse_execute_json( + db, + """ + delete ext::auth::PKCEChallenge filter + (datetime_of_statement() - .created_at) > + $validity + """, + variables={"validity": VALIDITY.to_backend_str()}, + cached_globally=True, + ) + + async def _gc(tenant: edbtenant.Tenant) -> None: try: async with asyncio.TaskGroup() as g: for db in tenant.iter_dbs(): if "auth" in db.extensions: - g.create_task( - execute.parse_execute_json( - db, - """ - delete ext::auth::PKCEChallenge filter - (datetime_of_statement() - .created_at) > - $validity - """, - variables={"validity": VALIDITY.to_backend_str()}, - cached_globally=True, - ), - ) + g.create_task(_delete_challenge(db)) except Exception as ex: logger.debug( "GC of ext::auth::PKCEChallenge failed (instance: %s)", diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 5af1d14f626..72a8f265d9b 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -801,21 +801,23 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.write(self.make_state_data_description_msg()) raise - rv = rpc.CompilationRequest(self.server.compilation_config_serializer) - rv.update( - self._tokenize(query), - self.protocol_version, + cfg_ser = self.server.compilation_config_serializer + rv = rpc.CompilationRequest( + source=self._tokenize(query), + protocol_version=self.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, output_format=output_format, expect_one=expect_one, implicit_limit=implicit_limit, inline_typeids=inline_typeids, inline_typenames=inline_typenames, inline_objectids=inline_objectids, - ).set_schema_version(_dbview.schema_version) - rv.set_modaliases(_dbview.get_modaliases()) - rv.set_session_config(_dbview.get_session_config()) - rv.set_database_config(_dbview.get_database_config()) - rv.set_system_config(_dbview.get_compilation_system_config()) + modaliases=_dbview.get_modaliases(), + session_config=_dbview.get_session_config(), + database_config=_dbview.get_database_config(), + system_config=_dbview.get_compilation_system_config(), + ) return rv, allow_capabilities async def parse(self): @@ -1430,12 +1432,13 @@ cdef class EdgeConnection(frontend.FrontendConnection): async def _execute_utility_stmt(self, eql: str, pgcon): cdef dbview.DatabaseConnectionView _dbview = self.get_dbview() + cfg_ser = self.server.compilation_config_serializer query_req = rpc.CompilationRequest( - self.server.compilation_config_serializer + source=edgeql.Source.from_string(eql), + protocol_version=self.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, ) - query_req.update( - edgeql.Source.from_string(eql), self.protocol_version - ).set_schema_version(_dbview.schema_version) compiled = await _dbview.parse(query_req) query_unit_group = compiled.query_unit_group @@ -1826,14 +1829,15 @@ async def run_script( await conn._start_connection(database) try: _dbview = conn.get_dbview() + cfg_ser = server.compilation_config_serializer compiled = await _dbview.parse( rpc.CompilationRequest( - server.compilation_config_serializer - ).update( - edgeql.Source.from_string(script), - conn.protocol_version, + source=edgeql.Source.from_string(script), + protocol_version=conn.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, output_format=FMT_NONE, - ).set_schema_version(_dbview.schema_version) + ), ) if len(compiled.query_unit_group) > 1: await conn._execute_script(compiled, b'') diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 2717aabea65..7814ed1495d 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -206,13 +206,13 @@ async def _parse( ) query_req = rpc.CompilationRequest( - db.server.compilation_config_serializer - ).update( - edgeql.Source.from_string(query), + source=edgeql.Source.from_string(query), protocol_version=edbdef.CURRENT_PROTOCOL, + schema_version=dbv.schema_version, + compilation_config_serializer=db.server.compilation_config_serializer, input_format=input_format, output_format=output_format, - ).set_schema_version(dbv.schema_version) + ) compiled = await dbv.parse( query_req, diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 7e2c6528afc..d0ce0c0ecb1 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -68,6 +68,7 @@ from .ha import adaptive as adaptive_ha from .ha import base as ha_base +from .http import HttpClient from .pgcon import errors as pgcon_errors if TYPE_CHECKING: @@ -80,6 +81,9 @@ logger = logging.getLogger("edb.server") +HTTP_MAX_CONNECTIONS = 100 + + class RoleDescriptor(TypedDict): superuser: bool name: str @@ -132,6 +136,8 @@ class Tenant(ha_base.ClusterProtocol): _jwt_revocation_list_file: pathlib.Path | None _jwt_revocation_list: frozenset[str] | None + _http_client: HttpClient | None + def __init__( self, cluster: pgcluster.BaseCluster, @@ -203,6 +209,8 @@ def __init__( self._jwt_revocation_list_file = None self._jwt_revocation_list = None + self._http_client = None + # If it isn't stored in instdata, it is the old default. self.default_database = defines.EDGEDB_OLD_DEFAULT_DB @@ -238,6 +246,11 @@ def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() + def get_http_client(self) -> HttpClient: + if self._http_client is None: + self._http_client = HttpClient(HTTP_MAX_CONNECTIONS) + return self._http_client + def on_switch_over(self): # Bumping this serial counter will "cancel" all pending connections # to the old master. @@ -993,7 +1006,7 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None: conns = await pgcon.sql_fetch_col( b""" SELECT - pid + row_to_json(pg_stat_activity) FROM pg_stat_activity WHERE @@ -1003,8 +1016,14 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None: ) if conns: + debug_info = "" + if self.server.in_dev_mode() or self.server.in_test_mode(): + jconns = [json.loads(conn) for conn in conns] + debug_info = ": " + json.dumps(jconns) + raise errors.ExecutionError( - f"database branch {dbname!r} is being accessed by other users" + f"database branch {dbname!r} is being accessed by " + f"other users{debug_info}" ) @contextlib.asynccontextmanager @@ -1222,6 +1241,18 @@ async def _early_introspect_db(self, dbname: str) -> None: early=True, ) + # Early introspection runs *before* we start accepting tasks. + # This means that if we are one of multiple frontends, and we + # get a ensure-database-not-used message, we aren't able to + # handle it. This can result in us hanging onto a connection + # that another frontend wants to get rid of. + # + # We still want to use the pool, though, since it limits our + # connections in the way we want. + # + # Hack around this by pruning the connection ourself. + await self._pg_pool.prune_inactive_connections(dbname) + async def _introspect_dbs(self) -> None: async with self.use_sys_pgcon() as syscon: dbnames = await self._server.get_dbnames(syscon) diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 99fc8d9f4ff..52a80af9c2e 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -330,9 +330,12 @@ def log_message(self, *args): class MultiHostMockHttpServerHandler(MockHttpServerHandler): def get_server_and_path(self) -> tuple[str, str]: - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) - return server, path + # Path looks like: + # http://127.0.0.1:32881/https%3A//slack.com/.well-known/openid-configuration + raw_url = urllib.parse.unquote(self.path.lstrip('/')) + url = urllib.parse.urlparse(raw_url) + return (f'{url.scheme}://{url.netloc}', + url.path.lstrip('/')) ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] @@ -416,9 +419,10 @@ def handle_request( body=body, ) self.requests[key].append(request_details) - if key not in self.routes: - handler.send_error(404) + error_message = (f"No route handler for {key}\n\n" + f"Available routes:\n{self.routes}") + handler.send_error(404, message=error_message) return registered_handler = self.routes[key] diff --git a/tests/test_edgeql_ddl.py b/tests/test_edgeql_ddl.py index 969f435c806..ac9713639d9 100644 --- a/tests/test_edgeql_ddl.py +++ b/tests/test_edgeql_ddl.py @@ -13664,6 +13664,35 @@ async def test_edgeql_ddl_errors_03(self): DROP FUNCTION foo___1(a: int64); ''') + @staticmethod + def order_migrations(migrations): + # Migrations are implicitly ordered based on parent_id. + # For now, assume that there is a single "initial" migration + # and all migrations have at most one child. + + # Find initial migration with no parents. + ordered = [ + migration + for migration in migrations + if not migration['parents'] + ] + + # Repeatedly find descendents until no more can be found. + prev_ids = [migration['id'] for migration in ordered] + while prev_ids: + curr_migrations = [ + migration + for migration in migrations + if any( + parent['id'] in prev_ids + for parent in migration['parents'] + ) + ] + ordered.extend(curr_migrations) + prev_ids = [migration['id'] for migration in curr_migrations] + + return ordered + async def test_edgeql_ddl_migration_sdl_01(self): await self.con.execute(''' CONFIGURE SESSION SET store_migration_sdl := @@ -13696,42 +13725,12 @@ async def test_edgeql_ddl_migration_sdl_01(self): drop type A; ''') - # Migrations implicitly ordered based on parent_id. - # Fetch the migrations and sort them here. - migrations = json.loads(await self.con.query_json( - 'select schema::Migration { id, parent_ids := .parents.id, sdl }' - )) - - # Find migrations with no parents. - sdl = [ - migration['sdl'] - for migration in migrations - if not migration['parent_ids'] - ] - - # Repeatedly find descendents until no more can be found. - prev_ids = [ - migration['id'] - for migration in migrations - if not migration['parent_ids'] - ] - while prev_ids: - sdl.extend( - migration['sdl'] - for migration in migrations - if any( - parent_id in prev_ids - for parent_id in migration['parent_ids'] - ) - ) - prev_ids = [ - migration['id'] - for migration in migrations - if any( - parent_id in prev_ids - for parent_id in migration['parent_ids'] - ) - ] + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { id, parents: { id }, sdl } + ''')) + ) + sdl = [migration['sdl'] for migration in migrations] self.assert_data_shape( sdl, @@ -13823,8 +13822,15 @@ async def test_edgeql_ddl_create_migration_01(self): }] ) - await self.assert_query_result( - 'select schema::Migration { script, sdl }', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, parents: { id }, script, sdl + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'script': ( @@ -13871,8 +13877,15 @@ async def test_edgeql_ddl_create_migration_02(self): }; ''') - await self.assert_query_result( - 'select schema::Migration { script, sdl }', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, parents: { id }, script, sdl + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'script': ( @@ -13933,10 +13946,20 @@ async def test_edgeql_ddl_create_migration_03(self): }; ''') - await self.assert_query_result( - ''' - SELECT schema::Migration { message, generated_by, script, sdl } - ''', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, + parents: { id }, + message, + generated_by, + script, + sdl, + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'generated_by': 'DevMode', @@ -13964,10 +13987,20 @@ async def test_edgeql_ddl_create_migration_03(self): CREATE TYPE Type3 ''') - await self.assert_query_result( - ''' - SELECT schema::Migration { message, generated_by, script, sdl } - ''', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, + parents: { id }, + message, + generated_by, + script, + sdl, + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'generated_by': 'DevMode', diff --git a/tests/test_edgeql_sql_codegen.py b/tests/test_edgeql_sql_codegen.py index 3367889fd43..0020b93f009 100644 --- a/tests/test_edgeql_sql_codegen.py +++ b/tests/test_edgeql_sql_codegen.py @@ -42,6 +42,18 @@ class TestEdgeQLSQLCodegen(tb.BaseEdgeQLCompilerTest): SCHEMA_cards = os.path.join(os.path.dirname(__file__), 'schemas', 'cards.esdl') + @classmethod + def get_schema_script(cls): + script = super().get_schema_script() + # Setting internal params like is_inlined in the schema + # doesn't work right so we override the script to add DDL. + return script + ''' + create function cards::ins_bot(name: str) -> cards::Bot { + set is_inlined := true; + using (insert cards::Bot { name := "asdf" }); + }; + ''' + def _compile_to_tree(self, source): qltree = qlparser.parse_query(source) ir = compiler.compile_ast_to_ir( @@ -461,3 +473,50 @@ def test_codegen_unless_conflict_03(self): "ON CONFLICT", sql, "insert unless conflict not using ON CONFLICT" ) + + def test_codegen_inlined_insert_01(self): + # Test that we don't use an overlay when selecting from a + # simple function that does an INSERT. + sql = self._compile(''' + WITH MODULE cards + select ins_bot("asdf") { id, name } + ''') + + table_obj = self.schema.get("cards::Bot") + count = sql.count(str(table_obj.id)) + # The table should only be referenced once, in the INSERT. + # If we reference it more than that, we're probably selecting it. + self.assertEqual( + count, + 1, + f"Bot selected from and not just inserted: {sql}") + + def test_codegen_inlined_insert_02(self): + # Test that we don't use an overlay when selecting from a + # net::http::schedule_request + sql = self._compile(''' + with + nh as module std::net::http, + url := $url, + request := ( + nh::schedule_request( + url, + method := nh::Method.`GET` + ) + ) + select request { + id, + state, + failure, + response, + } + ''') + + table_obj = self.schema.get("std::net::http::ScheduledRequest") + count = sql.count(str(table_obj.id)) + # The table should only be referenced once, in the INSERT. + # If we reference it more than that, we're probably selecting it. + self.assertEqual( + count, + 1, + f"ScheduledRequest selected from and not just inserted: {sql}") diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 90419421029..eaefccde7e6 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -362,6 +362,7 @@ def setUp(self): self.mock_net_server = tb.MockHttpServer() self.mock_net_server.start() + super().setUp() def tearDown(self): if self.mock_oauth_server is not None: @@ -369,6 +370,7 @@ def tearDown(self): if self.mock_net_server is not None: self.mock_net_server.stop() self.mock_oauth_server = None + super().tearDown() @classmethod def get_setup_script(cls): @@ -1043,7 +1045,7 @@ async def test_http_auth_ext_discord_callback_01(self): ) ) - user_request = ("GET", "https://discord.com/api/v10", "users/@me") + user_request = ("GET", "https://discord.com", "api/v10/users/@me") self.mock_oauth_server.register_route_handler(*user_request)( ( json.dumps( @@ -1417,8 +1419,8 @@ async def test_http_auth_ext_azure_authorize_01(self): discovery_request = ( "GET", - "https://login.microsoftonline.com/common/v2.0", - ".well-known/openid-configuration", + "https://login.microsoftonline.com", + "common/v2.0/.well-known/openid-configuration", ) self.mock_oauth_server.register_route_handler(*discovery_request)( ( @@ -1491,8 +1493,8 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: discovery_request = ( "GET", - "https://login.microsoftonline.com/common/v2.0", - ".well-known/openid-configuration", + "https://login.microsoftonline.com", + "common/v2.0/.well-known/openid-configuration", ) self.mock_oauth_server.register_route_handler(*discovery_request)( ( diff --git a/tests/test_http_std_net.py b/tests/test_http_std_net.py index 4e44088fa43..ced7103775a 100644 --- a/tests/test_http_std_net.py +++ b/tests/test_http_std_net.py @@ -27,7 +27,18 @@ class StdNetTestCase(tb.BaseHttpTest): mock_server: typing.Optional[tb.MockHttpServer] = None base_url: str + # Queries need to run outside of transactions here, but we *also* + # want them to be able to run concurrently against the same + # database (to exercise concurrency issues), so we also override + # parallelism granuality below. + TRANSACTION_ISOLATION = False + + @classmethod + def get_parallelism_granularity(cls): + return 'default' + def setUp(self): + super().setUp() self.mock_server = tb.MockHttpServer() self.mock_server.start() self.base_url = self.mock_server.get_base_url().rstrip("/") @@ -36,6 +47,7 @@ def tearDown(self): if self.mock_server is not None: self.mock_server.stop() self.mock_server = None + super().tearDown() async def _wait_for_request_completion(self, request_id: str): async for tr in self.try_until_succeeds( diff --git a/tests/test_server_compiler.py b/tests/test_server_compiler.py index d628999bd1c..c4f8f6ebfa5 100644 --- a/tests/test_server_compiler.py +++ b/tests/test_server_compiler.py @@ -33,6 +33,7 @@ from edb import edgeql from edb.testbase import lang as tb from edb.testbase import server as tbs +from edb.pgsql import params as pg_params from edb.server import args as edbargs from edb.server import compiler as edbcompiler from edb.server.compiler import rpc @@ -392,14 +393,16 @@ def setUpClass(cls): super().setUpClass() cls._std_schema = tb._load_std_schema() result = tb._load_reflection_schema() - cls._refl_schema, cls._schema_class_layout = result + cls._refl_schema, _schema_class_layout = result + assert _schema_class_layout is not None + cls._schema_class_layout = _schema_class_layout async def _test_pool_disconnect_queue(self, pool_class): with tempfile.TemporaryDirectory() as td: pool_ = await pool.create_compiler_pool( runstate_dir=td, pool_size=2, - backend_runtime_params=None, + backend_runtime_params=pg_params.get_default_runtime_params(), std_schema=self._std_schema, refl_schema=self._refl_schema, schema_class_layout=self._schema_class_layout, @@ -442,13 +445,14 @@ async def _test_pool_disconnect_queue(self, pool_class): ) orig_query = 'SELECT 123' + cfg_ser = compiler.state.compilation_config_serializer request = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).update( source=edgeql.Source.from_string(orig_query), protocol_version=(1, 0), + schema_version=uuid.uuid4(), + compilation_config_serializer=cfg_ser, implicit_limit=101, - ).set_schema_version(uuid.uuid4()) + ) await asyncio.gather(*(pool_.compile_in_tx( None, @@ -476,15 +480,15 @@ def test_server_compiler_rpc_hash_eq(self): ) def test(source: edgeql.Source): + cfg_ser = compiler.state.compilation_config_serializer request1 = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).update( source=source, protocol_version=(1, 0), - ).set_schema_version(uuid.uuid4()) - request2 = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).deserialize(request1.serialize(), "") + schema_version=uuid.uuid4(), + compilation_config_serializer=cfg_ser, + ) + request2 = rpc.CompilationRequest.deserialize( + request1.serialize(), "", cfg_ser) self.assertEqual(hash(request1), hash(request2)) self.assertEqual(request1, request2)