From a27ff6ac91d2c7432e0b2481104f662403327fad Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Tue, 22 Oct 2024 12:22:04 -0400 Subject: [PATCH 01/18] Add retries to std::net tests (#7898) --- tests/test_http_std_net.py | 137 ++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 64 deletions(-) diff --git a/tests/test_http_std_net.py b/tests/test_http_std_net.py index 4e44088fa43..826bf1ccf86 100644 --- a/tests/test_http_std_net.py +++ b/tests/test_http_std_net.py @@ -20,6 +20,7 @@ import typing import json +from edb import errors from edb.testbase import http as tb @@ -87,30 +88,34 @@ async def test_http_std_net_con_schedule_request_get_01(self): ) ) - result = await self.con.query_single( - """ - with - nh as module std::net::http, - net as module std::net, - url := $url, - request := ( - insert nh::ScheduledRequest { - created_at := datetime_of_statement(), - updated_at := datetime_of_statement(), - state := std::net::RequestState.Pending, - - url := url, - method := nh::Method.`GET`, - headers := [ - ("Accept", "application/json"), - ("x-test-header", "test-value"), - ], - } + async for tr in self.try_until_succeeds( + delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) + ): + async with tr: + result = await self.con.query_single( + """ + with + nh as module std::net::http, + net as module std::net, + url := $url, + request := ( + insert nh::ScheduledRequest { + created_at := datetime_of_statement(), + updated_at := datetime_of_statement(), + state := std::net::RequestState.Pending, + + url := url, + method := nh::Method.`GET`, + headers := [ + ("Accept", "application/json"), + ("x-test-header", "test-value"), + ], + } + ) + select request {*}; + """, + url=url, ) - select request {*}; - """, - url=url, - ) requests_for_example = None async for tr in self.try_until_succeeds( @@ -160,29 +165,32 @@ async def test_http_std_net_con_schedule_request_post_01(self): ) ) - result = await self.con.query_single( - """ - with - nh as module std::net::http, - net as module std::net, - url := $url, - body := $body, - request := ( - nh::schedule_request( - url, - method := nh::Method.POST, - headers := [ - ("Accept", "application/json"), - ("x-test-header", "test-value"), - ], - body := body, - ) + async for tr in self.try_until_succeeds( + delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) + ): + async with tr: + result = await self.con.query_single( + """ + with + nh as module std::net::http, + url := $url, + body := $body, + request := ( + nh::schedule_request( + url, + method := nh::Method.POST, + headers := [ + ("Accept", "application/json"), + ("x-test-header", "test-value"), + ], + body := body, + ) + ) + select request {*}; + """, + url=url, + body=b"Hello, world!", ) - select request {*}; - """, - url=url, - body=b"Hello, world!", - ) requests_for_example = None async for tr in self.try_until_succeeds( @@ -216,26 +224,27 @@ async def test_http_std_net_con_schedule_request_bad_address(self): # Test a request to a known-bad address bad_url = "http://256.256.256.256" - result = await self.con.query_single( - """ - with - nh as module std::net::http, - url := $url, - request := ( - nh::schedule_request( - url, - method := nh::Method.`GET` - ) + async for tr in self.try_until_succeeds( + delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) + ): + async with tr: + result = await self.con.query_single( + """ + with + nh as module std::net::http, + url := $url, + request := ( + nh::schedule_request(url) + ) + select request { + id, + state, + failure, + response, + }; + """, + url=bad_url, ) - select request { - id, - state, - failure, - response, - }; - """, - url=bad_url, - ) table_result = await self._wait_for_request_completion(result.id) self.assertEqual(str(table_result.state), 'Failed') From e65974fd6d2425e0390b8ae96fc374a120d1e4b7 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 22 Oct 2024 10:50:26 -0700 Subject: [PATCH 02/18] workflows: Pin Python to 3.12 when building on macOS (#7901) Python 3.13 is broken at the moment. --- .github/workflows.src/build.inc.yml | 2 +- .github/workflows/dryrun.yml | 4 ++-- .github/workflows/nightly.yml | 4 ++-- .github/workflows/release.yml | 4 ++-- .github/workflows/testing.yml | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows.src/build.inc.yml b/.github/workflows.src/build.inc.yml index 50e0291d780..81f7e495b1e 100644 --- a/.github/workflows.src/build.inc.yml +++ b/.github/workflows.src/build.inc.yml @@ -152,7 +152,7 @@ uses: actions/setup-python@v5 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/dryrun.yml b/.github/workflows/dryrun.yml index c3355e44caa..6743410f2b6 100644 --- a/.github/workflows/dryrun.yml +++ b/.github/workflows/dryrun.yml @@ -994,7 +994,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -1065,7 +1065,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 12f611932db..62c2c1a8f25 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -999,7 +999,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -1070,7 +1070,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 28ec1b4fbc7..9bdf49912f9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -541,7 +541,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -610,7 +610,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index eaa7275798a..a328b2aafc7 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -563,7 +563,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 @@ -633,7 +633,7 @@ jobs: uses: actions/setup-python@v5 if: true with: - python-version: "3.x" + python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 From 2a7a4a6f92334533114cb4475ffc2689afb7dfde Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 11:48:29 -0700 Subject: [PATCH 03/18] Run StdNetTestCases in a more intentional state (#7903) They were accidentally being run in a mode without transaction isolation but still concurrently. That's actually probably good here, but set it up more intentionally. --- tests/test_http_std_net.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_http_std_net.py b/tests/test_http_std_net.py index 826bf1ccf86..c0cdf2a4196 100644 --- a/tests/test_http_std_net.py +++ b/tests/test_http_std_net.py @@ -28,7 +28,18 @@ class StdNetTestCase(tb.BaseHttpTest): mock_server: typing.Optional[tb.MockHttpServer] = None base_url: str + # Queries need to run outside of transactions here, but we *also* + # want them to be able to run concurrently against the same + # database (to exercise concurrency issues), so we also override + # parallelism granuality below. + TRANSACTION_ISOLATION = False + + @classmethod + def get_parallelism_granularity(cls): + return 'default' + def setUp(self): + super().setUp() self.mock_server = tb.MockHttpServer() self.mock_server.start() self.base_url = self.mock_server.get_base_url().rstrip("/") @@ -37,6 +48,7 @@ def tearDown(self): if self.mock_server is not None: self.mock_server.stop() self.mock_server = None + super().tearDown() async def _wait_for_request_completion(self, request_id: str): async for tr in self.try_until_succeeds( From f65d0dfc3274cbec72a7d8f7d3bf371d03687d62 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 22 Oct 2024 11:54:43 -0700 Subject: [PATCH 04/18] Initialize `CompilationRequest` in one step (#7897) The current two-step initialization of `CompilationRequest` seems unnecessary and error-prone. Have a normal constructor and make `deserialize` a classmethod instead. --- edb/server/compiler/compiler.py | 18 +- edb/server/compiler/rpc.pxd | 3 +- edb/server/compiler/rpc.pyi | 25 +- edb/server/compiler/rpc.pyx | 470 ++++++++++++++++++-------------- edb/server/dbview/dbview.pyx | 8 +- edb/server/protocol/binary.pyx | 40 +-- edb/server/protocol/execute.pyx | 8 +- tests/test_server_compiler.py | 19 +- 8 files changed, 335 insertions(+), 256 deletions(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index a3922c6db30..94f4e2a9006 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -577,10 +577,11 @@ def compile_request( ) -> Tuple[ dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] ]: - request = rpc.CompilationRequest( - self.state.compilation_config_serializer + request = rpc.CompilationRequest.deserialize( + serialized_request, + original_query, + self.state.compilation_config_serializer, ) - request.deserialize(serialized_request, original_query) units, cstate = self.compile( user_schema, @@ -598,7 +599,7 @@ def compile_request( request.inline_typenames, request.protocol_version, request.inline_objectids, - request.json_parameters, + request.input_format is enums.InputFormat.JSON, cache_key=request.get_cache_key(), ) return units, cstate @@ -684,10 +685,11 @@ def compile_in_tx_request( ) -> Tuple[ dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] ]: - request = rpc.CompilationRequest( - self.state.compilation_config_serializer + request = rpc.CompilationRequest.deserialize( + serialized_request, + original_query, + self.state.compilation_config_serializer, ) - request.deserialize(serialized_request, original_query) # Apply session differences if any if ( @@ -712,7 +714,7 @@ def compile_in_tx_request( request.inline_typenames, request.protocol_version, request.inline_objectids, - request.json_parameters, + request.input_format is enums.InputFormat.JSON, expect_rollback=expect_rollback, cache_key=request.get_cache_key(), ) diff --git a/edb/server/compiler/rpc.pxd b/edb/server/compiler/rpc.pxd index 946e91b1804..afb0c01ec77 100644 --- a/edb/server/compiler/rpc.pxd +++ b/edb/server/compiler/rpc.pxd @@ -29,7 +29,7 @@ cdef class CompilationRequest: readonly object source readonly object protocol_version readonly object output_format - readonly bint json_parameters + readonly object input_format readonly bint expect_one readonly int implicit_limit readonly bint inline_typeids @@ -46,4 +46,3 @@ cdef class CompilationRequest: object cache_key cdef _serialize(self) - cdef _deserialize_v0(self, bytes data, str query_text) diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index 758f76521c2..94cd289e4cc 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -29,7 +29,7 @@ class CompilationRequest: source: edgeql.Source protocol_version: defines.ProtocolVersion output_format: enums.OutputFormat - json_parameters: bool + input_format: enums.InputFormat expect_one: bool implicit_limit: int inline_typeids: bool @@ -41,6 +41,21 @@ class CompilationRequest: def __init__( self, + *, + source: edgeql.Source, + protocol_version: defines.ProtocolVersion, + output_format: enums.OutputFormat = enums.OutputFormat.BINARY, + input_format: enums.InputFormat = enums.InputFormat.BINARY, + expect_one: bool = False, + implicit_limit: int = 0, + inline_typeids: bool = False, + inline_typenames: bool = False, + inline_objectids: bool = True, + schema_version: uuid.UUID | None = None, + modaliases: typing.Mapping[str | None, str] | None = None, + session_config: typing.Mapping[str, config.SettingValue] | None = None, + database_config: typing.Mapping[str, config.SettingValue] | None = None, + system_config: typing.Mapping[str, config.SettingValue] | None = None, compilation_config_serializer: sertypes.CompilationConfigSerializer, ): ... @@ -86,7 +101,13 @@ class CompilationRequest: def serialize(self) -> bytes: ... - def deserialize(self, data: bytes, query_text: str) -> CompilationRequest: + @classmethod + def deserialize( + cls, + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, + ) -> CompilationRequest: ... def get_cache_key(self) -> uuid.UUID: diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index fd956a9703e..d510e131774 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -16,6 +16,10 @@ # limitations under the License. # +from typing import ( + Mapping, +) + import hashlib import uuid @@ -77,35 +81,11 @@ cdef deserialize_output_format(char mode): cdef class CompilationRequest: def __cinit__( self, - compilation_config_serializer: sertypes.CompilationConfigSerializer, - ): - self._serializer = compilation_config_serializer - - def __copy__(self): - cdef CompilationRequest rv = CompilationRequest(self._serializer) - rv.source = self.source - rv.protocol_version = self.protocol_version - rv.output_format = self.output_format - rv.json_parameters = self.json_parameters - rv.expect_one = self.expect_one - rv.implicit_limit = self.implicit_limit - rv.inline_typeids = self.inline_typeids - rv.inline_typenames = self.inline_typenames - rv.inline_objectids = self.inline_objectids - rv.modaliases = self.modaliases - rv.session_config = self.session_config - rv.database_config = self.database_config - rv.system_config = self.system_config - rv.schema_version = self.schema_version - rv.serialized_cache = self.serialized_cache - rv.cache_key = self.cache_key - return rv - - def update( - self, + *, source: edgeql.Source, protocol_version: defines.ProtocolVersion, - *, + schema_version: uuid.UUID, + compilation_config_serializer: sertypes.CompilationConfigSerializer, output_format: enums.OutputFormat = OUT_FMT_BINARY, input_format: enums.InputFormat = IN_FMT_BINARY, expect_one: bint = False, @@ -113,20 +93,53 @@ cdef class CompilationRequest: inline_typeids: bint = False, inline_typenames: bint = False, inline_objectids: bint = True, - ) -> CompilationRequest: + modaliases: Mapping[str | None, str] | None = None, + session_config: Mapping[str, config.SettingValue] | None = None, + database_config: Mapping[str, config.SettingValue] | None = None, + system_config: Mapping[str, config.SettingValue] | None = None, + ): + self._serializer = compilation_config_serializer self.source = source self.protocol_version = protocol_version self.output_format = output_format - self.json_parameters = input_format is IN_FMT_JSON + self.input_format = input_format self.expect_one = expect_one self.implicit_limit = implicit_limit self.inline_typeids = inline_typeids self.inline_typenames = inline_typenames self.inline_objectids = inline_objectids + self.schema_version = schema_version + self.modaliases = modaliases + self.session_config = session_config + self.database_config = database_config + self.system_config = system_config self.serialized_cache = None self.cache_key = None - return self + + def __copy__(self): + cdef CompilationRequest rv + + rv = CompilationRequest( + source=self.source, + protocol_version=self.protocol_version, + schema_version=self.schema_version, + compilation_config_serializer=self._serializer, + output_format=self.output_format, + input_format=self.input_format, + expect_one=self.expect_one, + implicit_limit=self.implicit_limit, + inline_typeids=self.inline_typeids, + inline_typenames=self.inline_typenames, + inline_objectids=self.inline_objectids, + modaliases=self.modaliases, + session_config=self.session_config, + database_config=self.database_config, + system_config=self.system_config, + ) + rv.serialized_cache = self.serialized_cache + rv.cache_key = self.cache_key + return rv def set_modaliases(self, value) -> CompilationRequest: self.modaliases = value @@ -158,14 +171,22 @@ cdef class CompilationRequest: self.cache_key = None return self - def deserialize(self, bytes data, str query_text) -> CompilationRequest: + @classmethod + def deserialize( + cls, + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, + ) -> CompilationRequest: + buf = ReadBuffer.new_message_parser(data) + if data[0] == 0: - self._deserialize_v0(data, query_text) + return _deserialize_comp_req_v0( + buf, query_text, compilation_config_serializer) else: raise errors.UnsupportedProtocolVersionError( f"unsupported compile cache: version {data[0]}" ) - return self def serialize(self) -> bytes: if self.serialized_cache is None: @@ -178,182 +199,9 @@ cdef class CompilationRequest: return self.cache_key cdef _serialize(self): - # Please see _deserialize_v0 for the format doc - - cdef: - char version = 0, flags - WriteBuffer out = WriteBuffer.new() - - out.write_byte(version) - - flags = ( - (MASK_JSON_PARAMETERS if self.json_parameters else 0) | - (MASK_EXPECT_ONE if self.expect_one else 0) | - (MASK_INLINE_TYPEIDS if self.inline_typeids else 0) | - (MASK_INLINE_TYPENAMES if self.inline_typenames else 0) | - (MASK_INLINE_OBJECTIDS if self.inline_objectids else 0) - ) - out.write_byte(flags) - - out.write_int16(self.protocol_version[0]) - out.write_int16(self.protocol_version[1]) - out.write_byte(serialize_output_format(self.output_format)) - out.write_int64(self.implicit_limit) - - if self.modaliases is None: - out.write_int32(-1) - else: - out.write_int32(len(self.modaliases)) - for k, v in sorted( - self.modaliases.items(), - key=lambda i: (0, i[0]) if i[0] is None else (1, i[0]) - ): - if k is None: - out.write_byte(0) - else: - out.write_byte(1) - out.write_str(k, "utf-8") - out.write_str(v, "utf-8") - - type_id, desc = self._serializer.describe() - out.write_bytes(type_id.bytes) - out.write_len_prefixed_bytes(desc) - - hash_obj = hashlib.blake2b(memoryview(out), digest_size=16) - hash_obj.update(self.source.cache_key()) - - if self.session_config is None: - session_config = b"" - else: - session_config = self._serializer.encode_configs( - self.session_config - ) - out.write_len_prefixed_bytes(session_config) - - # Build config that affects compilation: session -> database -> system. - # This is only used for calculating cache_key, while session - # config itself is separately stored above in the serialized format. - serialized_comp_config = self._serializer.encode_configs( - self.system_config, self.database_config, self.session_config - ) - hash_obj.update(serialized_comp_config) - - # Must set_schema_version() before serializing compilation request - assert self.schema_version is not None - hash_obj.update(self.schema_version.bytes) - - cache_key_bytes = hash_obj.digest() - self.cache_key = uuidgen.from_bytes(cache_key_bytes) - - out.write_len_prefixed_bytes(self.source.serialize()) - out.write_bytes(cache_key_bytes) - out.write_bytes(self.schema_version.bytes) - - self.serialized_cache = bytes(out) - - cdef _deserialize_v0(self, bytes data, str query_text): - # Format: - # - # * 1 byte of version (0) - # * 1 byte of bit flags: - # * json_parameters - # * expect_one - # * inline_typeids - # * inline_typenames - # * inline_objectids - # * protocol_version (major: int64, minor: int16) - # * 1 byte output_format (the same as in the binary protocol) - # * implicit_limit: int64 - # * Module aliases: - # * length: int32 (negative means the modaliases is None) - # * For each alias pair: - # * 1 byte, 0 if the name is None - # * else, C-String as the name - # * C-String as the alias - # * Session config type descriptor - # * 16 bytes type ID - # * int32-length-prefixed serialized type descriptor - # * Session config: int32-length-prefixed serialized data - # * Serialized Source or NormalizedSource without the original query - # string - # * 16-byte cache key = BLAKE-2b hash of: - # * All above serialized, - # * Except that the source is replaced with Source.cache_key(), and - # * Except that the serialized session config is replaced by - # serialized combined config (session -> database -> system) - # that only affects compilation. - # * The schema version - # * OPTIONALLY, the schema version. We wanted to bump the protocol - # version to include this, but 5.x hard crashes when it reads a - # persistent cache with entries it doesn't understand, so instead - # we stick it on the end where it will be ignored by old versions. - - cdef char flags - - self.serialized_cache = data - - buf = ReadBuffer.new_message_parser(data) - - assert buf.read_byte() == 0 # version - - flags = buf.read_byte() - self.json_parameters = flags & MASK_JSON_PARAMETERS > 0 - self.expect_one = flags & MASK_EXPECT_ONE > 0 - self.inline_typeids = flags & MASK_INLINE_TYPEIDS > 0 - self.inline_typenames = flags & MASK_INLINE_TYPENAMES > 0 - self.inline_objectids = flags & MASK_INLINE_OBJECTIDS > 0 - - self.protocol_version = buf.read_int16(), buf.read_int16() - self.output_format = deserialize_output_format(buf.read_byte()) - self.implicit_limit = buf.read_int64() - - size = buf.read_int32() - if size >= 0: - modaliases = [] - for _ in range(size): - if buf.read_byte(): - k = buf.read_null_str().decode("utf-8") - else: - k = None - v = buf.read_null_str().decode("utf-8") - modaliases.append((k, v)) - self.modaliases = immutables.Map(modaliases) - else: - self.modaliases = None - - type_id = uuidgen.from_bytes(buf.read_bytes(16)) - if type_id == self._serializer.type_id: - serializer = self._serializer - buf.read_len_prefixed_bytes() - else: - serializer = sertypes.CompilationConfigSerializer( - type_id, buf.read_len_prefixed_bytes(), defines.CURRENT_PROTOCOL - ) - self._serializer = serializer - - data = buf.read_len_prefixed_bytes() - if data: - self.session_config = immutables.Map( - ( - k, - config.SettingValue( - name=k, - value=v, - source='session', - scope=qltypes.ConfigScope.SESSION, - ) - ) for k, v in serializer.decode(data).items() - ) - else: - self.session_config = None - - self.source = tokenizer.deserialize( - buf.read_len_prefixed_bytes(), query_text - ) - self.cache_key = uuidgen.from_bytes(buf.read_bytes(16)) - - if buf._length >= 16: - self.schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + cache_key, buf = _serialize_comp_req_v0(self) + self.cache_key = cache_key + self.serialized_cache = bytes(buf) def __hash__(self): return hash(self.get_cache_key()) @@ -363,10 +211,212 @@ cdef class CompilationRequest: self.source.cache_key() == other.source.cache_key() and self.protocol_version == other.protocol_version and self.output_format == other.output_format and - self.json_parameters == other.json_parameters and + self.input_format == other.input_format and self.expect_one == other.expect_one and self.implicit_limit == other.implicit_limit and self.inline_typeids == other.inline_typeids and self.inline_typenames == other.inline_typenames and self.inline_objectids == other.inline_objectids ) + + +cdef _deserialize_comp_req_v0( + buf: ReadBuffer, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, +): + # Format: + # + # * 1 byte of version (0) + # * 1 byte of bit flags: + # * json_parameters + # * expect_one + # * inline_typeids + # * inline_typenames + # * inline_objectids + # * protocol_version (major: int64, minor: int16) + # * 1 byte output_format (the same as in the binary protocol) + # * implicit_limit: int64 + # * Module aliases: + # * length: int32 (negative means the modaliases is None) + # * For each alias pair: + # * 1 byte, 0 if the name is None + # * else, C-String as the name + # * C-String as the alias + # * Session config type descriptor + # * 16 bytes type ID + # * int32-length-prefixed serialized type descriptor + # * Session config: int32-length-prefixed serialized data + # * Serialized Source or NormalizedSource without the original query + # string + # * 16-byte cache key = BLAKE-2b hash of: + # * All above serialized, + # * Except that the source is replaced with Source.cache_key(), and + # * Except that the serialized session config is replaced by + # serialized combined config (session -> database -> system) + # that only affects compilation. + # * The schema version + # * OPTIONALLY, the schema version. We wanted to bump the protocol + # version to include this, but 5.x hard crashes when it reads a + # persistent cache with entries it doesn't understand, so instead + # we stick it on the end where it will be ignored by old versions. + + cdef char flags + + assert buf.read_byte() == 0 # version + + flags = buf.read_byte() + if flags & MASK_JSON_PARAMETERS > 0: + input_format = IN_FMT_JSON + else: + input_format = IN_FMT_BINARY + expect_one = flags & MASK_EXPECT_ONE > 0 + inline_typeids = flags & MASK_INLINE_TYPEIDS > 0 + inline_typenames = flags & MASK_INLINE_TYPENAMES > 0 + inline_objectids = flags & MASK_INLINE_OBJECTIDS > 0 + + protocol_version = buf.read_int16(), buf.read_int16() + output_format = deserialize_output_format(buf.read_byte()) + implicit_limit = buf.read_int64() + + size = buf.read_int32() + if size >= 0: + modaliases = [] + for _ in range(size): + if buf.read_byte(): + k = buf.read_null_str().decode("utf-8") + else: + k = None + v = buf.read_null_str().decode("utf-8") + modaliases.append((k, v)) + modaliases = immutables.Map(modaliases) + else: + modaliases = None + + serializer = compilation_config_serializer + type_id = uuidgen.from_bytes(buf.read_bytes(16)) + if type_id == serializer.type_id: + buf.read_len_prefixed_bytes() + else: + serializer = sertypes.CompilationConfigSerializer( + type_id, buf.read_len_prefixed_bytes(), defines.CURRENT_PROTOCOL + ) + + data = buf.read_len_prefixed_bytes() + if data: + session_config = immutables.Map( + ( + k, + config.SettingValue( + name=k, + value=v, + source='session', + scope=qltypes.ConfigScope.SESSION, + ) + ) for k, v in serializer.decode(data).items() + ) + else: + session_config = None + + source = tokenizer.deserialize( + buf.read_len_prefixed_bytes(), query_text + ) + + cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + + req = CompilationRequest( + source=source, + protocol_version=protocol_version, + schema_version=schema_version, + compilation_config_serializer=serializer, + output_format=output_format, + input_format=input_format, + expect_one=expect_one, + implicit_limit=implicit_limit, + inline_typeids=inline_typeids, + inline_typenames=inline_typenames, + inline_objectids=inline_objectids, + modaliases=modaliases, + session_config=session_config, + ) + + req.serialized_cache = data + req.cache_key = cache_key + + return req + + +cdef _serialize_comp_req_v0(req: CompilationRequest): + # Please see _deserialize_v0 for the format doc + + cdef: + char version = 0, flags + WriteBuffer out = WriteBuffer.new() + + out.write_byte(version) + + flags = ( + (MASK_JSON_PARAMETERS if req.input_format is IN_FMT_JSON else 0) | + (MASK_EXPECT_ONE if req.expect_one else 0) | + (MASK_INLINE_TYPEIDS if req.inline_typeids else 0) | + (MASK_INLINE_TYPENAMES if req.inline_typenames else 0) | + (MASK_INLINE_OBJECTIDS if req.inline_objectids else 0) + ) + out.write_byte(flags) + + out.write_int16(req.protocol_version[0]) + out.write_int16(req.protocol_version[1]) + out.write_byte(serialize_output_format(req.output_format)) + out.write_int64(req.implicit_limit) + + if req.modaliases is None: + out.write_int32(-1) + else: + out.write_int32(len(req.modaliases)) + for k, v in sorted( + req.modaliases.items(), + key=lambda i: (0, i[0]) if i[0] is None else (1, i[0]) + ): + if k is None: + out.write_byte(0) + else: + out.write_byte(1) + out.write_str(k, "utf-8") + out.write_str(v, "utf-8") + + type_id, desc = req._serializer.describe() + out.write_bytes(type_id.bytes) + out.write_len_prefixed_bytes(desc) + + hash_obj = hashlib.blake2b(memoryview(out), digest_size=16) + hash_obj.update(req.source.cache_key()) + + if req.session_config is None: + session_config = b"" + else: + session_config = req._serializer.encode_configs( + req.session_config + ) + out.write_len_prefixed_bytes(session_config) + + # Build config that affects compilation: session -> database -> system. + # This is only used for calculating cache_key, while session + # config itreq is separately stored above in the serialized format. + serialized_comp_config = req._serializer.encode_configs( + req.system_config, req.database_config, req.session_config + ) + hash_obj.update(serialized_comp_config) + + # Must set_schema_version() before serializing compilation request + assert req.schema_version is not None + hash_obj.update(req.schema_version.bytes) + + cache_key_bytes = hash_obj.digest() + cache_key = uuidgen.from_bytes(cache_key_bytes) + + out.write_len_prefixed_bytes(req.source.serialize()) + out.write_bytes(cache_key_bytes) + out.write_bytes(req.schema_version.bytes) + + return cache_key, out diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 2c6f3e18e32..ed376cf16ba 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -419,9 +419,11 @@ cdef class Database: def hydrate_cache(self, query_cache): for _, in_data, out_data in query_cache: - query_req = rpc.CompilationRequest( - self.server.compilation_config_serializer) - query_req.deserialize(in_data, "") + query_req = rpc.CompilationRequest.deserialize( + in_data, + "", + self.server.compilation_config_serializer, + ) if query_req not in self._eql_to_compiled: unit = dbstate.QueryUnit.deserialize(out_data) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 5af1d14f626..72a8f265d9b 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -801,21 +801,23 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.write(self.make_state_data_description_msg()) raise - rv = rpc.CompilationRequest(self.server.compilation_config_serializer) - rv.update( - self._tokenize(query), - self.protocol_version, + cfg_ser = self.server.compilation_config_serializer + rv = rpc.CompilationRequest( + source=self._tokenize(query), + protocol_version=self.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, output_format=output_format, expect_one=expect_one, implicit_limit=implicit_limit, inline_typeids=inline_typeids, inline_typenames=inline_typenames, inline_objectids=inline_objectids, - ).set_schema_version(_dbview.schema_version) - rv.set_modaliases(_dbview.get_modaliases()) - rv.set_session_config(_dbview.get_session_config()) - rv.set_database_config(_dbview.get_database_config()) - rv.set_system_config(_dbview.get_compilation_system_config()) + modaliases=_dbview.get_modaliases(), + session_config=_dbview.get_session_config(), + database_config=_dbview.get_database_config(), + system_config=_dbview.get_compilation_system_config(), + ) return rv, allow_capabilities async def parse(self): @@ -1430,12 +1432,13 @@ cdef class EdgeConnection(frontend.FrontendConnection): async def _execute_utility_stmt(self, eql: str, pgcon): cdef dbview.DatabaseConnectionView _dbview = self.get_dbview() + cfg_ser = self.server.compilation_config_serializer query_req = rpc.CompilationRequest( - self.server.compilation_config_serializer + source=edgeql.Source.from_string(eql), + protocol_version=self.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, ) - query_req.update( - edgeql.Source.from_string(eql), self.protocol_version - ).set_schema_version(_dbview.schema_version) compiled = await _dbview.parse(query_req) query_unit_group = compiled.query_unit_group @@ -1826,14 +1829,15 @@ async def run_script( await conn._start_connection(database) try: _dbview = conn.get_dbview() + cfg_ser = server.compilation_config_serializer compiled = await _dbview.parse( rpc.CompilationRequest( - server.compilation_config_serializer - ).update( - edgeql.Source.from_string(script), - conn.protocol_version, + source=edgeql.Source.from_string(script), + protocol_version=conn.protocol_version, + schema_version=_dbview.schema_version, + compilation_config_serializer=cfg_ser, output_format=FMT_NONE, - ).set_schema_version(_dbview.schema_version) + ), ) if len(compiled.query_unit_group) > 1: await conn._execute_script(compiled, b'') diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 2717aabea65..7814ed1495d 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -206,13 +206,13 @@ async def _parse( ) query_req = rpc.CompilationRequest( - db.server.compilation_config_serializer - ).update( - edgeql.Source.from_string(query), + source=edgeql.Source.from_string(query), protocol_version=edbdef.CURRENT_PROTOCOL, + schema_version=dbv.schema_version, + compilation_config_serializer=db.server.compilation_config_serializer, input_format=input_format, output_format=output_format, - ).set_schema_version(dbv.schema_version) + ) compiled = await dbv.parse( query_req, diff --git a/tests/test_server_compiler.py b/tests/test_server_compiler.py index d628999bd1c..0de77042ece 100644 --- a/tests/test_server_compiler.py +++ b/tests/test_server_compiler.py @@ -442,13 +442,14 @@ async def _test_pool_disconnect_queue(self, pool_class): ) orig_query = 'SELECT 123' + cfg_ser = compiler.state.compilation_config_serializer request = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).update( source=edgeql.Source.from_string(orig_query), protocol_version=(1, 0), + schema_version=uuid.uuid4(), + compilation_config_serializer=cfg_ser, implicit_limit=101, - ).set_schema_version(uuid.uuid4()) + ) await asyncio.gather(*(pool_.compile_in_tx( None, @@ -476,15 +477,15 @@ def test_server_compiler_rpc_hash_eq(self): ) def test(source: edgeql.Source): + cfg_ser = compiler.state.compilation_config_serializer request1 = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).update( source=source, protocol_version=(1, 0), - ).set_schema_version(uuid.uuid4()) - request2 = rpc.CompilationRequest( - compiler.state.compilation_config_serializer - ).deserialize(request1.serialize(), "") + schema_version=uuid.uuid4(), + compilation_config_serializer=cfg_ser, + ) + request2 = rpc.CompilationRequest.deserialize( + request1.serialize(), "", cfg_ser) self.assertEqual(hash(request1), hash(request2)) self.assertEqual(request1, request2) From fc43c044d9696311c1e310f59bc13b0c1d455c87 Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Tue, 22 Oct 2024 16:02:14 -0400 Subject: [PATCH 05/18] Call super setup/teardown test scripts (#7902) --- tests/test_http_ext_auth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 90419421029..357e829b8e0 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -362,6 +362,7 @@ def setUp(self): self.mock_net_server = tb.MockHttpServer() self.mock_net_server.start() + super().setUp() def tearDown(self): if self.mock_oauth_server is not None: @@ -369,6 +370,7 @@ def tearDown(self): if self.mock_net_server is not None: self.mock_net_server.stop() self.mock_oauth_server = None + super().tearDown() @classmethod def get_setup_script(cls): From d3670b3a34a545e2f734da9717ca4d4a507539bb Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 22 Oct 2024 13:13:50 -0700 Subject: [PATCH 06/18] Reduce compiler RPC surface (#7905) The server is now mostly passing the compilation parameters in `CompilationRequest`, so convert the remaining users of the old `compile()` API to benefit from reduction in argument count. --- edb/edgeql/compiler/options.py | 4 +- edb/server/compiler/__init__.py | 2 + edb/server/compiler/compiler.py | 139 +++++++----------- edb/server/compiler/rpc.pyi | 4 +- .../compiler_pool/multitenant_worker.py | 34 +++-- edb/server/compiler_pool/worker.py | 34 +++-- tests/test_server_compiler.py | 7 +- 7 files changed, 107 insertions(+), 117 deletions(-) diff --git a/edb/edgeql/compiler/options.py b/edb/edgeql/compiler/options.py index 9dceb12fcc4..2a02abf0edb 100644 --- a/edb/edgeql/compiler/options.py +++ b/edb/edgeql/compiler/options.py @@ -37,7 +37,7 @@ SourceOrPathId = s_types.Type | s_pointers.Pointer | pathid.PathId -@dataclass +@dataclass(kw_only=True) class GlobalCompilerOptions: """Compiler toggles that affect compilation as a whole.""" @@ -102,7 +102,7 @@ class GlobalCompilerOptions: dump_restore_mode: bool = False -@dataclass +@dataclass(kw_only=True) class CompilerOptions(GlobalCompilerOptions): #: Module name aliases. diff --git a/edb/server/compiler/__init__.py b/edb/server/compiler/__init__.py index c9fc5c13421..77be86bae03 100644 --- a/edb/server/compiler/__init__.py +++ b/edb/server/compiler/__init__.py @@ -30,9 +30,11 @@ from .enums import InputFormat, OutputFormat from .explain import analyze_explain_output from .ddl import repair_schema +from .rpc import CompilationRequest __all__ = ( 'Cardinality', + 'CompilationRequest', 'Compiler', 'CompilerState', 'CompileContext', diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 94f4e2a9006..7cc04096f60 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -565,7 +565,7 @@ def compile_sql( apply_access_policies_sql=apply_access_policies_sql, ) - def compile_request( + def compile_serialized_request( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, @@ -584,48 +584,28 @@ def compile_request( ) units, cstate = self.compile( - user_schema, - global_schema, - reflection_cache, - database_config, - system_config, - request.source, - request.modaliases, - request.session_config, - request.output_format, - request.expect_one, - request.implicit_limit, - request.inline_typeids, - request.inline_typenames, - request.protocol_version, - request.inline_objectids, - request.input_format is enums.InputFormat.JSON, - cache_key=request.get_cache_key(), + user_schema=user_schema, + global_schema=global_schema, + reflection_cache=reflection_cache, + database_config=database_config, + system_config=system_config, + request=request, ) return units, cstate def compile( self, + *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, Tuple[str, ...]], database_config: Optional[immutables.Map[str, config.SettingValue]], system_config: Optional[immutables.Map[str, config.SettingValue]], - source: edgeql.Source, - sess_modaliases: Optional[immutables.Map[Optional[str], str]], - sess_config: Optional[immutables.Map[str, config.SettingValue]], - output_format: enums.OutputFormat, - expect_one: bool, - implicit_limit: int, - inline_typeids: bool, - inline_typenames: bool, - protocol_version: defines.ProtocolVersion, - inline_objectids: bool = True, - json_parameters: bool = False, - cache_key: Optional[uuid.UUID] = None, + request: rpc.CompilationRequest, ) -> Tuple[dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState]]: + sess_config = request.session_config if sess_config is None: sess_config = EMPTY_MAP @@ -635,6 +615,7 @@ def compile( if system_config is None: system_config = EMPTY_MAP + sess_modaliases = request.modaliases if sess_modaliases is None: sess_modaliases = DEFAULT_MODULE_ALIASES_MAP @@ -651,19 +632,19 @@ def compile( ctx = CompileContext( compiler_state=self.state, state=state, - output_format=output_format, - expected_cardinality_one=expect_one, - implicit_limit=implicit_limit, - inline_typeids=inline_typeids, - inline_typenames=inline_typenames, - inline_objectids=inline_objectids, - json_parameters=json_parameters, - source=source, - protocol_version=protocol_version, - cache_key=cache_key, + output_format=request.output_format, + expected_cardinality_one=request.expect_one, + implicit_limit=request.implicit_limit, + inline_typeids=request.inline_typeids, + inline_typenames=request.inline_typenames, + inline_objectids=request.inline_objectids, + json_parameters=request.input_format is enums.InputFormat.JSON, + source=request.source, + protocol_version=request.protocol_version, + cache_key=request.get_cache_key(), ) - unit_group = compile(ctx=ctx, source=source) + unit_group = compile(ctx=ctx, source=request.source) tx_started = False for unit in unit_group: if unit.tx_id: @@ -675,7 +656,7 @@ def compile( else: return unit_group, None - def compile_in_tx_request( + def compile_serialized_request_in_tx( self, state: dbstate.CompilerConnectionState, txid: int, @@ -690,7 +671,23 @@ def compile_in_tx_request( original_query, self.state.compilation_config_serializer, ) + return self.compile_in_tx( + state=state, + txid=txid, + request=request, + expect_rollback=expect_rollback, + ) + def compile_in_tx( + self, + *, + state: dbstate.CompilerConnectionState, + txid: int, + request: rpc.CompilationRequest, + expect_rollback: bool = False, + ) -> Tuple[ + dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState] + ]: # Apply session differences if any if ( request.modaliases is not None @@ -703,39 +700,6 @@ def compile_in_tx_request( ): state.current_tx().update_session_config(session_config) - units, cstate = self.compile_in_tx( - state, - txid, - request.source, - request.output_format, - request.expect_one, - request.implicit_limit, - request.inline_typeids, - request.inline_typenames, - request.protocol_version, - request.inline_objectids, - request.input_format is enums.InputFormat.JSON, - expect_rollback=expect_rollback, - cache_key=request.get_cache_key(), - ) - return units, cstate - - def compile_in_tx( - self, - state: dbstate.CompilerConnectionState, - txid: int, - source: edgeql.Source, - output_format: enums.OutputFormat, - expect_one: bool, - implicit_limit: int, - inline_typeids: bool, - inline_typenames: bool, - protocol_version: defines.ProtocolVersion, - inline_objectids: bool = True, - json_parameters: bool = False, - expect_rollback: bool = False, - cache_key: Optional[uuid.UUID] = None, - ) -> Tuple[dbstate.QueryUnitGroup, dbstate.CompilerConnectionState]: if ( expect_rollback and state.current_tx().id != txid and @@ -743,27 +707,28 @@ def compile_in_tx( ): # This is a special case when COMMIT MIGRATION fails, the compiler # doesn't have the right transaction state, so we just roll back. - return self._try_compile_rollback(source)[0], state + return self._try_compile_rollback(request.source)[0], state else: state.sync_tx(txid) ctx = CompileContext( compiler_state=self.state, state=state, - output_format=output_format, - expected_cardinality_one=expect_one, - implicit_limit=implicit_limit, - inline_typeids=inline_typeids, - inline_typenames=inline_typenames, - inline_objectids=inline_objectids, - source=source, - protocol_version=protocol_version, - json_parameters=json_parameters, + output_format=request.output_format, + expected_cardinality_one=request.expect_one, + implicit_limit=request.implicit_limit, + inline_typeids=request.inline_typeids, + inline_typenames=request.inline_typenames, + inline_objectids=request.inline_objectids, + source=request.source, + protocol_version=request.protocol_version, + json_parameters=request.input_format is enums.InputFormat.JSON, expect_rollback=expect_rollback, - cache_key=cache_key, + cache_key=request.get_cache_key(), ) - return compile(ctx=ctx, source=source), ctx.state + units = compile(ctx=ctx, source=request.source) + return units, ctx.state def interpret_backend_error( self, diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index 94cd289e4cc..47b6d2fe464 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -44,6 +44,8 @@ class CompilationRequest: *, source: edgeql.Source, protocol_version: defines.ProtocolVersion, + schema_version: uuid.UUID, + compilation_config_serializer: sertypes.CompilationConfigSerializer, output_format: enums.OutputFormat = enums.OutputFormat.BINARY, input_format: enums.InputFormat = enums.InputFormat.BINARY, expect_one: bool = False, @@ -51,12 +53,10 @@ class CompilationRequest: inline_typeids: bool = False, inline_typenames: bool = False, inline_objectids: bool = True, - schema_version: uuid.UUID | None = None, modaliases: typing.Mapping[str | None, str] | None = None, session_config: typing.Mapping[str, config.SettingValue] | None = None, database_config: typing.Mapping[str, config.SettingValue] | None = None, system_config: typing.Mapping[str, config.SettingValue] | None = None, - compilation_config_serializer: sertypes.CompilationConfigSerializer, ): ... diff --git a/edb/server/compiler_pool/multitenant_worker.py b/edb/server/compiler_pool/multitenant_worker.py index 2fb3cc4cbda..105000467dc 100644 --- a/edb/server/compiler_pool/multitenant_worker.py +++ b/edb/server/compiler_pool/multitenant_worker.py @@ -28,6 +28,7 @@ from edb import graphql from edb.common import debug +from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler @@ -195,7 +196,7 @@ def compile( ): client_schema = clients[client_id] db = client_schema.dbs[dbname] - units, cstate = COMPILER.compile_request( + units, cstate = COMPILER.compile_serialized_request( db.user_schema, client_schema.global_schema, db.reflection_cache, @@ -225,6 +226,7 @@ def compile_in_tx( ): global LAST_STATE if cstate == state.REUSE_LAST_STATE_MARKER: + assert LAST_STATE is not None cstate = LAST_STATE else: cstate = pickle.loads(cstate) @@ -236,7 +238,8 @@ def compile_in_tx( client_schema = clients[client_id] db = client_schema.dbs[dbname] cstate.set_root_user_schema(db.user_schema) - units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_serialized_request_in_tx( + cstate, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate, -1) @@ -286,23 +289,30 @@ def compile_graphql( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) - unit_group, _ = COMPILER.compile( - user_schema=db.user_schema, - global_schema=client_schema.global_schema, - reflection_cache=db.reflection_cache, - database_config=db.database_config, - system_config=client_schema.instance_config, + cfg_ser = COMPILER.state.compilation_config_serializer + request = compiler.CompilationRequest( source=source, - sess_modaliases=None, - sess_config=None, + protocol_version=defines.CURRENT_PROTOCOL, + schema_version=uuidgen.uuid4(), + compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, + input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, - json_parameters=True, - protocol_version=defines.CURRENT_PROTOCOL, + modaliases=None, + session_config=None, + ) + + unit_group, _ = COMPILER.compile( + user_schema=db.user_schema, + global_schema=client_schema.global_schema, + reflection_cache=db.reflection_cache, + database_config=db.database_config, + system_config=client_schema.instance_config, + request=request, ) return unit_group, gql_op diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index 2e9ac518c62..99e4708e5b1 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -26,6 +26,7 @@ from edb import edgeql from edb import graphql +from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler @@ -175,7 +176,7 @@ def compile( system_config, ) - units, cstate = COMPILER.compile_request( + units, cstate = COMPILER.compile_serialized_request( db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, @@ -199,6 +200,7 @@ def compile_in_tx( ): global LAST_STATE if cstate == state.REUSE_LAST_STATE_MARKER: + assert LAST_STATE is not None cstate = LAST_STATE else: cstate = pickle.loads(cstate) @@ -207,7 +209,8 @@ def compile_in_tx( cstate.set_root_user_schema(pickle.loads(user_schema)) else: cstate.set_root_user_schema(DBS[dbname].user_schema) - units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_serialized_request_in_tx( + cstate, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate, -1) @@ -275,23 +278,30 @@ def compile_graphql( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) - unit_group, _ = COMPILER.compile( - user_schema=db.user_schema, - global_schema=GLOBAL_SCHEMA, - reflection_cache=db.reflection_cache, - database_config=db.database_config, - system_config=INSTANCE_CONFIG, + cfg_ser = COMPILER.state.compilation_config_serializer + request = compiler.CompilationRequest( source=source, - sess_modaliases=None, - sess_config=None, + protocol_version=defines.CURRENT_PROTOCOL, + schema_version=uuidgen.uuid4(), + compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, + input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, - json_parameters=True, - protocol_version=defines.CURRENT_PROTOCOL, + modaliases=None, + session_config=None, + ) + + unit_group, _ = COMPILER.compile( + user_schema=db.user_schema, + global_schema=GLOBAL_SCHEMA, + reflection_cache=db.reflection_cache, + database_config=db.database_config, + system_config=INSTANCE_CONFIG, + request=request, ) return unit_group, gql_op diff --git a/tests/test_server_compiler.py b/tests/test_server_compiler.py index 0de77042ece..c4f8f6ebfa5 100644 --- a/tests/test_server_compiler.py +++ b/tests/test_server_compiler.py @@ -33,6 +33,7 @@ from edb import edgeql from edb.testbase import lang as tb from edb.testbase import server as tbs +from edb.pgsql import params as pg_params from edb.server import args as edbargs from edb.server import compiler as edbcompiler from edb.server.compiler import rpc @@ -392,14 +393,16 @@ def setUpClass(cls): super().setUpClass() cls._std_schema = tb._load_std_schema() result = tb._load_reflection_schema() - cls._refl_schema, cls._schema_class_layout = result + cls._refl_schema, _schema_class_layout = result + assert _schema_class_layout is not None + cls._schema_class_layout = _schema_class_layout async def _test_pool_disconnect_queue(self, pool_class): with tempfile.TemporaryDirectory() as td: pool_ = await pool.create_compiler_pool( runstate_dir=td, pool_size=2, - backend_runtime_params=None, + backend_runtime_params=pg_params.get_default_runtime_params(), std_schema=self._std_schema, refl_schema=self._refl_schema, schema_class_layout=self._schema_class_layout, From 024c744097d632729883cbae84e9082c6fbe06cd Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 13:34:39 -0700 Subject: [PATCH 07/18] Reduce isolation level on the GC query also (#7900) --- edb/server/net_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 9038b3c230b..5b0ad51edca 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -350,6 +350,7 @@ def _warn(e): """, variables={"expires_in": expires_in.to_backend_str()}, cached_globally=True, + tx_isolation=defines.TxIsolationLevel.RepeatableRead, ) if len(result) > 0: logger.info(f"Deleted requests: {result!r}") From a131a2b5c3fa7eacd71472867f4388a9f8907714 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 13:35:46 -0700 Subject: [PATCH 08/18] Make sure that simple inlined INSERTs don't select from the table (#7906) The code we were generating for inlined calls to a function containing an INSERT was selecting from the table (overlayed with the INSERT) after performing the insert. This is not incorrect, exactly, but it's suboptimal, and we think that this might be the source of the mysterious transaction serialization errors that we have been seeing in the std::net tests. The source of the issue is that the function call was only producing a `value` aspect, but to select properties out of it a `source` aspect was needed, and so `ensure_source_rvar` joins one back in. Fix it by making sure we can output a source aspect from an inlined function call. --- edb/pgsql/compiler/relgen.py | 14 +++++++- tests/test_edgeql_sql_codegen.py | 59 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 8f51168255f..c0e0390fe8b 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -3323,7 +3323,13 @@ def _compile_func_epilogue( aspect=pgce.PathAspect.VALUE, ) - aspects: Tuple[pgce.PathAspect, ...] = (pgce.PathAspect.VALUE,) + aspects: Tuple[pgce.PathAspect, ...] + if expr.body: + # For inlined functions, we want all of the aspects provided. + aspects = tuple(pathctx.list_path_aspects(func_rel, ir_set.path_id)) + else: + # Otherwise we just know we have value. + aspects = (pgce.PathAspect.VALUE,) func_rvar = relctx.new_rel_rvar(ir_set, func_rel, ctx=ctx) relctx.include_rvar( @@ -3605,6 +3611,12 @@ def process_set_as_func_expr( _compile_inlined_call_args(ir_set, ctx=newctx) set_expr = dispatch.compile(expr.body, ctx=newctx) + # Map the path id so that we can extract source aspects + # from it, which we want so that we can directly select + # from an INSERT instead of using overlays. + pathctx.put_path_id_map( + newctx.rel, ir_set.path_id, expr.body.path_id + ) else: args = _compile_call_args(ir_set, ctx=newctx) diff --git a/tests/test_edgeql_sql_codegen.py b/tests/test_edgeql_sql_codegen.py index 3367889fd43..0020b93f009 100644 --- a/tests/test_edgeql_sql_codegen.py +++ b/tests/test_edgeql_sql_codegen.py @@ -42,6 +42,18 @@ class TestEdgeQLSQLCodegen(tb.BaseEdgeQLCompilerTest): SCHEMA_cards = os.path.join(os.path.dirname(__file__), 'schemas', 'cards.esdl') + @classmethod + def get_schema_script(cls): + script = super().get_schema_script() + # Setting internal params like is_inlined in the schema + # doesn't work right so we override the script to add DDL. + return script + ''' + create function cards::ins_bot(name: str) -> cards::Bot { + set is_inlined := true; + using (insert cards::Bot { name := "asdf" }); + }; + ''' + def _compile_to_tree(self, source): qltree = qlparser.parse_query(source) ir = compiler.compile_ast_to_ir( @@ -461,3 +473,50 @@ def test_codegen_unless_conflict_03(self): "ON CONFLICT", sql, "insert unless conflict not using ON CONFLICT" ) + + def test_codegen_inlined_insert_01(self): + # Test that we don't use an overlay when selecting from a + # simple function that does an INSERT. + sql = self._compile(''' + WITH MODULE cards + select ins_bot("asdf") { id, name } + ''') + + table_obj = self.schema.get("cards::Bot") + count = sql.count(str(table_obj.id)) + # The table should only be referenced once, in the INSERT. + # If we reference it more than that, we're probably selecting it. + self.assertEqual( + count, + 1, + f"Bot selected from and not just inserted: {sql}") + + def test_codegen_inlined_insert_02(self): + # Test that we don't use an overlay when selecting from a + # net::http::schedule_request + sql = self._compile(''' + with + nh as module std::net::http, + url := $url, + request := ( + nh::schedule_request( + url, + method := nh::Method.`GET` + ) + ) + select request { + id, + state, + failure, + response, + } + ''') + + table_obj = self.schema.get("std::net::http::ScheduledRequest") + count = sql.count(str(table_obj.id)) + # The table should only be referenced once, in the INSERT. + # If we reference it more than that, we're probably selecting it. + self.assertEqual( + count, + 1, + f"ScheduledRequest selected from and not just inserted: {sql}") From 4a292d32d1ad678edb57978c60c460c796f3eb2d Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 22 Oct 2024 19:18:42 -0400 Subject: [PATCH 09/18] Fix error when dropping non overloaded function. (#7899) Related #7692 --- edb/pgsql/delta.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 56c7db1de4a..c5548e45f0d 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -1629,13 +1629,16 @@ def apply( if not overload: variadic = func.get_params(schema).find_variadic(schema) - self.pgops.add( - dbops.DropFunction( - name=self.get_pgname(func, schema), - args=self.compile_args(func, schema), - has_variadic=variadic is not None, + if func.get_volatility(schema) != ql_ft.Volatility.Modifying: + # Modifying functions are not compiled. + # See: compile_edgeql_function + self.pgops.add( + dbops.DropFunction( + name=self.get_pgname(func, schema), + args=self.compile_args(func, schema), + has_variadic=variadic is not None, + ) ) - ) return super().apply(schema, context) From 275d0e44a2351410af674b7f8266a2a360b4d3f9 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 16:50:21 -0700 Subject: [PATCH 10/18] Revert "Use new connection pool by default (#7706)" (#7909) This reverts commit eda72c0441d91ca700592a52832f964cd75375e4. There was a logic error in this commit, and so it did not actually make the new connection pool the default. I tried making it the default properly, and hit a number of new test failures, so for now let's just clarify the situation. --- edb/server/connpool/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edb/server/connpool/__init__.py b/edb/server/connpool/__init__.py index 6ce2fa62bd7..dda78b56bd7 100644 --- a/edb/server/connpool/__init__.py +++ b/edb/server/connpool/__init__.py @@ -22,7 +22,7 @@ # During the transition period we allow for the pool to be swapped out. The # current default is to use the old pool, however this will be switched to use # the new pool once we've fully implemented all required features. -if os.environ.get("EDGEDB_USE_NEW_CONNPOOL", "") == "0": +if os.environ.get("EDGEDB_USE_NEW_CONNPOOL", "") == "1": Pool = Pool2Impl Pool2 = Pool1Impl else: From 107408766feb8daad302566c32286fea18ffabb0 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 22 Oct 2024 18:15:50 -0700 Subject: [PATCH 11/18] Remove Postgres NOTICE spew generated by `DROP FUNCTION IF EXISTS` (#7911) Use an explicit conditional when using `DropFunction` in various places of bootstrap and for reflection cache functions. This makes it actually possible to NOTICE postgres messages in bootstrap (pun intended). Equipped with this new visibility fix the name of the indexes we place on properties. --- edb/pgsql/delta.py | 4 ++-- edb/pgsql/deltadbops.py | 8 +++++++- edb/server/compiler/compiler.py | 9 ++++++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index c5548e45f0d..8fdc26c87e8 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -5373,10 +5373,10 @@ def _create_table( id = sn.QualName( module=prop.get_name(schema).module, name=str(prop.id)) - index_name = common.convert_name(id, 'idx0', catenate=True) + index_name = common.convert_name(id, 'idx0', catenate=False) pg_index = dbops.Index( - name=index_name, table_name=new_table_name, + name=index_name[1], table_name=new_table_name, unique=False, columns=[src_col], metadata={'code': DEFAULT_INDEX_CODE}, ) diff --git a/edb/pgsql/deltadbops.py b/edb/pgsql/deltadbops.py index c6e0806e7f5..a9bbf3d2399 100644 --- a/edb/pgsql/deltadbops.py +++ b/edb/pgsql/deltadbops.py @@ -485,7 +485,13 @@ def create_constr_trigger_function( return [dbops.CreateFunction(func, or_replace=True)] def drop_constr_trigger_function(self, proc_name: Tuple[str, ...]): - return [dbops.DropFunction(name=proc_name, args=(), if_exists=True)] + return [dbops.DropFunction( + name=proc_name, + args=(), + # Use a condition instead of if_exists ot reduce annoying + # debug spew from postgres. + conditions=[dbops.FunctionExists(name=proc_name, args=())], + )] def create_constraint(self, constraint: SchemaConstraintTableConstraint): # Add the constraint normally to our table diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 7cc04096f60..38bd694117c 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1243,7 +1243,14 @@ def compile_schema_storage_in_delta( # We drop first instead of using or_replace, in case # something about the arguments changed. df = pg_dbops.DropFunction( - name=func.name, args=func.args or (), if_exists=True + name=func.name, + args=func.args or (), + # Use a condition instead of if_exists ot reduce annoying + # debug spew from postgres. + conditions=[pg_dbops.FunctionExists( + name=func.name, + args=func.args or (), + )], ) df.generate(funcblock) From 15f80880631e0113eb1b478cc437cf485f7eb079 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 19:10:10 -0700 Subject: [PATCH 12/18] Try to fix some of the DROP DATABASE flakes (#7907) I think one of the sources of trouble might be that _early_introspection is hanging on to connections and can't be forced to drop them. Try to work around it by forcing them closed at the end of introspecting. Also, when in dev/test mode, dump out all the info about the other connections. Hopefully this will let us figure out where they are coming from. --- edb/server/tenant.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 7e2c6528afc..b063a1e3ab1 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -993,7 +993,7 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None: conns = await pgcon.sql_fetch_col( b""" SELECT - pid + row_to_json(pg_stat_activity) FROM pg_stat_activity WHERE @@ -1003,8 +1003,14 @@ async def _pg_ensure_database_not_connected(self, dbname: str) -> None: ) if conns: + debug_info = "" + if self.server.in_dev_mode() or self.server.in_test_mode(): + jconns = [json.loads(conn) for conn in conns] + debug_info = ": " + json.dumps(jconns) + raise errors.ExecutionError( - f"database branch {dbname!r} is being accessed by other users" + f"database branch {dbname!r} is being accessed by " + f"other users{debug_info}" ) @contextlib.asynccontextmanager @@ -1222,6 +1228,18 @@ async def _early_introspect_db(self, dbname: str) -> None: early=True, ) + # Early introspection runs *before* we start accepting tasks. + # This means that if we are one of multiple frontends, and we + # get a ensure-database-not-used message, we aren't able to + # handle it. This can result in us hanging onto a connection + # that another frontend wants to get rid of. + # + # We still want to use the pool, though, since it limits our + # connections in the way we want. + # + # Hack around this by pruning the connection ourself. + await self._pg_pool.prune_inactive_connections(dbname) + async def _introspect_dbs(self) -> None: async with self.use_sys_pgcon() as syscon: dbnames = await self._server.get_dbnames(syscon) From 2e5d2be4457d3a307ac678f117d86857382276f1 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 22 Oct 2024 19:12:14 -0700 Subject: [PATCH 13/18] Revert "Add retries to std::net tests (#7898)" (#7908) This reverts commit a27ff6ac91d2c7432e0b2481104f662403327fad. I'll run the nightlies against this and see how it goes. --- tests/test_http_std_net.py | 135 +++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 72 deletions(-) diff --git a/tests/test_http_std_net.py b/tests/test_http_std_net.py index c0cdf2a4196..ced7103775a 100644 --- a/tests/test_http_std_net.py +++ b/tests/test_http_std_net.py @@ -20,7 +20,6 @@ import typing import json -from edb import errors from edb.testbase import http as tb @@ -100,34 +99,30 @@ async def test_http_std_net_con_schedule_request_get_01(self): ) ) - async for tr in self.try_until_succeeds( - delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) - ): - async with tr: - result = await self.con.query_single( - """ - with - nh as module std::net::http, - net as module std::net, - url := $url, - request := ( - insert nh::ScheduledRequest { - created_at := datetime_of_statement(), - updated_at := datetime_of_statement(), - state := std::net::RequestState.Pending, + result = await self.con.query_single( + """ + with + nh as module std::net::http, + net as module std::net, + url := $url, + request := ( + insert nh::ScheduledRequest { + created_at := datetime_of_statement(), + updated_at := datetime_of_statement(), + state := std::net::RequestState.Pending, - url := url, - method := nh::Method.`GET`, - headers := [ - ("Accept", "application/json"), - ("x-test-header", "test-value"), - ], - } - ) - select request {*}; - """, - url=url, + url := url, + method := nh::Method.`GET`, + headers := [ + ("Accept", "application/json"), + ("x-test-header", "test-value"), + ], + } ) + select request {*}; + """, + url=url, + ) requests_for_example = None async for tr in self.try_until_succeeds( @@ -177,32 +172,29 @@ async def test_http_std_net_con_schedule_request_post_01(self): ) ) - async for tr in self.try_until_succeeds( - delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) - ): - async with tr: - result = await self.con.query_single( - """ - with - nh as module std::net::http, - url := $url, - body := $body, - request := ( - nh::schedule_request( - url, - method := nh::Method.POST, - headers := [ - ("Accept", "application/json"), - ("x-test-header", "test-value"), - ], - body := body, - ) - ) - select request {*}; - """, - url=url, - body=b"Hello, world!", + result = await self.con.query_single( + """ + with + nh as module std::net::http, + net as module std::net, + url := $url, + body := $body, + request := ( + nh::schedule_request( + url, + method := nh::Method.POST, + headers := [ + ("Accept", "application/json"), + ("x-test-header", "test-value"), + ], + body := body, + ) ) + select request {*}; + """, + url=url, + body=b"Hello, world!", + ) requests_for_example = None async for tr in self.try_until_succeeds( @@ -236,27 +228,26 @@ async def test_http_std_net_con_schedule_request_bad_address(self): # Test a request to a known-bad address bad_url = "http://256.256.256.256" - async for tr in self.try_until_succeeds( - delay=2, timeout=120, ignore=(errors.TransactionSerializationError,) - ): - async with tr: - result = await self.con.query_single( - """ - with - nh as module std::net::http, - url := $url, - request := ( - nh::schedule_request(url) - ) - select request { - id, - state, - failure, - response, - }; - """, - url=bad_url, + result = await self.con.query_single( + """ + with + nh as module std::net::http, + url := $url, + request := ( + nh::schedule_request( + url, + method := nh::Method.`GET` + ) ) + select request { + id, + state, + failure, + response, + }; + """, + url=bad_url, + ) table_result = await self._wait_for_request_completion(result.id) self.assertEqual(str(table_result.state), 'Failed') From 037886628b61c2569c6bc294ebd34c6188c8f806 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 22 Oct 2024 22:13:18 -0400 Subject: [PATCH 14/18] Sort migrations in tests before checking data. (#7912) --- tests/test_edgeql_ddl.py | 129 ++++++++++++++++++++++++--------------- 1 file changed, 81 insertions(+), 48 deletions(-) diff --git a/tests/test_edgeql_ddl.py b/tests/test_edgeql_ddl.py index 969f435c806..ac9713639d9 100644 --- a/tests/test_edgeql_ddl.py +++ b/tests/test_edgeql_ddl.py @@ -13664,6 +13664,35 @@ async def test_edgeql_ddl_errors_03(self): DROP FUNCTION foo___1(a: int64); ''') + @staticmethod + def order_migrations(migrations): + # Migrations are implicitly ordered based on parent_id. + # For now, assume that there is a single "initial" migration + # and all migrations have at most one child. + + # Find initial migration with no parents. + ordered = [ + migration + for migration in migrations + if not migration['parents'] + ] + + # Repeatedly find descendents until no more can be found. + prev_ids = [migration['id'] for migration in ordered] + while prev_ids: + curr_migrations = [ + migration + for migration in migrations + if any( + parent['id'] in prev_ids + for parent in migration['parents'] + ) + ] + ordered.extend(curr_migrations) + prev_ids = [migration['id'] for migration in curr_migrations] + + return ordered + async def test_edgeql_ddl_migration_sdl_01(self): await self.con.execute(''' CONFIGURE SESSION SET store_migration_sdl := @@ -13696,42 +13725,12 @@ async def test_edgeql_ddl_migration_sdl_01(self): drop type A; ''') - # Migrations implicitly ordered based on parent_id. - # Fetch the migrations and sort them here. - migrations = json.loads(await self.con.query_json( - 'select schema::Migration { id, parent_ids := .parents.id, sdl }' - )) - - # Find migrations with no parents. - sdl = [ - migration['sdl'] - for migration in migrations - if not migration['parent_ids'] - ] - - # Repeatedly find descendents until no more can be found. - prev_ids = [ - migration['id'] - for migration in migrations - if not migration['parent_ids'] - ] - while prev_ids: - sdl.extend( - migration['sdl'] - for migration in migrations - if any( - parent_id in prev_ids - for parent_id in migration['parent_ids'] - ) - ) - prev_ids = [ - migration['id'] - for migration in migrations - if any( - parent_id in prev_ids - for parent_id in migration['parent_ids'] - ) - ] + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { id, parents: { id }, sdl } + ''')) + ) + sdl = [migration['sdl'] for migration in migrations] self.assert_data_shape( sdl, @@ -13823,8 +13822,15 @@ async def test_edgeql_ddl_create_migration_01(self): }] ) - await self.assert_query_result( - 'select schema::Migration { script, sdl }', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, parents: { id }, script, sdl + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'script': ( @@ -13871,8 +13877,15 @@ async def test_edgeql_ddl_create_migration_02(self): }; ''') - await self.assert_query_result( - 'select schema::Migration { script, sdl }', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, parents: { id }, script, sdl + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'script': ( @@ -13933,10 +13946,20 @@ async def test_edgeql_ddl_create_migration_03(self): }; ''') - await self.assert_query_result( - ''' - SELECT schema::Migration { message, generated_by, script, sdl } - ''', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, + parents: { id }, + message, + generated_by, + script, + sdl, + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'generated_by': 'DevMode', @@ -13964,10 +13987,20 @@ async def test_edgeql_ddl_create_migration_03(self): CREATE TYPE Type3 ''') - await self.assert_query_result( - ''' - SELECT schema::Migration { message, generated_by, script, sdl } - ''', + migrations = TestEdgeQLDDL.order_migrations( + json.loads(await self.con.query_json(''' + select schema::Migration { + id, + parents: { id }, + message, + generated_by, + script, + sdl, + } + ''')) + ) + self.assert_data_shape( + migrations, [ { 'generated_by': 'DevMode', From b57279343a635a6438ea2595e0ac14cf42964ea8 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 23 Oct 2024 07:58:50 -0700 Subject: [PATCH 15/18] Fix older postgres versions (#7892) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This still *leaves out* fields that were removed in postgres 17, though. They seem pretty marginal, but it's worth discussing what we really want to do there. I think this approach is good enough to ship an alpha with, though. --------- Co-authored-by: Aljaž Mur Eržen --- edb/pgsql/metaschema.py | 12 ++++++------ edb/server/bootstrap.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 49e45f7c7ee..98bd3f41e28 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -6479,6 +6479,7 @@ def _generate_sql_information_schema( SELECT attrelid, attname, atttypid, + attstattarget, attlen, attnum, attnum as attnum_internal, @@ -6486,8 +6487,8 @@ def _generate_sql_information_schema( attcacheoff, atttypmod, attbyval, - attalign, attstorage, + attalign, attnotnull, atthasdef, atthasmissing, @@ -6497,7 +6498,6 @@ def _generate_sql_information_schema( attislocal, attinhcount, attcollation, - attstattarget, attacl, attoptions, attfdwoptions, @@ -6522,6 +6522,7 @@ def _generate_sql_information_schema( SELECT pc_oid as attrelid, col_name as attname, COALESCE(atttypid, 25) as atttypid, -- defaults to TEXT + COALESCE(attstattarget, -1) as attstattarget, COALESCE(attlen, -1) as attlen, (ROW_NUMBER() OVER ( PARTITION BY pc_oid @@ -6532,8 +6533,8 @@ def _generate_sql_information_schema( COALESCE(attcacheoff, -1) as attcacheoff, COALESCE(atttypmod, -1) as atttypmod, COALESCE(attbyval, FALSE) as attbyval, - COALESCE(attalign, 'i') as attalign, COALESCE(attstorage, 'x') as attstorage, + COALESCE(attalign, 'i') as attalign, required as attnotnull, -- Always report no default, to avoid expr trouble false as atthasdef, @@ -6544,7 +6545,6 @@ def _generate_sql_information_schema( COALESCE(attislocal, TRUE) as attislocal, COALESCE(attinhcount, 0) as attinhcount, COALESCE(attcollation, 0) as attcollation, - COALESCE(attstattarget, -1) as attstattarget, attacl, attoptions, attfdwoptions, @@ -6645,14 +6645,15 @@ def _generate_sql_information_schema( attrelid, attname, atttypid, + attstattarget, attlen, attnum, attndims, attcacheoff, atttypmod, attbyval, - attalign, attstorage, + attalign, attnotnull, atthasdef, atthasmissing, @@ -6662,7 +6663,6 @@ def _generate_sql_information_schema( attislocal, attinhcount, attcollation, - attstattarget, attacl, attoptions, attfdwoptions, diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index cb90c125220..0a118c0bcdf 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -1135,6 +1135,9 @@ async def create_branch( elif line.startswith('CREATE TYPE'): if any(skip in line for skip in to_skip): skipping = True + elif line == 'SET transaction_timeout = 0;': + continue + if skipping: continue new_lines.append(line) @@ -1535,6 +1538,15 @@ def cleanup_tpldbdump(tpldbdump: bytes) -> bytes: flags=re.MULTILINE, ) + # PostgreSQL 17 adds a transaction_timeout config setting that + # didn't exist on earlier versions. + tpldbdump = re.sub( + rb'^SET transaction_timeout = 0;$', + rb'', + tpldbdump, + flags=re.MULTILINE, + ) + return tpldbdump @@ -2144,7 +2156,6 @@ async def _populate_misc_instance_data( json.dumps(json_single_role_metadata), ) - assert backend_params.has_create_database if not backend_params.has_create_database: await _store_static_json_cache( ctx, From d01b12fac8b7bb104455b600fb2888d723f42b65 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Wed, 23 Oct 2024 11:49:48 -0600 Subject: [PATCH 16/18] Migrate OAuth HTTP code to new worker (#7883) This migrates the auth extensions from httpx to the new Rust HTTP handler. We add a new, shared HttpClient to the tenant which anyone can make use of. --- edb/server/http.py | 191 ++++++++++++++++++++ edb/server/net_worker.py | 74 +------- edb/server/protocol/auth_ext/base.py | 2 +- edb/server/protocol/auth_ext/http.py | 39 +++- edb/server/protocol/auth_ext/http_client.py | 72 ++++---- edb/server/protocol/auth_ext/oauth.py | 16 +- edb/server/tenant.py | 13 ++ edb/testbase/http.py | 14 +- tests/test_http_ext_auth.py | 10 +- 9 files changed, 302 insertions(+), 129 deletions(-) create mode 100644 edb/server/http.py diff --git a/edb/server/http.py b/edb/server/http.py new file mode 100644 index 00000000000..a01c2f82899 --- /dev/null +++ b/edb/server/http.py @@ -0,0 +1,191 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from typing import ( + Tuple, + Any, + Mapping, + Optional, +) + +import asyncio +import dataclasses +import logging +import os +import json as json_lib +import urllib.parse + +from edb.server._http import Http + +logger = logging.getLogger("edb.server") + + +class HttpClient: + def __init__(self, limit: int): + self._client = Http(limit) + self._fd = self._client._fd + self._task = None + self._skip_reads = 0 + self._loop = asyncio.get_running_loop() + self._task = self._loop.create_task(self._boot(self._loop)) + self._next_id = 0 + self._requests: dict[int, asyncio.Future] = {} + + def __del__(self) -> None: + if self._task: + self._task.cancel() + self._task = None + + def _update_limit(self, limit: int): + self._client._update_limit(limit) + + async def request( + self, + *, + method: str, + url: str, + content: bytes | None, + headers: list[tuple[str, str]] | None, + ) -> tuple[int, bytes, dict[str, str]]: + if content is None: + content = bytes() + if headers is None: + headers = [] + id = self._next_id + self._next_id += 1 + self._requests[id] = asyncio.Future() + try: + self._client._request(id, url, method, content, headers) + resp = await self._requests[id] + return resp + finally: + del self._requests[id] + + async def get( + self, path: str, *, headers: dict[str, str] | None = None + ) -> Response: + headers_list = [(k, v) for k, v in headers.items()] if headers else None + result = await self.request( + method="GET", url=path, content=None, headers=headers_list + ) + return Response.from_tuple(result) + + async def post( + self, + path: str, + *, + headers: dict[str, str] | None = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> Response: + if json is not None: + data = json_lib.dumps(json).encode('utf-8') + headers = headers or {} + headers['Content-Type'] = 'application/json' + elif isinstance(data, str): + data = data.encode('utf-8') + elif isinstance(data, dict): + data = urllib.parse.urlencode(data).encode('utf-8') + headers = headers or {} + headers['Content-Type'] = 'application/x-www-form-urlencoded' + + headers_list = [(k, v) for k, v in headers.items()] if headers else None + result = await self.request( + method="POST", url=path, content=data, headers=headers_list + ) + return Response.from_tuple(result) + + async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: + logger.info("Python-side HTTP client booted") + reader = asyncio.StreamReader(loop=loop) + reader_protocol = asyncio.StreamReaderProtocol(reader) + fd = os.fdopen(self._client._fd, 'rb') + transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) + try: + while len(await reader.read(1)) == 1: + if not self._client or not self._task: + break + if self._skip_reads > 0: + self._skip_reads -= 1 + continue + msg = self._client._read() + if not msg: + break + self._process_message(msg) + finally: + transport.close() + + def _process_message(self, msg): + msg_type, id, data = msg + + if id in self._requests: + if msg_type == 1: + self._requests[id].set_result(data) + elif msg_type == 0: + self._requests[id].set_exception(Exception(data)) + + +class CaseInsensitiveDict(dict): + def __init__(self, data: Optional[list[Tuple[str, str]]] = None): + super().__init__() + if data: + for k, v in data: + self[k.lower()] = v + + def __setitem__(self, key: str, value: str): + super().__setitem__(key.lower(), value) + + def __getitem__(self, key: str): + return super().__getitem__(key.lower()) + + def get(self, key: str, default=None): + return super().get(key.lower(), default) + + def update(self, *args, **kwargs: str) -> None: + if args: + data = args[0] + if isinstance(data, Mapping): + for key, value in data.items(): + self[key] = value + else: + for key, value in data: + self[key] = value + for key, value in kwargs.items(): + self[key] = value + + +@dataclasses.dataclass +class Response: + status_code: int + body: bytes + headers: CaseInsensitiveDict + + @classmethod + def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]): + status_code, body, headers_list = data + headers = CaseInsensitiveDict([(k, v) for k, v in headers_list.items()]) + return cls(status_code, body, headers) + + def json(self): + return json_lib.loads(self.body.decode('utf-8')) + + @property + def text(self) -> str: + return self.body.decode('utf-8') diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 5b0ad51edca..33cd658d0a6 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -24,12 +24,11 @@ import asyncio import logging import base64 -import os from edb.ir import statypes from edb.server import defines from edb.server.protocol import execute -from edb.server._http import Http +from edb.server.http import HttpClient from edb.common import retryloop from . import dbview @@ -99,77 +98,6 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: ) -class HttpClient: - def __init__(self, limit: int): - self._client = Http(limit) - self._fd = self._client._fd - self._task = None - self._skip_reads = 0 - self._loop = asyncio.get_running_loop() - self._task = self._loop.create_task(self._boot(self._loop)) - self._next_id = 0 - self._requests: dict[int, asyncio.Future] = {} - - def __del__(self) -> None: - if self._task: - self._task.cancel() - self._task = None - - def _update_limit(self, limit: int): - self._client._update_limit(limit) - - async def request( - self, - *, - method: str, - url: str, - content: bytes | None, - headers: list[tuple[str, str]] | None, - ): - if content is None: - content = bytes() - if headers is None: - headers = [] - id = self._next_id - self._next_id += 1 - self._requests[id] = asyncio.Future() - try: - self._client._request(id, url, method, content, headers) - resp = await self._requests[id] - return resp - finally: - del self._requests[id] - - async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: - logger.info("Python-side HTTP client booted") - reader = asyncio.StreamReader(loop=loop) - reader_protocol = asyncio.StreamReaderProtocol(reader) - fd = os.fdopen(self._client._fd, 'rb') - transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) - try: - while len(await reader.read(1)) == 1: - if not self._client or not self._task: - break - if self._skip_reads > 0: - self._skip_reads -= 1 - continue - msg = self._client._read() - if not msg: - break - self._process_message(msg) - finally: - transport.close() - - def _process_message(self, msg): - msg_type, id, data = msg - - if id in self._requests: - if msg_type == 1: - self._requests[id].set_result(data) - elif msg_type == 0: - self._requests[id].set_exception(Exception(data)) - - def create_http(tenant: edbtenant.Tenant): net_http_max_connections = tenant._server.config_lookup( 'net_http_max_connections', tenant.get_sys_config() diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index e97edbc8c4f..d42bac0cbf8 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -37,7 +37,7 @@ def __init__( client_secret: str, *, additional_scope: str | None, - http_factory: Callable[..., http_client.HttpClient], + http_factory: Callable[..., http_client.AuthHttpClient], ): self.name = name self.issuer_url = issuer_url diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 6457f62df59..2837edfd241 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -31,7 +31,15 @@ import mimetypes import uuid -from typing import Any, Optional, Tuple, FrozenSet, cast, TYPE_CHECKING +from typing import ( + Any, + Optional, + Tuple, + FrozenSet, + cast, + TYPE_CHECKING, + Callable, +) import aiosmtplib from jwcrypto import jwk, jwt @@ -80,6 +88,27 @@ def __init__( self.tenant = tenant self.test_mode = tenant.server.in_test_mode() + def _get_url_munger( + self, request: protocol.HttpRequest + ) -> Callable[[str], str] | None: + """ + Returns a callable that can be used to modify the base URL + when making requests to the OAuth provider. + + This is used to redirect requests to the test OAuth provider + when running in test mode. + """ + if not self.test_mode: + return None + test_url = ( + request.params[b'oauth-test-server'].decode() + if (request.params and b'oauth-test-server' in request.params) + else None + ) + if test_url: + return lambda path: f"{test_url}{urllib.parse.quote(path)}" + return None + async def handle_request( self, request: protocol.HttpRequest, @@ -270,7 +299,10 @@ async def handle_authorize( query, "challenge", fallback_keys=["code_challenge"] ) oauth_client = oauth.Client( - db=self.db, provider_name=provider_name, base_url=self.test_url + db=self.db, + provider_name=provider_name, + url_munger=self._get_url_munger(request), + http_client=self.tenant.get_http_client(), ) await pkce.create(self.db, challenge) authorize_url = await oauth_client.get_authorize_url( @@ -369,7 +401,8 @@ async def handle_callback( oauth_client = oauth.Client( db=self.db, provider_name=provider_name, - base_url=self.test_url, + url_munger=self._get_url_munger(request), + http_client=self.tenant.get_http_client(), ) ( identity, diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py index 52302b97c31..563605b5900 100644 --- a/edb/server/protocol/auth_ext/http_client.py +++ b/edb/server/protocol/auth_ext/http_client.py @@ -16,51 +16,49 @@ # limitations under the License. # -from typing import Any -import urllib.parse +from typing import Any, Callable, Self -import hishel -import httpx +from edb.server import http -class HttpClient(httpx.AsyncClient): +class AuthHttpClient: def __init__( self, - *args: Any, - edgedb_test_url: str | None, - base_url: str, - **kwargs: Any, + http_client: http.HttpClient, + url_munger: Callable[[str], str] | None = None, + base_url: str | None = None, ): - self.edgedb_orig_base_url = None - if edgedb_test_url: - self.edgedb_orig_base_url = urllib.parse.quote(base_url, safe='') - base_url = edgedb_test_url - cache = hishel.AsyncCacheTransport( - transport=httpx.AsyncHTTPTransport(), - storage=hishel.AsyncInMemoryStorage(capacity=5), - ) - super().__init__( - *args, base_url=base_url, transport=cache, **kwargs - ) + self.url_munger = url_munger + self.http_client = http_client + self.base_url = base_url - async def post( # type: ignore[override] + async def post( self, path: str, - *args: Any, - **kwargs: Any, - ) -> httpx.Response: - if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}{path}' - return await super().post( - path, *args, **kwargs + *, + headers: dict[str, str] | None = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> http.Response: + if self.base_url: + path = self.base_url + path + if self.url_munger: + path = self.url_munger(path) + return await self.http_client.post( + path, headers=headers, data=data, json=json ) - async def get( # type: ignore[override] - self, - path: str, - *args: Any, - **kwargs: Any, - ) -> httpx.Response: - if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}{path}' - return await super().get(path, *args, **kwargs) + async def get( + self, path: str, *, headers: dict[str, str] | None = None + ) -> http.Response: + if self.base_url: + path = self.base_url + path + if self.url_munger: + path = self.url_munger(path) + return await self.http_client.get(path, headers=headers) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args) -> None: # type: ignore + pass diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index a34d49f2fae..73d1cf03666 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -19,23 +19,29 @@ import json -from typing import cast, Any +from typing import cast, Any, Callable from edb.server.protocol import execute +from edb.server.http import HttpClient from . import github, google, azure, apple, discord, slack -from . import config, errors, util, data, base, http_client +from . import config, errors, util, data, base, http_client as _http_client class Client: provider: base.BaseProvider def __init__( - self, db: Any, provider_name: str, base_url: str | None = None + self, + *, + db: Any, + provider_name: str, + http_client: HttpClient, + url_munger: Callable[[str], str] | None = None, ): self.db = db - http_factory = lambda *args, **kwargs: http_client.HttpClient( - *args, edgedb_test_url=base_url, **kwargs + http_factory = lambda *args, **kwargs: _http_client.AuthHttpClient( + *args, url_munger=url_munger, http_client=http_client, **kwargs # type: ignore ) provider_config = self._get_provider_config(provider_name) diff --git a/edb/server/tenant.py b/edb/server/tenant.py index b063a1e3ab1..d0ce0c0ecb1 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -68,6 +68,7 @@ from .ha import adaptive as adaptive_ha from .ha import base as ha_base +from .http import HttpClient from .pgcon import errors as pgcon_errors if TYPE_CHECKING: @@ -80,6 +81,9 @@ logger = logging.getLogger("edb.server") +HTTP_MAX_CONNECTIONS = 100 + + class RoleDescriptor(TypedDict): superuser: bool name: str @@ -132,6 +136,8 @@ class Tenant(ha_base.ClusterProtocol): _jwt_revocation_list_file: pathlib.Path | None _jwt_revocation_list: frozenset[str] | None + _http_client: HttpClient | None + def __init__( self, cluster: pgcluster.BaseCluster, @@ -203,6 +209,8 @@ def __init__( self._jwt_revocation_list_file = None self._jwt_revocation_list = None + self._http_client = None + # If it isn't stored in instdata, it is the old default. self.default_database = defines.EDGEDB_OLD_DEFAULT_DB @@ -238,6 +246,11 @@ def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() + def get_http_client(self) -> HttpClient: + if self._http_client is None: + self._http_client = HttpClient(HTTP_MAX_CONNECTIONS) + return self._http_client + def on_switch_over(self): # Bumping this serial counter will "cancel" all pending connections # to the old master. diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 99fc8d9f4ff..52a80af9c2e 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -330,9 +330,12 @@ def log_message(self, *args): class MultiHostMockHttpServerHandler(MockHttpServerHandler): def get_server_and_path(self) -> tuple[str, str]: - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) - return server, path + # Path looks like: + # http://127.0.0.1:32881/https%3A//slack.com/.well-known/openid-configuration + raw_url = urllib.parse.unquote(self.path.lstrip('/')) + url = urllib.parse.urlparse(raw_url) + return (f'{url.scheme}://{url.netloc}', + url.path.lstrip('/')) ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] @@ -416,9 +419,10 @@ def handle_request( body=body, ) self.requests[key].append(request_details) - if key not in self.routes: - handler.send_error(404) + error_message = (f"No route handler for {key}\n\n" + f"Available routes:\n{self.routes}") + handler.send_error(404, message=error_message) return registered_handler = self.routes[key] diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 357e829b8e0..eaefccde7e6 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -1045,7 +1045,7 @@ async def test_http_auth_ext_discord_callback_01(self): ) ) - user_request = ("GET", "https://discord.com/api/v10", "users/@me") + user_request = ("GET", "https://discord.com", "api/v10/users/@me") self.mock_oauth_server.register_route_handler(*user_request)( ( json.dumps( @@ -1419,8 +1419,8 @@ async def test_http_auth_ext_azure_authorize_01(self): discovery_request = ( "GET", - "https://login.microsoftonline.com/common/v2.0", - ".well-known/openid-configuration", + "https://login.microsoftonline.com", + "common/v2.0/.well-known/openid-configuration", ) self.mock_oauth_server.register_route_handler(*discovery_request)( ( @@ -1493,8 +1493,8 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: discovery_request = ( "GET", - "https://login.microsoftonline.com/common/v2.0", - ".well-known/openid-configuration", + "https://login.microsoftonline.com", + "common/v2.0/.well-known/openid-configuration", ) self.mock_oauth_server.register_route_handler(*discovery_request)( ( From a72c8d9db4558b9c9144b0e7307b9a8702913cf4 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 23 Oct 2024 14:05:09 -0400 Subject: [PATCH 17/18] ext: fix issues in net gc task (#7904) * don't run on dropping databases * don't call asyncio.wait([]) * fix wrong logging without decoding JSON --- edb/server/net_worker.py | 16 ++++++++++----- edb/server/protocol/auth_ext/pkce.py | 30 +++++++++++++++++----------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 33cd658d0a6..4267689836d 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -265,7 +265,11 @@ def _warn(e): ) async for iteration in rloop: async with iteration: - result = await execute.parse_execute_json( + if not db.tenant.is_database_connectable(db.name): + # Don't run the net_worker if the database is not + # connectable, e.g. being dropped + continue + result_json = await execute.parse_execute_json( db, """ with requests := ( @@ -274,14 +278,15 @@ def _warn(e): and (datetime_of_statement() - .updated_at) > $expires_in ) - delete requests; + select count((delete requests)); """, variables={"expires_in": expires_in.to_backend_str()}, cached_globally=True, tx_isolation=defines.TxIsolationLevel.RepeatableRead, ) - if len(result) > 0: - logger.info(f"Deleted requests: {result!r}") + result: list[int] = json.loads(result_json) + if result[0] > 0: + logger.info(f"Deleted {result[0]} requests") else: logger.info(f"No requests to delete") @@ -311,7 +316,8 @@ async def gc(server: edbserver.BaseServer) -> None: if tenant.accept_new_tasks ] try: - await asyncio.wait(tasks) + if tasks: + await asyncio.wait(tasks) except Exception as ex: logger.debug( "GC of std::net::http::ScheduledRequest failed", exc_info=ex diff --git a/edb/server/protocol/auth_ext/pkce.py b/edb/server/protocol/auth_ext/pkce.py index 028a6e9e162..a8acfdc5752 100644 --- a/edb/server/protocol/auth_ext/pkce.py +++ b/edb/server/protocol/auth_ext/pkce.py @@ -158,23 +158,29 @@ async def delete(db: edbtenant.dbview.Database, id: str) -> None: assert len(result_json) == 1 +async def _delete_challenge(db: edbtenant.dbview.Database) -> None: + if not db.tenant.is_database_connectable(db.name): + # Don't run gc if the database is not connectable, e.g. being dropped + return + + await execute.parse_execute_json( + db, + """ + delete ext::auth::PKCEChallenge filter + (datetime_of_statement() - .created_at) > + $validity + """, + variables={"validity": VALIDITY.to_backend_str()}, + cached_globally=True, + ) + + async def _gc(tenant: edbtenant.Tenant) -> None: try: async with asyncio.TaskGroup() as g: for db in tenant.iter_dbs(): if "auth" in db.extensions: - g.create_task( - execute.parse_execute_json( - db, - """ - delete ext::auth::PKCEChallenge filter - (datetime_of_statement() - .created_at) > - $validity - """, - variables={"validity": VALIDITY.to_backend_str()}, - cached_globally=True, - ), - ) + g.create_task(_delete_challenge(db)) except Exception as ex: logger.debug( "GC of ext::auth::PKCEChallenge failed (instance: %s)", From 584795dd444b615d29ddaf1f6f642f93ec487ce9 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 23 Oct 2024 11:25:25 -0700 Subject: [PATCH 18/18] Rename net_http_max_connections to http_max_connections (#7914) We're going to use it to cap other outgoing http connections also. --- docs/reference/configuration.rst | 5 ++--- edb/buildmeta.py | 2 +- edb/lib/cfg.edgeql | 4 ++-- edb/server/net_worker.py | 12 ++++++------ 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/reference/configuration.rst b/docs/reference/configuration.rst index d10dae96d20..2179cd93054 100644 --- a/docs/reference/configuration.rst +++ b/docs/reference/configuration.rst @@ -74,9 +74,8 @@ Resource usage :eql:synopsis:`shared_buffers -> cfg::memory` The amount of memory used for shared memory buffers. -:eql:synopsis:`net_http_max_connections -> int64` - The maximum number of concurrent HTTP connections to allow when using the - ``std::net::http`` module. +:eql:synopsis:`http_max_connections -> int64` + The maximum number of concurrent outbound HTTP connections to allow. Query planning -------------- diff --git a/edb/buildmeta.py b/edb/buildmeta.py index 90becaf4c85..ca4e4bf9783 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_10_17_00_00 +EDGEDB_CATALOG_VERSION = 2024_10_23_00_00 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index bdca5d2300b..939913da510 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -238,8 +238,8 @@ ALTER TYPE cfg::AbstractConfig { 'Where the query cache is finally stored'; }; - # std::net::http Configuration - CREATE PROPERTY net_http_max_connections -> std::int64 { + # HTTP Worker Configuration + CREATE PROPERTY http_max_connections -> std::int64 { SET default := 10; CREATE ANNOTATION std::description := 'The maximum number of concurrent HTTP connections.'; diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 4267689836d..8c82fa75e84 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -46,10 +46,10 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: - net_http_max_connections = tenant._server.config_lookup( - 'net_http_max_connections', tenant.get_sys_config() + http_max_connections = tenant._server.config_lookup( + 'http_max_connections', tenant.get_sys_config() ) - http_client._update_limit(net_http_max_connections) + http_client._update_limit(http_max_connections) try: async with (asyncio.TaskGroup() as g,): for db in tenant.iter_dbs(): @@ -99,10 +99,10 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: def create_http(tenant: edbtenant.Tenant): - net_http_max_connections = tenant._server.config_lookup( - 'net_http_max_connections', tenant.get_sys_config() + http_max_connections = tenant._server.config_lookup( + 'http_max_connections', tenant.get_sys_config() ) - return HttpClient(net_http_max_connections) + return HttpClient(http_max_connections) async def http(server: edbserver.BaseServer) -> None: