diff --git a/Makefile b/Makefile index 7024cfd9d20..8ad40c159b9 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,13 @@ cython: build-reqs BUILD_EXT_MODE=py-only python setup.py build_ext --inplace +# Just rebuild actually changed cython. This *should* work, since +# that is how build systems are supposed to be, but it sometimes +# fails in annoying ways. +cython-fast: build-reqs + BUILD_EXT_MODE=py-only python setup.py build_ext --inplace + + rust: build-reqs BUILD_EXT_MODE=rust-only python setup.py build_ext --inplace diff --git a/edb/common/asyncutil.py b/edb/common/asyncutil.py index 19eb55ba565..7d439a6172c 100644 --- a/edb/common/asyncutil.py +++ b/edb/common/asyncutil.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import TypeVar, Awaitable +from typing import Callable, TypeVar, Awaitable import asyncio @@ -62,3 +62,81 @@ async def deferred_shield(arg: Awaitable[_T]) -> _T: if ex: raise ex return task.result() + + +async def debounce( + input: Callable[[], Awaitable[_T]], + output: Callable[[list[_T]], Awaitable[None]], + *, + max_wait: float, + delay_amt: float, + max_batch_size: int, +) -> None: + '''Debounce and batch async events. + + Loops forever unless an operation fails, so should probably be run + from a task. + + The basic algorithm is that if an event comes in less than + `delay_amt` since the previous one, then instead of sending it + immediately, we wait an additional `delay_amt` from then. If we are + already waiting, any message also extends the wait, up to + `max_wait`. + + Also, cap the maximum batch size to `max_batch_size`. + ''' + # I think the algorithm reads more clearly with the params + # capitalized as constants, though we don't want them like that in + # the argument list, so reassign them. + MAX_WAIT, DELAY_AMT, MAX_BATCH_SIZE = max_wait, delay_amt, max_batch_size + + loop = asyncio.get_running_loop() + + batch = [] + last_signal = -MAX_WAIT + target_time = None + + while True: + try: + if target_time is None: + v = await input() + else: + async with asyncio.timeout_at(target_time): + v = await input() + except TimeoutError: + t = loop.time() + else: + batch.append(v) + + t = loop.time() + + # If we aren't current waiting, and we got a + # notification recently, arrange to wait some before + # sending it. + if ( + target_time is None + and t - last_signal < DELAY_AMT + ): + target_time = t + DELAY_AMT + # If we were already waiting, wait a little longer, though + # not longer than MAX_WAIT. + elif ( + target_time is not None + ): + target_time = min( + max(t + DELAY_AMT, target_time), + last_signal + MAX_WAIT, + ) + + # Skip sending the event if we need to wait longer. + if ( + target_time is not None + and t < target_time + and len(batch) < MAX_BATCH_SIZE + ): + continue + + await output(batch) + batch = [] + last_signal = t + target_time = None diff --git a/edb/common/debug.py b/edb/common/debug.py index 6d9ed365dd6..051959e656c 100644 --- a/edb/common/debug.py +++ b/edb/common/debug.py @@ -176,7 +176,7 @@ class flags(metaclass=FlagsMeta): zombodb = Flag(doc="Enabled zombodb and disables postgres FTS") - persistent_cache = Flag(doc="Use persistent cache") + disable_persistent_cache = Flag(doc="Don't use persistent cache") # Function cache is an experimental feature that may not fully work func_cache = Flag(doc="Use stored functions for persistent cache") diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 327bedd3ebc..3d9109b225a 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -870,7 +870,7 @@ def compile_request( request.protocol_version, request.inline_objectids, request.json_parameters, - persistent_cache=bool(debug.flags.persistent_cache), + persistent_cache=not debug.flags.disable_persistent_cache, cache_key=request.get_cache_key(), ) return units, cstate @@ -976,7 +976,7 @@ def compile_in_tx_request( request.inline_objectids, request.json_parameters, expect_rollback=expect_rollback, - persistent_cache=bool(debug.flags.persistent_cache), + persistent_cache=not debug.flags.disable_persistent_cache, cache_key=request.get_cache_key(), ) return units, cstate diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index b627337ad1a..cf184ff36c8 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -398,6 +398,8 @@ class QueryUnitGroup: state_serializer: Optional[sertypes.StateSerializer] = None + cache_state: int = 0 + @property def units(self) -> List[QueryUnit]: if self._unpacked_units is None: diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 2008f0a4ba5..d85696fb864 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -73,6 +73,11 @@ cdef class Database: object _state_serializers readonly object user_config_spec + object _cache_worker_task + object _cache_queue + object _cache_notify_task + object _cache_notify_queue + readonly str name readonly object schema_version readonly object dbver diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 669ecf9e21b..2afe8417dfc 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -35,7 +35,7 @@ import weakref import immutables from edb import errors -from edb.common import debug, lru, uuidgen +from edb.common import debug, lru, uuidgen, asyncutil from edb import edgeql from edb.edgeql import qltypes from edb.schema import schema as s_schema @@ -76,6 +76,13 @@ cdef next_dbver(): return VER_COUNTER + +cdef enum CacheState: + Pending = 0, + Present, + Evicted + + @cython.final cdef class CompiledQuery: @@ -140,6 +147,16 @@ cdef class Database: self.extensions = extensions self._observe_auth_ext_config() + self._cache_worker_task = self._cache_queue = None + self._cache_notify_task = self._cache_notify_queue = None + if not debug.flags.disable_persistent_cache: + self._cache_queue = asyncio.Queue() + self._cache_worker_task = asyncio.create_task( + self.monitor(self.cache_worker, 'cache_worker')) + self._cache_notify_queue = asyncio.Queue() + self._cache_notify_task = asyncio.create_task( + self.monitor(self.cache_notifier, 'cache_notifier')) + @property def server(self): return self._index._server @@ -148,6 +165,83 @@ cdef class Database: def tenant(self): return self._index._tenant + def stop(self): + if self._cache_worker_task: + self._cache_worker_task.cancel() + self._cache_worker_task = None + if self._cache_notify_task: + self._cache_notify_task.cancel() + self._cache_notify_task = None + + async def monitor(self, worker, name): + while True: + try: + await worker() + except Exception as ex: + debug.dump(ex) + metrics.background_errors.inc( + 1.0, self.tenant._instance_name, name + ) + # Give things time to recover, since the likely + # failure mode here is a failover or some such. + await asyncio.sleep(0.1) + + async def cache_worker(self): + while True: + # First, handle any evictions + keys = [] + while self._eql_to_compiled.needs_cleanup(): + query_req, unit_group = self._eql_to_compiled.cleanup_one() + if len(unit_group) == 1 and unit_group.cache_state == 1: + keys.append(query_req.get_cache_key()) + unit_group.cache_state = CacheState.Evicted + if keys: + await self.tenant.evict_query_cache(self.name, keys) + + # Now, populate the cache + # Empty the queue, for batching reasons. + # N.B: This empty/get_nowait loop is safe because this is + # an asyncio Queue. If it was threaded, it would be racy. + ops = [await self._cache_queue.get()] + while not self._cache_queue.empty(): + ops.append(self._cache_queue.get_nowait()) + # Filter ops for only what we need + ops = [ + (query_req, units) for query_req, units in ops + if len(units) == 1 + and units[0].cache_sql + and units.cache_state == CacheState.Pending + ] + if not ops: + continue + + # TODO: Should we do any sort of error handling here? + g = execute.build_cache_persistence_units(ops) + conn = await self.tenant.acquire_pgcon(self.name) + try: + await g.execute(conn, self) + finally: + self.tenant.release_pgcon(self.name, conn) + + for _, units in ops: + units.cache_state = CacheState.Present + self._cache_notify_queue.put_nowait(str(units[0].cache_key)) + + async def cache_notifier(self): + await asyncutil.debounce( + lambda: self._cache_notify_queue.get(), + lambda keys: self.tenant.signal_sysevent( + 'query-cache-changes', + dbname=self.name, + keys=keys, + ), + max_wait=1.0, + delay_amt=0.2, + # 100 keys will take up about 4000 bytes, which + # fits in the 8000 allowed in events. + max_batch_size=100, + ) + cdef schedule_config_update(self): self._index._tenant.on_local_database_config_change(self.name) @@ -214,17 +308,9 @@ cdef class Database: return self._eql_to_compiled[key] = compiled - # TODO(fantix): merge in-memory cleanup into the task below - keys = [] - while self._eql_to_compiled.needs_cleanup(): - query_req, unit_group = self._eql_to_compiled.cleanup_one() - if len(unit_group) == 1: - keys.append(query_req.get_cache_key()) - if keys and debug.flags.persistent_cache: - self.tenant.create_task( - self.tenant.evict_query_cache(self.name, keys), - interruptable=True, - ) + + if self._cache_queue is not None: + self._cache_queue.put_nowait((key, compiled)) def cache_compiled_sql(self, key, compiled: list[str], schema_version): existing, ver = self._sql_to_compiled.get(key, DICTDEFAULT) @@ -267,23 +353,18 @@ cdef class Database: return old_serializer def hydrate_cache(self, query_cache): - new = set() for _, in_data, out_data in query_cache: query_req = rpc.CompilationRequest( self.server.compilation_config_serializer) query_req.deserialize(in_data, "") - new.add(query_req) if query_req not in self._eql_to_compiled: unit = dbstate.QueryUnit.deserialize(out_data) group = dbstate.QueryUnitGroup() group.append(unit) + group.cache_state = CacheState.Present self._eql_to_compiled[query_req] = group - for query_req in list(self._eql_to_compiled): - if query_req not in new: - del self._eql_to_compiled[query_req] - def iter_views(self): yield from self._views @@ -295,7 +376,8 @@ cdef class Database: async with self._introspection_lock: if self.user_schema_pickle is None: await self.tenant.introspect_db( - self.name, hydrate_cache=debug.flags.persistent_cache + self.name, + hydrate_cache=not debug.flags.disable_persistent_cache, ) @@ -731,8 +813,7 @@ cdef class DatabaseConnectionView: cdef cache_compiled_query(self, object key, object query_unit_group): assert query_unit_group.cacheable - if not self._in_tx_with_ddl: - self._db._cache_compiled_query(key, query_unit_group) + self._db._cache_compiled_query(key, query_unit_group) cdef lookup_compiled_query(self, object key): if (self._tx_error or @@ -1381,7 +1462,8 @@ cdef class DatabaseIndex: return db def unregister_db(self, dbname): - self._dbs.pop(dbname) + db = self._dbs.pop(dbname) + db.stop() self.set_current_branches() cdef inline set_current_branches(self): diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index c96e8bad8e0..8a0126549af 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -2932,7 +2932,8 @@ cdef class PGConnection: self.tenant.on_remote_database_quarantine(dbname) elif event == 'query-cache-changes': dbname = event_payload['dbname'] - self.tenant.on_remote_query_cache_change(dbname) + keys = event_payload.get('keys') + self.tenant.on_remote_query_cache_change(dbname, keys=keys) else: raise AssertionError(f'unexpected system event: {event!r}') diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 689cec91b4e..61af5e736ae 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -839,14 +839,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): await _dbview.reload_state_serializer() query_req, allow_capabilities = self.parse_execute_request() compiled = await self._parse(query_req, allow_capabilities) - units = compiled.query_unit_group - if len(units) == 1 and units[0].cache_sql: - conn = await self.get_pgcon() - try: - g = execute.build_cache_persistence_units([(query_req, units)]) - await g.execute(conn, _dbview) - finally: - self.maybe_release_pgcon(conn) buf = self.make_command_data_description_msg(compiled) diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 78baf73034e..b0ae8e469d9 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -70,7 +70,7 @@ cdef class ExecutionGroup: async def execute( self, pgcon.PGConnection be_conn, - dbview.DatabaseConnectionView dbv, + object dbv, # can be DatabaseConnectionView or Database fe_conn: frontend.AbstractFrontendConnection = None, bytes state = None, ): @@ -166,7 +166,6 @@ async def execute( bytes state = None, orig_state = None WriteBuffer bound_args_buf ExecutionGroup group - bint persist_cache, persist_recompiled_query_cache query_unit = compiled.query_unit_group[0] @@ -177,16 +176,6 @@ async def execute( server = dbv.server tenant = dbv.tenant - # If we have both the compilation request and a pair of SQLs for the cache - # (persist, evict), we should follow the persistent cache route. - persist_cache = bool(compiled.request and query_unit.cache_sql) - - # Recompilation is a standalone feature than persistent cache. - # This flag indicates both features are in use, and we actually have - # recompiled the query cache to persist. - persist_recompiled_query_cache = bool( - debug.flags.persistent_cache and compiled.recompiled_cache) - data = None try: @@ -213,24 +202,7 @@ async def execute( if query_unit.sql: if query_unit.user_schema: - if persist_recompiled_query_cache: - # If we have recompiled the query cache, writeback to - # the cache table here in an implicit transaction (if - # not in one already), so that whenever the transaction - # commits, we flip to using the new cache at once. - group = build_cache_persistence_units( - compiled.recompiled_cache) - group.append(query_unit) - if query_unit.ddl_stmt_id is None: - await group.execute(be_conn, dbv) - ddl_ret = None - else: - ddl_ret = be_conn.load_ddl_return( - query_unit, - await group.execute(be_conn, dbv, state=state), - ) - else: - ddl_ret = await be_conn.run_ddl(query_unit, state) + ddl_ret = await be_conn.run_ddl(query_unit, state) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: @@ -244,31 +216,14 @@ async def execute( read_data = ( query_unit.needs_readback or query_unit.is_explain) - if persist_cache: - # Persistent cache needs to happen before the actual - # query because the query may depend on the function - # created persisting the cache entry. - group = build_cache_persistence_units( - [(compiled.request, compiled.query_unit_group)] - ) - if not use_prep_stmt: - query_unit.sql_hash = b'' - group.append(query_unit, bound_args_buf) - data = await group.execute( - be_conn, - dbv, - fe_conn=fe_conn if not read_data else None, - state=state, - ) - else: - data = await be_conn.parse_execute( - query=query_unit, - fe_conn=fe_conn if not read_data else None, - bind_data=bound_args_buf, - use_prep_stmt=use_prep_stmt, - state=state, - dbver=dbv.dbver, - ) + data = await be_conn.parse_execute( + query=query_unit, + fe_conn=fe_conn if not read_data else None, + bind_data=bound_args_buf, + use_prep_stmt=use_prep_stmt, + state=state, + dbver=dbv.dbver, + ) if query_unit.needs_readback and data: config_ops = [ @@ -363,8 +318,6 @@ async def execute( if compiled.recompiled_cache: for req, qu_group in compiled.recompiled_cache: dbv.cache_compiled_query(req, qu_group) - if persist_cache or persist_recompiled_query_cache: - signal_query_cache_changes(dbv) finally: if query_unit.drop_db: tenant.allow_database_connections(query_unit.drop_db) @@ -648,26 +601,6 @@ def signal_side_effects(dbv, side_effects): ) -def signal_query_cache_changes(dbv): - # FIXME: This is disabled because the increased sysevent traffic - # caused by doing this was causing test failures on aarch64 when - # sometimes the signals failed due to a failure to look up - # transactions. We need to figure out what is going on with that - # and restore it. We also probably want to rate limit - # query-cache-changes, or include a more detailed payload, since - # it can force pretty aggressive amounts of cache reloading work - # on the targets. - - # dbv.tenant.create_task( - # dbv.tenant.signal_sysevent( - # 'query-cache-changes', - # dbname=dbv.dbname, - # ), - # interruptable=False, - # ) - pass - - async def parse_execute_json( db: dbview.Database, query: str, diff --git a/edb/server/tenant.py b/edb/server/tenant.py index bd651263ec0..ae5935c1bfd 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -28,6 +28,7 @@ AsyncGenerator, Dict, Set, + Optional, TypedDict, TYPE_CHECKING, ) @@ -1480,16 +1481,35 @@ async def task(): self.create_task(task(), interruptable=True) async def _load_query_cache( - self, conn: pgcon.PGConnection + self, + conn: pgcon.PGConnection, + keys: Optional[Iterable[uuid.UUID]] = None, ) -> list[tuple[bytes, ...]] | None: - return await conn.sql_fetch( - b'SELECT "schema_version", "input", "output" ' - b'FROM "edgedb"."_query_cache"', - use_prep_stmt=True, - ) + if keys is None: + return await conn.sql_fetch( + b''' + SELECT "schema_version", "input", "output" + FROM "edgedb"."_query_cache" + ''', + use_prep_stmt=True, + ) + else: + # If keys were specified, just load those keys. + # TODO: Or should we do something time based? + return await conn.sql_fetch( + b''' + SELECT "schema_version", "input", "output" + ROWS FROM json_array_elements($1) j(ikey) + INNER JOIN "edgedb"."_query_cache" + ON (to_jsonb(ARRAY[ikey])->>0)::uuid = key + ''', + args=(json.dumps(keys).encode('utf-8'),), + use_prep_stmt=True, + ) async def evict_query_cache( - self, dbname: str, + self, + dbname: str, keys: Iterable[uuid.UUID], ) -> None: try: @@ -1507,7 +1527,10 @@ async def evict_query_cache( finally: self.release_pgcon(dbname, conn) - await self.signal_sysevent("query-cache-changes", dbname=dbname) + # XXX: TODO: We don't need to signal here in the + # non-function version, but in the function caching + # situation this will be fraught. + # await self.signal_sysevent("query-cache-changes", dbname=dbname) except Exception: logger.exception("error in evict_query_cache():") @@ -1515,7 +1538,9 @@ async def evict_query_cache( 1.0, self._instance_name, "evict_query_cache" ) - def on_remote_query_cache_change(self, dbname: str) -> None: + def on_remote_query_cache_change( + self, dbname: str, keys: Optional[list[str]], + ) -> None: if not self._accept_new_tasks: return @@ -1526,7 +1551,7 @@ async def task(): return try: - query_cache = await self._load_query_cache(conn) + query_cache = await self._load_query_cache(conn, keys=keys) finally: self.release_pgcon(dbname, conn) diff --git a/pyproject.toml b/pyproject.toml index e00c4988278..786d2a3d909 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ test = [ 'ruff~=0.1.6', 'asyncpg~=0.29.0', + # Needed for testing asyncutil + 'async_solipsism==0.5.0', + # Needed for test_docs_sphinx_ext 'requests-xml~=0.2.3', diff --git a/tests/common/test_asyncutil.py b/tests/common/test_asyncutil.py new file mode 100644 index 00000000000..b3f24c8774d --- /dev/null +++ b/tests/common/test_asyncutil.py @@ -0,0 +1,121 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import unittest + +from edb.common import asyncutil + +try: + import async_solipsism +except ImportError: + async_solipsism = None # type: ignore + + +def with_fake_event_loop(f): + # async_solpsism creates an event loop with, among other things, + # a totally fake clock. + def new(*args, **kwargs): + loop = async_solipsism.EventLoop() + try: + loop.run_until_complete(f(*args, **kwargs)) + finally: + loop.close() + + return new + + +@unittest.skipIf(async_solipsism is None, 'async_solipsism is missing') +class TestDebounce(unittest.TestCase): + + @with_fake_event_loop + async def test_debounce_01(self): + loop = asyncio.get_running_loop() + outs = [] + ins = asyncio.Queue() + + async def output(vs): + assert loop.time() == int(loop.time()) + outs.append((int(loop.time()), vs)) + + async def sleep_until(t): + await asyncio.sleep(t - loop.time()) + + task = asyncio.create_task(asyncutil.debounce( + ins.get, + output, + # Use integers for delays to avoid any possibility of + # floating point nonsense + max_wait=500, + delay_amt=200, + max_batch_size=4, + )) + + ins.put_nowait(1) + await sleep_until(10) + ins.put_nowait(2) + ins.put_nowait(3) + await sleep_until(300) + ins.put_nowait(4) + ins.put_nowait(5) + ins.put_nowait(6) + await sleep_until(1000) + + # Time 1000 now + ins.put_nowait(7) + await sleep_until(1150) + ins.put_nowait(8) + ins.put_nowait(9) + ins.put_nowait(10) + await sleep_until(1250) + ins.put_nowait(11) + + ins.put_nowait(12) + await asyncio.sleep(190) + ins.put_nowait(13) + await asyncio.sleep(190) + ins.put_nowait(14) + await asyncio.sleep(190) + self.assertEqual(loop.time(), 1820) + ins.put_nowait(15) + + # Make sure everything clears out and stop it + await asyncio.sleep(10000) + task.cancel() + + self.assertEqual( + outs, + [ + # First one right away + (0, [1]), + # Next two added at 10 + 200 tick + (210, [2, 3]), + # Next three added at 300 + 200 tick + (500, [4, 5, 6]), + # First at 1000 + (1000, [7]), + # Next group at 1250 when the batch fills up + (1250, [8, 9, 10, 11]), + # And more at 1750 when time expires on that batch + (1750, [12, 13, 14]), + # And the next one (queued at 1820) at 200 after it was queued, + # since there had been a recent signal when it was queued. + (2020, [15]), + ], + )