From 818a3e507760b953e9a964b92c81670ad11afa26 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Thu, 7 Mar 2024 16:58:37 -0500 Subject: [PATCH] Split recompile and writeback in transaction (#6959) --- edb/server/compiler/rpc.pyx | 20 ++++++ edb/server/dbview/dbview.pxd | 1 + edb/server/dbview/dbview.pyx | 119 +++++++++++++++---------------- edb/server/protocol/binary.pyx | 3 +- edb/server/protocol/execute.pyx | 120 +++++++++++--------------------- 5 files changed, 119 insertions(+), 144 deletions(-) diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index 4a9a28821ad..79b5ba954f2 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -80,6 +80,26 @@ cdef class CompilationRequest: ): self._serializer = compilation_config_serializer + def __copy__(self): + cdef CompilationRequest rv = CompilationRequest(self._serializer) + rv.source = self.source + rv.protocol_version = self.protocol_version + rv.output_format = self.output_format + rv.json_parameters = self.json_parameters + rv.expect_one = self.expect_one + rv.implicit_limit = self.implicit_limit + rv.inline_typeids = self.inline_typeids + rv.inline_typenames = self.inline_typenames + rv.inline_objectids = self.inline_objectids + rv.modaliases = self.modaliases + rv.session_config = self.session_config + rv.database_config = self.database_config + rv.system_config = self.system_config + rv.schema_version = self.schema_version + rv.serialized_cache = self.serialized_cache + rv.cache_key = self.cache_key + return rv + def update( self, source: edgeql.Source, diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index b83357c3a4a..badcaed1e97 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -41,6 +41,7 @@ cdef class CompiledQuery: cdef public object extra_counts cdef public object extra_blobs cdef public object request + cdef public object recompiled_cache cdef class DatabaseIndex: diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 4cfc366b040..0551238d8cd 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -22,6 +22,7 @@ from typing import ( import asyncio import base64 +import copy import json import os.path import pickle @@ -85,12 +86,14 @@ cdef class CompiledQuery: extra_counts=(), extra_blobs=(), request=None, + recompiled_cache=None, ): self.query_unit_group = query_unit_group self.first_extra = first_extra self.extra_counts = extra_counts self.extra_blobs = extra_blobs self.request = request + self.recompiled_cache = recompiled_cache cdef class Database: @@ -197,7 +200,7 @@ cdef class Database: query_req, unit_group = self._eql_to_compiled.cleanup_one() if len(unit_group) == 1: keys.append(query_req.get_cache_key()) - if keys: + if keys and debug.flags.persistent_cache: self.tenant.create_task( self.tenant.evict_query_cache(self.name, keys), interruptable=True, @@ -909,39 +912,24 @@ cdef class DatabaseConnectionView: self._reset_tx_state() return side_effects - async def clear_cache_keys(self, conn) -> list[rpc.CompilationRequest]: - rows = await conn.sql_fetch(b'SELECT "edgedb"."_clear_query_cache"()') - rv = [] - for row in rows: - query_req = rpc.CompilationRequest( - self.server.compilation_config_serializer - ).deserialize(row[0], "") - rv.append(query_req) - self._db._eql_to_compiled.pop(query_req, None) - execute.signal_query_cache_changes(self) - return rv - - async def recompile_all( - self, conn, requests: typing.Iterable[rpc.CompilationRequest] - ): - # Assume the size of compiler pool is 100, we'll issue 50 concurrent - # compilation requests at the same time, cache up to 150 results and - # persist in one backend round-trip, in parallel. + async def recompile_cached_queries(self, user_schema, schema_version): compiler_pool = self.server.get_compiler_pool() compile_concurrency = max(1, compiler_pool.get_size_hint() // 2) concurrency_control = asyncio.Semaphore(compile_concurrency) - persist_batch_size = compile_concurrency * 3 - compiled_queue = asyncio.Queue(persist_batch_size) + rv = [] async def recompile_request(query_req: rpc.CompilationRequest): async with concurrency_control: try: - schema_version = self.schema_version database_config = self.get_database_config() system_config = self.get_compilation_system_config() - result = await compiler_pool.compile( + query_req = copy.copy(query_req) + query_req.set_schema_version(schema_version) + query_req.set_database_config(database_config) + query_req.set_system_config(system_config) + unit_group, _, _ = await compiler_pool.compile( self.dbname, - self.get_user_schema_pickle(), + user_schema, self.get_global_schema_pickle(), self.reflection_cache, database_config, @@ -951,49 +939,16 @@ cdef class DatabaseConnectionView: client_id=self.tenant.client_id, ) except Exception: - # discard cache entry that cannot be recompiled - self._db._eql_to_compiled.pop(query_req, None) + # ignore cache entry that cannot be recompiled + pass else: - # schema_version, database_config and system_config are not - # serialized but only affect the cache key. We only update - # these values *after* the compilation so that we can evict - # the in-memory cache by the right key when recompilation - # fails in the `except` branch above. - query_req.set_schema_version(schema_version) - query_req.set_database_config(database_config) - query_req.set_system_config(system_config) - - await compiled_queue.put((query_req, result[0])) - - async def persist_cache_task(): - if not debug.flags.func_cache: - # TODO(fantix): sync _query_cache in one implicit tx - await conn.sql_fetch(b'SELECT "edgedb"."_clear_query_cache"()') - - buf = [] - running = True - while running: - while len(buf) < persist_batch_size: - item = await compiled_queue.get() - if item is None: - running = False - break - buf.append(item) - if buf: - await execute.persist_cache( - conn, self, [item[:2] for item in buf] - ) - for query_req, query_unit_group in buf: - self._db._cache_compiled_query( - query_req, query_unit_group) - buf.clear() + rv.append((query_req, unit_group)) async with asyncio.TaskGroup() as g: - g.create_task(persist_cache_task()) - async with asyncio.TaskGroup() as compile_group: - for req in requests: - compile_group.create_task(recompile_request(req)) - await compiled_queue.put(None) + for req, grp in self._db._eql_to_compiled.items(): + if len(grp) == 1: + g.create_task(recompile_request(req)) + return rv async def apply_config_ops(self, conn, ops): settings = self.get_config_spec() @@ -1124,6 +1079,41 @@ cdef class DatabaseConnectionView: if not lock._waiters: del lock_table[query_req] + recompiled_cache = None + if ( + not self.in_tx() + or len(query_unit_group) > 0 + and query_unit_group[0].tx_commit + ): + # Recompile all cached queries if: + # * Issued a DDL or committing a tx with DDL (recompilation + # before in-tx DDL needs to fix _in_tx_with_ddl caching 1st) + # * Config.auto_rebuild_query_cache is turned on + # + # Ideally we should compute the proper user_schema, database_config + # and system_config for recompilation from server/compiler.py with + # proper handling of config values. For now we just use the values + # in the current dbview and not support certain marginal cases. + user_schema = None + user_schema_version = None + for unit in query_unit_group: + if unit.tx_rollback: + break + if unit.user_schema: + user_schema = unit.user_schema + user_schema_version = unit.user_schema_version + if user_schema and not self.server.config_lookup( + "auto_rebuild_query_cache", + self.get_session_config(), + self.get_database_config(), + self.get_system_config(), + ): + user_schema = None + if user_schema: + recompiled_cache = await self.recompile_cached_queries( + user_schema, user_schema_version + ) + if use_metrics: metrics.edgeql_query_compilations.inc( 1.0, self.tenant.get_instance_name(), 'compiler' @@ -1136,6 +1126,7 @@ cdef class DatabaseConnectionView: extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), request=query_req, + recompiled_cache=recompiled_cache, ) cdef inline _check_in_tx_error(self, query_unit_group): diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 08d96c2c9f7..e578f37974c 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -838,7 +838,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): if len(units) == 1 and units[0].cache_sql: conn = await self.get_pgcon() try: - await execute.persist_cache(conn, _dbview, [(query_req, units)]) + g = execute.build_cache_persistence_units([(query_req, units)]) + await g.execute(conn, _dbview) finally: self.maybe_release_pgcon(conn) diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index f35584da29c..02f491bdde1 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -92,7 +92,7 @@ cdef class ExecutionGroup: if state is not None: await be_conn.wait_for_state_resp(state, state_sync=0) for i, unit in enumerate(self.group): - if unit.output_format == FMT_NONE: + if unit.output_format == FMT_NONE and unit.ddl_stmt_id is None: for sql in unit.sql: await be_conn.wait_for_command( unit, parse_array[i], dbver, ignore_data=True @@ -108,7 +108,7 @@ cdef class ExecutionGroup: return rv -cdef ExecutionGroup build_cache_persistence_units( +cpdef ExecutionGroup build_cache_persistence_units( pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]], ExecutionGroup group = None, ): @@ -118,8 +118,7 @@ cdef ExecutionGroup build_cache_persistence_units( INSERT INTO "edgedb"."_query_cache" ("key", "schema_version", "input", "output", "evict") VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (key) DO UPDATE SET - "schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5 + ON CONFLICT (key) DO NOTHING ''' sql_hash = hashlib.sha1(insert_sql).hexdigest().encode('latin1') for request, units in pairs: @@ -152,35 +151,6 @@ cdef ExecutionGroup build_cache_persistence_units( return group -async def persist_cache( - be_conn: pgcon.PGConnection, - dbv: dbview.DatabaseConnectionView, - pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]], -): - cdef group = build_cache_persistence_units(pairs) - - try: - await group.execute(be_conn, dbv) - except Exception as ex: - if ( - isinstance(ex, pgerror.BackendError) - and ex.code_is(pgerror.ERROR_SERIALIZATION_FAILURE) - # If we are in a transaction, we have to let the error - # propagate, since we can't do anything else after the error. - # Hopefully the client will retry, hit the cache, and - # everything will be fine. - and not dbv.in_tx() - ): - # XXX: Is it OK to just ignore it? Can we rely on the conflict - # having set it to the same thing? - pass - else: - dbv.on_error() - raise - else: - signal_query_cache_changes(dbv) - - # TODO: can we merge execute and execute_script? async def execute( be_conn: pgcon.PGConnection, @@ -195,6 +165,7 @@ async def execute( bytes state = None, orig_state = None WriteBuffer bound_args_buf ExecutionGroup group + bint persist_cache, persist_recompiled_query_cache query_unit = compiled.query_unit_group[0] @@ -205,6 +176,16 @@ async def execute( server = dbv.server tenant = dbv.tenant + # If we have both the compilation request and a pair of SQLs for the cache + # (persist, evict), we should follow the persistent cache route. + persist_cache = bool(compiled.request and query_unit.cache_sql) + + # Recompilation is a standalone feature than persistent cache. + # This flag indicates both features are in use, and we actually have + # recompiled the query cache to persist. + persist_recompiled_query_cache = bool( + debug.flags.persistent_cache and compiled.recompiled_cache) + data = None try: @@ -231,7 +212,24 @@ async def execute( if query_unit.sql: if query_unit.user_schema: - ddl_ret = await be_conn.run_ddl(query_unit, state) + if persist_recompiled_query_cache: + # If we have recompiled the query cache, writeback to + # the cache table here in an implicit transaction (if + # not in one already), so that whenever the transaction + # commits, we flip to using the new cache at once. + group = build_cache_persistence_units( + compiled.recompiled_cache) + group.append(query_unit) + if query_unit.ddl_stmt_id is None: + await group.execute(be_conn, dbv) + ddl_ret = None + else: + ddl_ret = be_conn.load_ddl_return( + query_unit, + await group.execute(be_conn, dbv, state=state), + ) + else: + ddl_ret = await be_conn.run_ddl(query_unit, state) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: @@ -245,7 +243,10 @@ async def execute( read_data = ( query_unit.needs_readback or query_unit.is_explain) - if compiled.request and query_unit.cache_sql: + if persist_cache: + # Persistent cache needs to happen before the actual + # query because the query may depend on the function + # created persisting the cache entry. group = build_cache_persistence_units( [(compiled.request, compiled.query_unit_group)] ) @@ -358,28 +359,11 @@ async def execute( # 1. An orphan ROLLBACK command without a paring start tx # 2. There was no SQL, so the state can't have been synced. be_conn.last_state = state - if ( - debug.flags.persistent_cache - and not dbv.in_tx() - and not query_unit.tx_rollback - and query_unit.user_schema - and server.config_lookup( - "auto_rebuild_query_cache", - dbv.get_session_config(), - dbv.get_database_config(), - dbv.get_system_config(), - ) - ): - # TODO(fantix): recompile first and update cache in tx - if debug.flags.func_cache: - recompile_requests = await dbv.clear_cache_keys(be_conn) - else: - recompile_requests = [ - req - for req, grp in dbv._db._eql_to_compiled.items() - if len(grp) == 1 - ] - await dbv.recompile_all(be_conn, recompile_requests) + if compiled.recompiled_cache: + for req, qu_group in compiled.recompiled_cache: + dbv.cache_compiled_query(req, qu_group) + if persist_cache or persist_recompiled_query_cache: + signal_query_cache_changes(dbv) finally: if query_unit.drop_db: tenant.allow_database_connections(query_unit.drop_db) @@ -564,28 +548,6 @@ async def execute_script( conn.last_state = state if unit_group.state_serializer is not None: dbv.set_state_serializer(unit_group.state_serializer) - if ( - debug.flags.persistent_cache - and not in_tx - and any(query_unit.user_schema for query_unit in unit_group) - and dbv.server.config_lookup( - "auto_rebuild_query_cache", - dbv.get_session_config(), - dbv.get_database_config(), - dbv.get_system_config(), - ) - ): - # TODO(fantix): recompile first and update cache in tx - if debug.flags.func_cache: - recompile_requests = await dbv.clear_cache_keys(conn) - else: - recompile_requests = [ - req - for req, grp in dbv._db._eql_to_compiled.items() - if len(grp) == 1 - ] - - await dbv.recompile_all(conn, recompile_requests) finally: if sent and not sync: