From fb09377c957567a5dd703cec0a698c7ddc0ccf50 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 5 Mar 2024 18:16:38 -0500 Subject: [PATCH] Schema version in query cache key (#6990) so that the cache table can be append-only to avoid concurrent update errors. --- edb/server/compiler/rpc.pxd | 1 + edb/server/compiler/rpc.pyi | 3 + edb/server/compiler/rpc.pyx | 10 ++ edb/server/dbview/dbview.pxd | 8 +- edb/server/dbview/dbview.pyx | 234 +++++++++++++++++--------------- edb/server/protocol/binary.pyx | 61 ++++----- edb/server/protocol/execute.pyx | 6 +- tests/test_server_compiler.py | 12 +- 8 files changed, 183 insertions(+), 152 deletions(-) diff --git a/edb/server/compiler/rpc.pxd b/edb/server/compiler/rpc.pxd index 3c8e9f08f7b..946e91b1804 100644 --- a/edb/server/compiler/rpc.pxd +++ b/edb/server/compiler/rpc.pxd @@ -40,6 +40,7 @@ cdef class CompilationRequest: readonly object session_config object database_config object system_config + object schema_version bytes serialized_cache object cache_key diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index a9037611463..dd45efed8ef 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -80,6 +80,9 @@ class CompilationRequest: ) -> CompilationRequest: ... + def set_schema_version(self, version: uuid.UUID) -> CompilationRequest: + ... + def serialize(self) -> bytes: ... diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index 76e7ff9ca2d..4a9a28821ad 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -131,6 +131,12 @@ cdef class CompilationRequest: self.cache_key = None return self + def set_schema_version(self, version: uuid.UUID) -> CompilationRequest: + self.schema_version = version + self.serialized_cache = None + self.cache_key = None + return self + def deserialize(self, bytes data, str query_text) -> CompilationRequest: if data[0] == 0: self._deserialize_v0(data, query_text) @@ -217,6 +223,10 @@ cdef class CompilationRequest: ) hash_obj.update(serialized_comp_config) + # Must set_schema_version() before serializing compilation request + assert self.schema_version is not None + hash_obj.update(self.schema_version.bytes) + cache_key_bytes = hash_obj.digest() self.cache_key = uuidgen.from_bytes(cache_key_bytes) diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 96f2050e410..b83357c3a4a 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -83,7 +83,7 @@ cdef class Database: cdef schedule_config_update(self) cdef _invalidate_caches(self) - cdef _cache_compiled_query(self, key, compiled, schema_version) + cdef _cache_compiled_query(self, key, compiled) cdef _new_view(self, query_cache, protocol_version) cdef _remove_view(self, view) cdef _update_backend_ids(self, new_types) @@ -152,6 +152,7 @@ cdef class DatabaseConnectionView: object __weakref__ cdef _reset_tx_state(self) + cdef inline _check_in_tx_error(self, query_unit_group) cdef clear_tx_error(self) cdef rollback_tx_to_savepoint(self, name) @@ -162,10 +163,9 @@ cdef class DatabaseConnectionView: cpdef in_tx(self) cpdef in_tx_error(self) - cdef cache_compiled_query( - self, object key, object query_unit_group, schema_version - ) + cdef cache_compiled_query(self, object key, object query_unit_group) cdef lookup_compiled_query(self, object key) + cdef as_compiled(self, query_req, query_unit_group, bint use_metrics=?) cdef tx_error(self) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index cc732cc4b8e..4cfc366b040 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -182,22 +182,19 @@ cdef class Database: self._sql_to_compiled.clear() self._index.invalidate_caches() - cdef _cache_compiled_query( - self, key, compiled: dbstate.QueryUnitGroup, schema_version - ): + cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): # `dbver` must be the schema version `compiled` was compiled upon assert compiled.cacheable - existing, existing_ver = self._eql_to_compiled.get(key, DICTDEFAULT) - if existing is not None and existing_ver == self.schema_version: - # We already have a cached query for a more recent DB version. + if key in self._eql_to_compiled: + # We already have a cached query for the current user schema return - # Store the matching schema version, see also the comments at origin - self._eql_to_compiled[key] = compiled, schema_version + self._eql_to_compiled[key] = compiled + # TODO(fantix): merge in-memory cleanup into the task below keys = [] while self._eql_to_compiled.needs_cleanup(): - query_req, (unit_group, _) = self._eql_to_compiled.cleanup_one() + query_req, unit_group = self._eql_to_compiled.cleanup_one() if len(unit_group) == 1: keys.append(query_req.get_cache_key()) if keys: @@ -248,19 +245,17 @@ cdef class Database: def hydrate_cache(self, query_cache): new = set() - for schema_version, in_data, out_data in query_cache: + for _, in_data, out_data in query_cache: query_req = rpc.CompilationRequest( self.server.compilation_config_serializer) query_req.deserialize(in_data, "") - schema_version = uuidgen.from_bytes(schema_version) new.add(query_req) - _, cached_ver = self._eql_to_compiled.get(query_req, DICTDEFAULT) - if cached_ver != schema_version: + if query_req not in self._eql_to_compiled: unit = dbstate.QueryUnit.deserialize(out_data) group = dbstate.QueryUnitGroup() group.append(unit) - self._eql_to_compiled[query_req] = group, schema_version + self._eql_to_compiled[query_req] = group for query_req in list(self._eql_to_compiled): if query_req not in new: @@ -710,15 +705,11 @@ cdef class DatabaseConnectionView: cpdef in_tx_error(self): return self._tx_error - cdef cache_compiled_query( - self, object key, object query_unit_group, schema_version - ): + cdef cache_compiled_query(self, object key, object query_unit_group): assert query_unit_group.cacheable if not self._in_tx_with_ddl: - self._db._cache_compiled_query( - key, query_unit_group, schema_version - ) + self._db._cache_compiled_query(key, query_unit_group) cdef lookup_compiled_query(self, object key): if (self._tx_error or @@ -726,12 +717,7 @@ cdef class DatabaseConnectionView: self._in_tx_with_ddl): return None - query_unit_group, qu_ver = self._db._eql_to_compiled.get( - key, DICTDEFAULT) - if query_unit_group is not None and qu_ver != self._db.schema_version: - query_unit_group = None - - return query_unit_group + return self._db._eql_to_compiled.get(key, None) cdef tx_error(self): if self._in_tx: @@ -951,13 +937,15 @@ cdef class DatabaseConnectionView: async with concurrency_control: try: schema_version = self.schema_version + database_config = self.get_database_config() + system_config = self.get_compilation_system_config() result = await compiler_pool.compile( self.dbname, self.get_user_schema_pickle(), self.get_global_schema_pickle(), self.reflection_cache, - self.get_database_config(), - self.get_compilation_system_config(), + database_config, + system_config, query_req.serialize(), "", client_id=self.tenant.client_id, @@ -966,9 +954,16 @@ cdef class DatabaseConnectionView: # discard cache entry that cannot be recompiled self._db._eql_to_compiled.pop(query_req, None) else: - await compiled_queue.put( - (query_req, result[0], schema_version) - ) + # schema_version, database_config and system_config are not + # serialized but only affect the cache key. We only update + # these values *after* the compilation so that we can evict + # the in-memory cache by the right key when recompilation + # fails in the `except` branch above. + query_req.set_schema_version(schema_version) + query_req.set_database_config(database_config) + query_req.set_system_config(system_config) + + await compiled_queue.put((query_req, result[0])) async def persist_cache_task(): if not debug.flags.func_cache: @@ -988,9 +983,9 @@ cdef class DatabaseConnectionView: await execute.persist_cache( conn, self, [item[:2] for item in buf] ) - for query_req, query_unit_group, schema_version in buf: + for query_req, query_unit_group in buf: self._db._cache_compiled_query( - query_req, query_unit_group, schema_version) + query_req, query_unit_group) buf.clear() async with asyncio.TaskGroup() as g: @@ -1030,108 +1025,99 @@ cdef class DatabaseConnectionView: self, query_req: rpc.CompilationRequest, cached_globally=False, - use_metrics=True, - # allow passing in the query_unit_group, in case the call site needs - # to make decisions based on whether the cache is hit - query_unit_group=None, + bint use_metrics=True, uint64_t allow_capabilities = compiler.Capability.ALL, ) -> CompiledQuery: - source = query_req.source - if cached_globally: - # WARNING: only set cached_globally to True when the query is - # strictly referring to only shared stable objects in user schema - # or anything from std schema, for example: - # YES: select ext::auth::UIConfig { ... } - # NO: select default::User { ... } - query_unit_group = ( - self.server.system_compile_cache.get(query_req) - if self._query_cache_enabled - else None - ) - elif query_unit_group is None: - query_unit_group = self.lookup_compiled_query(query_req) + query_unit_group = None + if self._query_cache_enabled: + if cached_globally: + # WARNING: only set cached_globally to True when the query is + # strictly referring to only shared stable objects in user + # schema or anything from std schema, for example: + # YES: select ext::auth::UIConfig { ... } + # NO: select default::User { ... } + query_unit_group = ( + self.server.system_compile_cache.get(query_req) + ) + else: + query_unit_group = self.lookup_compiled_query(query_req) + + # Fast-path to skip all the locks if it's a cache HIT + if query_unit_group is not None: + return self.as_compiled( + query_req, query_unit_group, use_metrics) lock = None - cached = True + schema_version = self.schema_version - if query_unit_group is None: - # Lock on the query compilation to avoid other coroutines running - # the same compile and waste computational resources - if cached_globally: - lock_table = self.server.system_compile_cache_locks - else: - lock_table = self._db._cache_locks + # Lock on the query compilation to avoid other coroutines running + # the same compile and waste computational resources + if cached_globally: + lock_table = self.server.system_compile_cache_locks + else: + lock_table = self._db._cache_locks + while True: + # We need a loop here because schema_version is a part of the key, + # there could be a DDL while we're waiting for the lock. lock = lock_table.get(query_req) if lock is None: lock = asyncio.Lock() lock_table[query_req] = lock await lock.acquire() + if self.schema_version == schema_version: + break + else: + lock.release() + if not lock._waiters: + del lock_table[query_req] + schema_version = self.schema_version + # Updating the schema_version will make query_req a new key + query_req.set_schema_version(schema_version) try: # Check the cache again with the lock acquired - if query_unit_group is None and self._query_cache_enabled: + if self._query_cache_enabled: if cached_globally: query_unit_group = ( self.server.system_compile_cache.get(query_req) ) else: query_unit_group = self.lookup_compiled_query(query_req) + if query_unit_group is not None: + return self.as_compiled( + query_req, query_unit_group, use_metrics) - if query_unit_group is None: - # Cache miss; need to compile this query. - cached = False - # Remember the schema version we are compiling on, so that we - # can cache the result with the matching version. In case of - # concurrent schema update, we're only storing an outdated - # cache entry, and the next identical query could get - # recompiled on the new schema. - schema_version = self.schema_version - - try: - query_unit_group = await self._compile(query_req) - except (errors.EdgeQLSyntaxError, errors.InternalServerError): + try: + query_unit_group = await self._compile(query_req) + except (errors.EdgeQLSyntaxError, errors.InternalServerError): + raise + except errors.EdgeDBError: + if self.in_tx_error(): + # Because we are in an error state it's more reasonable + # to fail with TransactionError("commands ignored") + # rather than with a potentially more cryptic error. + # An exception from this rule are syntax errors and + # ISEs, because these could arise while the user is + # trying to properly rollback this failed transaction. + self.raise_in_tx_error() + else: raise - except errors.EdgeDBError: - if self.in_tx_error(): - # Because we are in an error state it's more reasonable - # to fail with TransactionError("commands ignored") - # rather than with a potentially more cryptic error. - # An exception from this rule are syntax errors and - # ISEs, because these could arise while the user is - # trying to properly rollback this failed transaction. - self.raise_in_tx_error() - else: - raise - - self.check_capabilities( - query_unit_group.capabilities, - allow_capabilities, - errors.DisabledCapabilityError, - "disabled by the client", - ) - if self.in_tx_error(): - # The current transaction is aborted, so we must fail - # all commands except ROLLBACK or ROLLBACK TO SAVEPOINT. - first = query_unit_group[0] - if ( - not ( - first.tx_rollback - or first.tx_savepoint_rollback - or first.tx_abort_migration - ) or len(query_unit_group) > 1 - ): - self.raise_in_tx_error() + self.check_capabilities( + query_unit_group.capabilities, + allow_capabilities, + errors.DisabledCapabilityError, + "disabled by the client", + ) + self._check_in_tx_error(query_unit_group) - if not cached and query_unit_group.cacheable: + if self._query_cache_enabled and query_unit_group.cacheable: if cached_globally: self.server.system_compile_cache[query_req] = ( query_unit_group ) else: - self.cache_compiled_query( - query_req, query_unit_group, schema_version - ) + self.cache_compiled_query(query_req, query_unit_group) finally: if lock is not None: lock.release() @@ -1140,11 +1126,10 @@ cdef class DatabaseConnectionView: if use_metrics: metrics.edgeql_query_compilations.inc( - 1.0, - self.tenant.get_instance_name(), - 'cache' if cached else 'compiler', + 1.0, self.tenant.get_instance_name(), 'compiler' ) + source = query_req.source return CompiledQuery( query_unit_group=query_unit_group, first_extra=source.first_extra(), @@ -1153,6 +1138,35 @@ cdef class DatabaseConnectionView: request=query_req, ) + cdef inline _check_in_tx_error(self, query_unit_group): + if self.in_tx_error(): + # The current transaction is aborted, so we must fail + # all commands except ROLLBACK or ROLLBACK TO SAVEPOINT. + first = query_unit_group[0] + if ( + not ( + first.tx_rollback + or first.tx_savepoint_rollback + or first.tx_abort_migration + ) or len(query_unit_group) > 1 + ): + self.raise_in_tx_error() + + cdef as_compiled(self, query_req, query_unit_group, bint use_metrics=True): + self._check_in_tx_error(query_unit_group) + if use_metrics: + metrics.edgeql_query_compilations.inc( + 1.0, self.tenant.get_instance_name(), 'cache' + ) + + source = query_req.source + return CompiledQuery( + query_unit_group=query_unit_group, + first_extra=source.first_extra(), + extra_counts=source.extra_counts(), + extra_blobs=source.extra_blobs(), + ) + async def _compile( self, query_req: rpc.CompilationRequest, diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index ef7cfc63444..e26ff3a7ed8 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -547,24 +547,21 @@ cdef class EdgeConnection(frontend.FrontendConnection): 'after', source.first_extra()) query_unit_group = dbv.lookup_compiled_query(query_req) - - # If we have to do a compile within a transaction, suppress - # the idle_in_transaction_session_timeout. - suppress_timeout = ( - dbv.in_tx() and not dbv.in_tx_error() and query_unit_group is None - ) - if suppress_timeout: - await self._suppress_tx_timeout() - - try: - return await dbv.parse( - query_req, - query_unit_group=query_unit_group, - allow_capabilities=allow_capabilities, - ) - finally: + if query_unit_group is None: + # If we have to do a compile within a transaction, suppress + # the idle_in_transaction_session_timeout. + suppress_timeout = dbv.in_tx() and not dbv.in_tx_error() if suppress_timeout: - await self._restore_tx_timeout(dbv) + await self._suppress_tx_timeout() + try: + return await dbv.parse( + query_req, allow_capabilities=allow_capabilities + ) + finally: + if suppress_timeout: + await self._restore_tx_timeout(dbv) + else: + return dbv.as_compiled(query_req, query_unit_group) cdef parse_cardinality(self, bytes card): if card[0] == CARD_MANY.value: @@ -758,6 +755,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): object output_format bint expect_one = False bytes query + dbview.DatabaseConnectionView _dbview allow_capabilities = self.buffer.read_int64() compilation_flags = self.buffer.read_int64() @@ -790,10 +788,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): if not query: raise errors.BinaryProtocolError('empty query') + _dbview = self.get_dbview() state_tid = self.buffer.read_bytes(16) state_data = self.buffer.read_len_prefixed_bytes() try: - self.get_dbview().decode_state(state_tid, state_data) + _dbview.decode_state(state_tid, state_data) except errors.StateMismatchError: self.write(self.make_state_data_description_msg()) raise @@ -808,7 +807,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): inline_typeids=inline_typeids, inline_typenames=inline_typenames, inline_objectids=inline_objectids, - ) + ).set_schema_version(_dbview.schema_version) + rv.set_modaliases(_dbview.get_modaliases()) + rv.set_session_config(_dbview.get_session_config()) + rv.set_database_config(_dbview.get_database_config()) + rv.set_system_config(_dbview.get_compilation_system_config()) return rv, allow_capabilities async def parse(self): @@ -828,10 +831,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): if _dbview.get_state_serializer() is None: await _dbview.reload_state_serializer() query_req, allow_capabilities = self.parse_execute_request() - query_req.set_modaliases(_dbview.get_modaliases()) - query_req.set_session_config(_dbview.get_session_config()) - query_req.set_database_config(_dbview.get_database_config()) - query_req.set_system_config(_dbview.get_compilation_system_config()) compiled = await self._parse(query_req, allow_capabilities) units = compiled.query_unit_group if len(units) == 1 and units[0].cache_sql: @@ -872,10 +871,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): in_tid = self.buffer.read_bytes(16) out_tid = self.buffer.read_bytes(16) args = self.buffer.read_len_prefixed_bytes() - query_req.set_modaliases(_dbview.get_modaliases()) - query_req.set_session_config(_dbview.get_session_config()) - query_req.set_database_config(_dbview.get_database_config()) - query_req.set_system_config(_dbview.get_compilation_system_config()) self.buffer.finish_message() if ( @@ -1430,16 +1425,14 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.flush() async def _execute_utility_stmt(self, eql: str, pgcon): - cdef dbview.DatabaseConnectionView _dbview + cdef dbview.DatabaseConnectionView _dbview = self.get_dbview() query_req = rpc.CompilationRequest( self.server.compilation_config_serializer ) query_req.update( edgeql.Source.from_string(eql), self.protocol_version - ) - - _dbview = self.get_dbview() + ).set_schema_version(_dbview.schema_version) compiled = await _dbview.parse(query_req) query_unit_group = compiled.query_unit_group @@ -1825,17 +1818,19 @@ async def run_script( cdef: EdgeConnection conn dbview.CompiledQuery compiled + dbview.DatabaseConnectionView _dbview conn = new_edge_connection(server, tenant) await conn._start_connection(database) try: - compiled = await conn.get_dbview().parse( + _dbview = conn.get_dbview() + compiled = await _dbview.parse( rpc.CompilationRequest( server.compilation_config_serializer ).update( edgeql.Source.from_string(script), conn.protocol_version, output_format=FMT_NONE, - ) + ).set_schema_version(_dbview.schema_version) ) if len(compiled.query_unit_group) > 1: await conn._execute_script(compiled, b'') diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index e1d14861a7a..f35584da29c 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -376,7 +376,7 @@ async def execute( else: recompile_requests = [ req - for req, (grp, _) in dbv._db._eql_to_compiled.items() + for req, grp in dbv._db._eql_to_compiled.items() if len(grp) == 1 ] await dbv.recompile_all(be_conn, recompile_requests) @@ -581,7 +581,7 @@ async def execute_script( else: recompile_requests = [ req - for req, (grp, _) in dbv._db._eql_to_compiled.items() + for req, grp in dbv._db._eql_to_compiled.items() if len(grp) == 1 ] @@ -740,7 +740,7 @@ async def parse_execute_json( protocol_version=edbdef.CURRENT_PROTOCOL, input_format=compiler.InputFormat.JSON, output_format=output_format, - ) + ).set_schema_version(dbv.schema_version) compiled = await dbv.parse( query_req, diff --git a/tests/test_server_compiler.py b/tests/test_server_compiler.py index 23b2fa9dfc6..af6331e3460 100644 --- a/tests/test_server_compiler.py +++ b/tests/test_server_compiler.py @@ -26,6 +26,7 @@ import tempfile import time import unittest.mock +import uuid import immutables @@ -442,7 +443,7 @@ async def _test_pool_disconnect_queue(self, pool_class): source=edgeql.Source.from_string(orig_query), protocol_version=(1, 0), implicit_limit=101, - ) + ).set_schema_version(uuid.uuid4()) await asyncio.gather(*(pool_.compile_in_tx( context.state.current_tx().id, @@ -473,12 +474,19 @@ def test(source: edgeql.Source): ).update( source=source, protocol_version=(1, 0), - ) + ).set_schema_version(uuid.uuid4()) request2 = rpc.CompilationRequest( compiler.state.compilation_config_serializer ).deserialize(request1.serialize(), "") self.assertEqual(hash(request1), hash(request2)) self.assertEqual(request1, request2) + # schema_version affects the cache_key, hence the hash. + # But, it's not serialized so the 2 requests are still equal. + # This makes request2 a new key as being used in dicts. + request2.set_schema_version(uuid.uuid4()) + self.assertNotEqual(hash(request1), hash(request2)) + self.assertEqual(request1, request2) + test(edgeql.Source.from_string("SELECT 42")) test(edgeql.NormalizedSource.from_string("SELECT 42"))