Skip to content

Commit

Permalink
Recompile all cache entries on DDL
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Jan 31, 2024
1 parent d816895 commit 1c256eb
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 5 deletions.
30 changes: 30 additions & 0 deletions edb/pgsql/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ def __init__(self) -> None:
)


class ClearQueryCacheFunction(dbops.Function):

# TODO(fantix): this may consume a lot of memory in Postgres
text = f'''
DECLARE
row record;
BEGIN
FOR row IN
DELETE FROM "edgedb"."_query_cache"
RETURNING "input", "evict"
LOOP
EXECUTE row."evict";
RETURN NEXT row."input";
END LOOP;
END;
'''

def __init__(self) -> None:
super().__init__(
name=('edgedb', '_clear_query_cache'),
args=[],
returns=('bytea',),
set_returning=True,
language='plpgsql',
volatility='volatile',
text=self.text,
)


class BigintDomain(dbops.Domain):
"""Bigint: a variant of numeric that enforces zero digits after the dot.
Expand Down Expand Up @@ -4437,6 +4466,7 @@ async def bootstrap(
dbops.CreateTable(QueryCacheTable()),
dbops.Query(DMLDummyTable.SETUP_QUERY),
dbops.CreateFunction(EvictQueryCacheFunction()),
dbops.CreateFunction(ClearQueryCacheFunction()),
dbops.CreateFunction(UuidGenerateV1mcFunction('edgedbext')),
dbops.CreateFunction(UuidGenerateV4Function('edgedbext')),
dbops.CreateFunction(UuidGenerateV5Function('edgedbext')),
Expand Down
13 changes: 13 additions & 0 deletions edb/server/compiler_pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ async def analyze_explain_output(
def get_debug_info(self):
return {}

def get_size_hint(self) -> int:
raise NotImplementedError


class BaseLocalPool(
AbstractPool, amsg.ServerProtocol, asyncio.SubprocessProtocol
Expand Down Expand Up @@ -948,6 +951,9 @@ async def _stop(self):
await trans._wait()
trans.close()

def get_size_hint(self) -> int:
return self._pool_size


@srvargs.CompilerPoolMode.OnDemand.assign_implementation
class SimpleAdaptivePool(BaseLocalPool):
Expand Down Expand Up @@ -1071,6 +1077,9 @@ def _scale_down(self):
)[:-self._pool_size]:
worker.close()

def get_size_hint(self) -> int:
return self._max_num_workers


class RemoteWorker(BaseWorker):
def __init__(self, con, secret, *args):
Expand Down Expand Up @@ -1098,6 +1107,7 @@ def __init__(self, *, address, pool_size, **kwargs):
self._worker = None
self._sync_lock = asyncio.Lock()
self._semaphore = asyncio.BoundedSemaphore(pool_size)
self._pool_size = pool_size
secret = os.environ.get("_EDGEDB_SERVER_COMPILER_POOL_SECRET")
if not secret:
raise AssertionError(
Expand Down Expand Up @@ -1249,6 +1259,9 @@ def get_debug_info(self):
free=self._semaphore._value, # type: ignore
)

def get_size_hint(self) -> int:
return self._pool_size


@dataclasses.dataclass
class TenantSchema:
Expand Down
70 changes: 69 additions & 1 deletion edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ import weakref
import immutables

from edb import errors
from edb.common import lru, uuidgen
from edb.common import lru, taskgroup, uuidgen
from edb import edgeql
from edb.edgeql import qltypes
from edb.schema import schema as s_schema
from edb.server import compiler, defines, config, metrics
from edb.server.compiler import dbstate, enums, sertypes
from edb.server.protocol import execute
from edb.pgsql import dbops
from edb.server.compiler_pool import state as compiler_state_mod

Expand Down Expand Up @@ -914,6 +915,73 @@ cdef class DatabaseConnectionView:
self._reset_tx_state()
return side_effects

async def clear_cache_keys(self, conn) -> list[bytes]:
rows = await conn.sql_fetch(b'SELECT "edgedb"."_clear_query_cache"()')
self._db._query_cache.clear()
return [row[0] for row in rows or []]

async def recompile_all(self, conn, requests: typing.Iterable[bytes]):
compiler_pool = self.server.get_compiler_pool()
concurrency = max(1, compiler_pool.get_size_hint() - 1)
i = asyncio.Queue(maxsize=concurrency)
o = asyncio.Queue()

async def recompile_request():
while True:
request = await i.get()
if request is None:
o.put_nowait((None, None))
break
try:
result = await compiler_pool.compile(
self.dbname,
self.get_user_schema_pickle(),
self.get_global_schema_pickle(),
self.reflection_cache,
self.get_database_config(),
self.get_compilation_system_config(),
request,
client_id=self.tenant.client_id,
)
except Exception:
# discard cache entry that cannot be recompiled
pass
else:
o.put_nowait((request, result[0]))

async def persist_cache():
count = concurrency
while count > 0:
request, response = await o.get()
if request is None:
count -= 1
else:
query_unit_group = pickle.loads(response)
assert len(query_unit_group) == 1
query_unit = query_unit_group[0]
key = query_unit_group.cache_key
assert key is not None
await execute.persist_cache_spec(
conn,
self,
query_unit,
request,
response,
key,
)
self._db._query_cache[key] = (
query_unit_group, self.schema_version
)

async with taskgroup.TaskGroup() as g:
for _ in range(concurrency):
g.create_task(recompile_request())
g.create_task(persist_cache())
for data in requests:
await i.put(data)
for _ in range(concurrency):
await i.put(None)

async def apply_config_ops(self, conn, ops):
settings = self.get_config_spec()

Expand Down
32 changes: 28 additions & 4 deletions edb/server/protocol/execute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import (
Any,
Mapping,
Expand Down Expand Up @@ -64,6 +63,23 @@ async def persist_cache(
):
assert len(compiled.query_unit_group) == 1
query_unit = compiled.query_unit_group[0]
return await persist_cache_spec(
be_conn,
dbv,
query_unit,
compiled.request.serialize(),
compiled.serialized,
compiled.request.get_cache_key(),
)

async def persist_cache_spec(
be_conn: pgcon.PGConnection,
dbv: dbview.DatabaseConnectionView,
query_unit,
request,
response,
cache_key,
):
persist, evict = query_unit.cache_sql
await be_conn.sql_execute((evict, persist))
await be_conn.sql_fetch(
Expand All @@ -73,10 +89,10 @@ async def persist_cache(
b'ON CONFLICT (key) DO UPDATE SET '
b'"schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5',
args=(
compiled.request.get_cache_key().bytes,
cache_key.bytes,
dbv.schema_version.bytes,
compiled.request.serialize(),
compiled.serialized,
request,
response,
evict,
),
use_prep_stmt=True,
Expand Down Expand Up @@ -114,6 +130,7 @@ async def execute(
new_types = None
server = dbv.server
tenant = dbv.tenant
recompile_requests = None

data = None

Expand All @@ -140,7 +157,12 @@ async def execute(
await persist_cache(be_conn, dbv, compiled)

if query_unit.sql:
if query_unit.has_ddl:
# TODO(fantix): do this in the same transaction
recompile_requests = await dbv.clear_cache_keys(be_conn)
if query_unit.ddl_stmt_id:
await be_conn.sql_execute(
b'delete from "edgedb"."_query_cache"')
ddl_ret = await be_conn.run_ddl(query_unit, state)
if ddl_ret and ddl_ret['new_types']:
new_types = ddl_ret['new_types']
Expand Down Expand Up @@ -230,6 +252,8 @@ async def execute(
# 1. An orphan ROLLBACK command without a paring start tx
# 2. There was no SQL, so the state can't have been synced.
be_conn.last_state = state
if recompile_requests:
await dbv.recompile_all(be_conn, recompile_requests)
finally:
if query_unit.drop_db:
tenant.allow_database_connections(query_unit.drop_db)
Expand Down

0 comments on commit 1c256eb

Please sign in to comment.