Skip to content

Commit

Permalink
Do persistent cache writebacks asynchronously (#7025)
Browse files Browse the repository at this point in the history
Whenever a entry is added to the in memory cache, enqueue it in a
queue that a per-db worker task consumes. The worker task manages all
cache inserts and evictions, which should fully eliminate the
serialization error dangers. This also lets us persistently cache
queries that were first compiled from within a rolled back
transaction.

Also, reenable query-cache-notifications, but make them include the
specific keys to query and implement a debouncing/batching algorithm
to cut down on traffic.

Unfortunately, even with query-cache-notifications off, the async caching
starts triggering the bizarre failures to lookup transactions in postgres on
arm64 linux.
I'm going to put up a follow up that disables the cache there, so we can
move ahead with testing and the release.

FUTURE NOTE: The situation for function caching once we want to
reintroduce that will be a little more complicated. We'll need to put
both the SQL text and the function call code in the in-memory cache,
and use the text until the function is visible everywhere (at which
point we can drop the text.) We'll also need to do some thinking about
how to test it properly, because the downside of the approach is that
in the typically path, the first execution of a query can't use the
function.  We may need some sort of testing path to allow us to
exercise the functions easily in the test code.
  • Loading branch information
msullivan authored Mar 13, 2024
1 parent 6394465 commit 6bbfde4
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 122 deletions.
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ cython: build-reqs
BUILD_EXT_MODE=py-only python setup.py build_ext --inplace


# Just rebuild actually changed cython. This *should* work, since
# that is how build systems are supposed to be, but it sometimes
# fails in annoying ways.
cython-fast: build-reqs
BUILD_EXT_MODE=py-only python setup.py build_ext --inplace


rust: build-reqs
BUILD_EXT_MODE=rust-only python setup.py build_ext --inplace

Expand Down
80 changes: 79 additions & 1 deletion edb/common/asyncutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


from __future__ import annotations
from typing import TypeVar, Awaitable
from typing import Callable, TypeVar, Awaitable

import asyncio

Expand Down Expand Up @@ -62,3 +62,81 @@ async def deferred_shield(arg: Awaitable[_T]) -> _T:
if ex:
raise ex
return task.result()


async def debounce(
input: Callable[[], Awaitable[_T]],
output: Callable[[list[_T]], Awaitable[None]],
*,
max_wait: float,
delay_amt: float,
max_batch_size: int,
) -> None:
'''Debounce and batch async events.
Loops forever unless an operation fails, so should probably be run
from a task.
The basic algorithm is that if an event comes in less than
`delay_amt` since the previous one, then instead of sending it
immediately, we wait an additional `delay_amt` from then. If we are
already waiting, any message also extends the wait, up to
`max_wait`.
Also, cap the maximum batch size to `max_batch_size`.
'''
# I think the algorithm reads more clearly with the params
# capitalized as constants, though we don't want them like that in
# the argument list, so reassign them.
MAX_WAIT, DELAY_AMT, MAX_BATCH_SIZE = max_wait, delay_amt, max_batch_size

loop = asyncio.get_running_loop()

batch = []
last_signal = -MAX_WAIT
target_time = None

while True:
try:
if target_time is None:
v = await input()
else:
async with asyncio.timeout_at(target_time):
v = await input()
except TimeoutError:
t = loop.time()
else:
batch.append(v)

t = loop.time()

# If we aren't current waiting, and we got a
# notification recently, arrange to wait some before
# sending it.
if (
target_time is None
and t - last_signal < DELAY_AMT
):
target_time = t + DELAY_AMT
# If we were already waiting, wait a little longer, though
# not longer than MAX_WAIT.
elif (
target_time is not None
):
target_time = min(
max(t + DELAY_AMT, target_time),
last_signal + MAX_WAIT,
)

# Skip sending the event if we need to wait longer.
if (
target_time is not None
and t < target_time
and len(batch) < MAX_BATCH_SIZE
):
continue

await output(batch)
batch = []
last_signal = t
target_time = None
2 changes: 1 addition & 1 deletion edb/common/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class flags(metaclass=FlagsMeta):

zombodb = Flag(doc="Enabled zombodb and disables postgres FTS")

persistent_cache = Flag(doc="Use persistent cache")
disable_persistent_cache = Flag(doc="Don't use persistent cache")

# Function cache is an experimental feature that may not fully work
func_cache = Flag(doc="Use stored functions for persistent cache")
Expand Down
4 changes: 2 additions & 2 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def compile_request(
request.protocol_version,
request.inline_objectids,
request.json_parameters,
persistent_cache=bool(debug.flags.persistent_cache),
persistent_cache=not debug.flags.disable_persistent_cache,
cache_key=request.get_cache_key(),
)
return units, cstate
Expand Down Expand Up @@ -976,7 +976,7 @@ def compile_in_tx_request(
request.inline_objectids,
request.json_parameters,
expect_rollback=expect_rollback,
persistent_cache=bool(debug.flags.persistent_cache),
persistent_cache=not debug.flags.disable_persistent_cache,
cache_key=request.get_cache_key(),
)
return units, cstate
Expand Down
2 changes: 2 additions & 0 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ class QueryUnitGroup:

state_serializer: Optional[sertypes.StateSerializer] = None

cache_state: int = 0

@property
def units(self) -> List[QueryUnit]:
if self._unpacked_units is None:
Expand Down
5 changes: 5 additions & 0 deletions edb/server/dbview/dbview.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ cdef class Database:
object _state_serializers
readonly object user_config_spec

object _cache_worker_task
object _cache_queue
object _cache_notify_task
object _cache_notify_queue

readonly str name
readonly object schema_version
readonly object dbver
Expand Down
126 changes: 104 additions & 22 deletions edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import weakref
import immutables

from edb import errors
from edb.common import debug, lru, uuidgen
from edb.common import debug, lru, uuidgen, asyncutil
from edb import edgeql
from edb.edgeql import qltypes
from edb.schema import schema as s_schema
Expand Down Expand Up @@ -76,6 +76,13 @@ cdef next_dbver():
return VER_COUNTER



cdef enum CacheState:
Pending = 0,
Present,
Evicted


@cython.final
cdef class CompiledQuery:

Expand Down Expand Up @@ -140,6 +147,16 @@ cdef class Database:
self.extensions = extensions
self._observe_auth_ext_config()

self._cache_worker_task = self._cache_queue = None
self._cache_notify_task = self._cache_notify_queue = None
if not debug.flags.disable_persistent_cache:
self._cache_queue = asyncio.Queue()
self._cache_worker_task = asyncio.create_task(
self.monitor(self.cache_worker, 'cache_worker'))
self._cache_notify_queue = asyncio.Queue()
self._cache_notify_task = asyncio.create_task(
self.monitor(self.cache_notifier, 'cache_notifier'))

@property
def server(self):
return self._index._server
Expand All @@ -148,6 +165,83 @@ cdef class Database:
def tenant(self):
return self._index._tenant

def stop(self):
if self._cache_worker_task:
self._cache_worker_task.cancel()
self._cache_worker_task = None
if self._cache_notify_task:
self._cache_notify_task.cancel()
self._cache_notify_task = None

async def monitor(self, worker, name):
while True:
try:
await worker()
except Exception as ex:
debug.dump(ex)
metrics.background_errors.inc(
1.0, self.tenant._instance_name, name
)
# Give things time to recover, since the likely
# failure mode here is a failover or some such.
await asyncio.sleep(0.1)

async def cache_worker(self):
while True:
# First, handle any evictions
keys = []
while self._eql_to_compiled.needs_cleanup():
query_req, unit_group = self._eql_to_compiled.cleanup_one()
if len(unit_group) == 1 and unit_group.cache_state == 1:
keys.append(query_req.get_cache_key())
unit_group.cache_state = CacheState.Evicted
if keys:
await self.tenant.evict_query_cache(self.name, keys)

# Now, populate the cache
# Empty the queue, for batching reasons.
# N.B: This empty/get_nowait loop is safe because this is
# an asyncio Queue. If it was threaded, it would be racy.
ops = [await self._cache_queue.get()]
while not self._cache_queue.empty():
ops.append(self._cache_queue.get_nowait())
# Filter ops for only what we need
ops = [
(query_req, units) for query_req, units in ops
if len(units) == 1
and units[0].cache_sql
and units.cache_state == CacheState.Pending
]
if not ops:
continue

# TODO: Should we do any sort of error handling here?
g = execute.build_cache_persistence_units(ops)
conn = await self.tenant.acquire_pgcon(self.name)
try:
await g.execute(conn, self)
finally:
self.tenant.release_pgcon(self.name, conn)

for _, units in ops:
units.cache_state = CacheState.Present
self._cache_notify_queue.put_nowait(str(units[0].cache_key))

async def cache_notifier(self):
await asyncutil.debounce(
lambda: self._cache_notify_queue.get(),
lambda keys: self.tenant.signal_sysevent(
'query-cache-changes',
dbname=self.name,
keys=keys,
),
max_wait=1.0,
delay_amt=0.2,
# 100 keys will take up about 4000 bytes, which
# fits in the 8000 allowed in events.
max_batch_size=100,
)

cdef schedule_config_update(self):
self._index._tenant.on_local_database_config_change(self.name)

Expand Down Expand Up @@ -214,17 +308,9 @@ cdef class Database:
return

self._eql_to_compiled[key] = compiled
# TODO(fantix): merge in-memory cleanup into the task below
keys = []
while self._eql_to_compiled.needs_cleanup():
query_req, unit_group = self._eql_to_compiled.cleanup_one()
if len(unit_group) == 1:
keys.append(query_req.get_cache_key())
if keys and debug.flags.persistent_cache:
self.tenant.create_task(
self.tenant.evict_query_cache(self.name, keys),
interruptable=True,
)

if self._cache_queue is not None:
self._cache_queue.put_nowait((key, compiled))

def cache_compiled_sql(self, key, compiled: list[str], schema_version):
existing, ver = self._sql_to_compiled.get(key, DICTDEFAULT)
Expand Down Expand Up @@ -267,23 +353,18 @@ cdef class Database:
return old_serializer

def hydrate_cache(self, query_cache):
new = set()
for _, in_data, out_data in query_cache:
query_req = rpc.CompilationRequest(
self.server.compilation_config_serializer)
query_req.deserialize(in_data, "<unknown>")
new.add(query_req)

if query_req not in self._eql_to_compiled:
unit = dbstate.QueryUnit.deserialize(out_data)
group = dbstate.QueryUnitGroup()
group.append(unit)
group.cache_state = CacheState.Present
self._eql_to_compiled[query_req] = group

for query_req in list(self._eql_to_compiled):
if query_req not in new:
del self._eql_to_compiled[query_req]

def iter_views(self):
yield from self._views

Expand All @@ -295,7 +376,8 @@ cdef class Database:
async with self._introspection_lock:
if self.user_schema_pickle is None:
await self.tenant.introspect_db(
self.name, hydrate_cache=debug.flags.persistent_cache
self.name,
hydrate_cache=not debug.flags.disable_persistent_cache,
)


Expand Down Expand Up @@ -731,8 +813,7 @@ cdef class DatabaseConnectionView:
cdef cache_compiled_query(self, object key, object query_unit_group):
assert query_unit_group.cacheable

if not self._in_tx_with_ddl:
self._db._cache_compiled_query(key, query_unit_group)
self._db._cache_compiled_query(key, query_unit_group)

cdef lookup_compiled_query(self, object key):
if (self._tx_error or
Expand Down Expand Up @@ -1381,7 +1462,8 @@ cdef class DatabaseIndex:
return db

def unregister_db(self, dbname):
self._dbs.pop(dbname)
db = self._dbs.pop(dbname)
db.stop()
self.set_current_branches()

cdef inline set_current_branches(self):
Expand Down
3 changes: 2 additions & 1 deletion edb/server/pgcon/pgcon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2932,7 +2932,8 @@ cdef class PGConnection:
self.tenant.on_remote_database_quarantine(dbname)
elif event == 'query-cache-changes':
dbname = event_payload['dbname']
self.tenant.on_remote_query_cache_change(dbname)
keys = event_payload.get('keys')
self.tenant.on_remote_query_cache_change(dbname, keys=keys)
else:
raise AssertionError(f'unexpected system event: {event!r}')

Expand Down
8 changes: 0 additions & 8 deletions edb/server/protocol/binary.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -839,14 +839,6 @@ cdef class EdgeConnection(frontend.FrontendConnection):
await _dbview.reload_state_serializer()
query_req, allow_capabilities = self.parse_execute_request()
compiled = await self._parse(query_req, allow_capabilities)
units = compiled.query_unit_group
if len(units) == 1 and units[0].cache_sql:
conn = await self.get_pgcon()
try:
g = execute.build_cache_persistence_units([(query_req, units)])
await g.execute(conn, _dbview)
finally:
self.maybe_release_pgcon(conn)

buf = self.make_command_data_description_msg(compiled)

Expand Down
Loading

0 comments on commit 6bbfde4

Please sign in to comment.