From 1c256eb05bfd7ad94e512736f084a1901f0ff9b7 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 31 Jan 2024 16:47:50 -0500 Subject: [PATCH] Recompile all cache entries on DDL --- edb/pgsql/metaschema.py | 30 ++++++++++++++ edb/server/compiler_pool/pool.py | 13 ++++++ edb/server/dbview/dbview.pyx | 70 +++++++++++++++++++++++++++++++- edb/server/protocol/execute.pyx | 32 +++++++++++++-- 4 files changed, 140 insertions(+), 5 deletions(-) diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index a422fa851aa..2abce36f5fe 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -206,6 +206,35 @@ def __init__(self) -> None: ) +class ClearQueryCacheFunction(dbops.Function): + + # TODO(fantix): this may consume a lot of memory in Postgres + text = f''' + DECLARE + row record; + BEGIN + FOR row IN + DELETE FROM "edgedb"."_query_cache" + RETURNING "input", "evict" + LOOP + EXECUTE row."evict"; + RETURN NEXT row."input"; + END LOOP; + END; + ''' + + def __init__(self) -> None: + super().__init__( + name=('edgedb', '_clear_query_cache'), + args=[], + returns=('bytea',), + set_returning=True, + language='plpgsql', + volatility='volatile', + text=self.text, + ) + + class BigintDomain(dbops.Domain): """Bigint: a variant of numeric that enforces zero digits after the dot. @@ -4437,6 +4466,7 @@ async def bootstrap( dbops.CreateTable(QueryCacheTable()), dbops.Query(DMLDummyTable.SETUP_QUERY), dbops.CreateFunction(EvictQueryCacheFunction()), + dbops.CreateFunction(ClearQueryCacheFunction()), dbops.CreateFunction(UuidGenerateV1mcFunction('edgedbext')), dbops.CreateFunction(UuidGenerateV4Function('edgedbext')), dbops.CreateFunction(UuidGenerateV5Function('edgedbext')), diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index be70d9880d5..3c4b23f2ccc 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -670,6 +670,9 @@ async def analyze_explain_output( def get_debug_info(self): return {} + def get_size_hint(self) -> int: + raise NotImplementedError + class BaseLocalPool( AbstractPool, amsg.ServerProtocol, asyncio.SubprocessProtocol @@ -948,6 +951,9 @@ async def _stop(self): await trans._wait() trans.close() + def get_size_hint(self) -> int: + return self._pool_size + @srvargs.CompilerPoolMode.OnDemand.assign_implementation class SimpleAdaptivePool(BaseLocalPool): @@ -1071,6 +1077,9 @@ def _scale_down(self): )[:-self._pool_size]: worker.close() + def get_size_hint(self) -> int: + return self._max_num_workers + class RemoteWorker(BaseWorker): def __init__(self, con, secret, *args): @@ -1098,6 +1107,7 @@ def __init__(self, *, address, pool_size, **kwargs): self._worker = None self._sync_lock = asyncio.Lock() self._semaphore = asyncio.BoundedSemaphore(pool_size) + self._pool_size = pool_size secret = os.environ.get("_EDGEDB_SERVER_COMPILER_POOL_SECRET") if not secret: raise AssertionError( @@ -1249,6 +1259,9 @@ def get_debug_info(self): free=self._semaphore._value, # type: ignore ) + def get_size_hint(self) -> int: + return self._pool_size + @dataclasses.dataclass class TenantSchema: diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 309c727b274..57ea4ce7491 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -34,12 +34,13 @@ import weakref import immutables from edb import errors -from edb.common import lru, uuidgen +from edb.common import lru, taskgroup, uuidgen from edb import edgeql from edb.edgeql import qltypes from edb.schema import schema as s_schema from edb.server import compiler, defines, config, metrics from edb.server.compiler import dbstate, enums, sertypes +from edb.server.protocol import execute from edb.pgsql import dbops from edb.server.compiler_pool import state as compiler_state_mod @@ -914,6 +915,73 @@ cdef class DatabaseConnectionView: self._reset_tx_state() return side_effects + async def clear_cache_keys(self, conn) -> list[bytes]: + rows = await conn.sql_fetch(b'SELECT "edgedb"."_clear_query_cache"()') + self._db._query_cache.clear() + return [row[0] for row in rows or []] + + async def recompile_all(self, conn, requests: typing.Iterable[bytes]): + compiler_pool = self.server.get_compiler_pool() + concurrency = max(1, compiler_pool.get_size_hint() - 1) + i = asyncio.Queue(maxsize=concurrency) + o = asyncio.Queue() + + async def recompile_request(): + while True: + request = await i.get() + if request is None: + o.put_nowait((None, None)) + break + try: + result = await compiler_pool.compile( + self.dbname, + self.get_user_schema_pickle(), + self.get_global_schema_pickle(), + self.reflection_cache, + self.get_database_config(), + self.get_compilation_system_config(), + request, + client_id=self.tenant.client_id, + ) + except Exception: + # discard cache entry that cannot be recompiled + pass + else: + o.put_nowait((request, result[0])) + + async def persist_cache(): + count = concurrency + while count > 0: + request, response = await o.get() + if request is None: + count -= 1 + else: + query_unit_group = pickle.loads(response) + assert len(query_unit_group) == 1 + query_unit = query_unit_group[0] + key = query_unit_group.cache_key + assert key is not None + await execute.persist_cache_spec( + conn, + self, + query_unit, + request, + response, + key, + ) + self._db._query_cache[key] = ( + query_unit_group, self.schema_version + ) + + async with taskgroup.TaskGroup() as g: + for _ in range(concurrency): + g.create_task(recompile_request()) + g.create_task(persist_cache()) + for data in requests: + await i.put(data) + for _ in range(concurrency): + await i.put(None) + async def apply_config_ops(self, conn, ops): settings = self.get_config_spec() diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index ab4ec9a2877..577efbb65e1 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from typing import ( Any, Mapping, @@ -64,6 +63,23 @@ async def persist_cache( ): assert len(compiled.query_unit_group) == 1 query_unit = compiled.query_unit_group[0] + return await persist_cache_spec( + be_conn, + dbv, + query_unit, + compiled.request.serialize(), + compiled.serialized, + compiled.request.get_cache_key(), + ) + +async def persist_cache_spec( + be_conn: pgcon.PGConnection, + dbv: dbview.DatabaseConnectionView, + query_unit, + request, + response, + cache_key, +): persist, evict = query_unit.cache_sql await be_conn.sql_execute((evict, persist)) await be_conn.sql_fetch( @@ -73,10 +89,10 @@ async def persist_cache( b'ON CONFLICT (key) DO UPDATE SET ' b'"schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5', args=( - compiled.request.get_cache_key().bytes, + cache_key.bytes, dbv.schema_version.bytes, - compiled.request.serialize(), - compiled.serialized, + request, + response, evict, ), use_prep_stmt=True, @@ -114,6 +130,7 @@ async def execute( new_types = None server = dbv.server tenant = dbv.tenant + recompile_requests = None data = None @@ -140,7 +157,12 @@ async def execute( await persist_cache(be_conn, dbv, compiled) if query_unit.sql: + if query_unit.has_ddl: + # TODO(fantix): do this in the same transaction + recompile_requests = await dbv.clear_cache_keys(be_conn) if query_unit.ddl_stmt_id: + await be_conn.sql_execute( + b'delete from "edgedb"."_query_cache"') ddl_ret = await be_conn.run_ddl(query_unit, state) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] @@ -230,6 +252,8 @@ async def execute( # 1. An orphan ROLLBACK command without a paring start tx # 2. There was no SQL, so the state can't have been synced. be_conn.last_state = state + if recompile_requests: + await dbv.recompile_all(be_conn, recompile_requests) finally: if query_unit.drop_db: tenant.allow_database_connections(query_unit.drop_db)