Skip to content

Commit

Permalink
Persist cache in the same transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Mar 1, 2024
1 parent 00a77fa commit ed9bb13
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions edb/server/protocol/execute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cdef class ExecutionGroup:
pgcon.PGConnection be_conn,
dbview.DatabaseConnectionView dbv,
fe_conn: frontend.AbstractFrontendConnection = None,
bytes state = None,
):
cdef int dbver

Expand All @@ -82,12 +83,14 @@ cdef class ExecutionGroup:
self.group,
True, # sync
self.bind_datas,
None, # state
state,
0, # start
len(self.group), # end
dbver,
parse_array,
)
if state is not None:
await be_conn.wait_for_state_resp(state, state_sync=0)
for i, unit in enumerate(self.group):
if unit.output_format == FMT_NONE:
for sql in unit.sql:
Expand All @@ -105,11 +108,12 @@ cdef class ExecutionGroup:
return rv


async def persist_cache(
be_conn: pgcon.PGConnection,
dbv: dbview.DatabaseConnectionView,
cdef ExecutionGroup build_cache_persistence_units(
pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]],
ExecutionGroup group = None,
):
if group is None:
group = ExecutionGroup()
insert_sql = b'''
INSERT INTO "edgedb"."_query_cache"
("key", "schema_version", "input", "output", "evict")
Expand All @@ -118,7 +122,6 @@ async def persist_cache(
"schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5
'''
sql_hash = hashlib.sha1(insert_sql).hexdigest().encode('latin1')
group = ExecutionGroup()
for request, units in pairs:
# FIXME: this is temporary; drop this assertion when we support scripts
assert len(units) == 1
Expand Down Expand Up @@ -146,6 +149,15 @@ async def persist_cache(
evict,
)),
)
return group


async def persist_cache(
be_conn: pgcon.PGConnection,
dbv: dbview.DatabaseConnectionView,
pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]],
):
cdef group = build_cache_persistence_units(pairs)

try:
await group.execute(be_conn, dbv)
Expand Down Expand Up @@ -182,6 +194,7 @@ async def execute(
cdef:
bytes state = None, orig_state = None
WriteBuffer bound_args_buf
ExecutionGroup group

query_unit = compiled.query_unit_group[0]

Expand Down Expand Up @@ -216,13 +229,6 @@ async def execute(
else:
config_ops = query_unit.config_ops

if compiled.request and query_unit.cache_sql:
await persist_cache(
be_conn,
dbv,
[(compiled.request, compiled.query_unit_group)],
)

if query_unit.sql:
if query_unit.user_schema:
ddl_ret = await be_conn.run_ddl(query_unit, state)
Expand All @@ -239,14 +245,28 @@ async def execute(
read_data = (
query_unit.needs_readback or query_unit.is_explain)

data = await be_conn.parse_execute(
query=query_unit,
fe_conn=fe_conn if not read_data else None,
bind_data=bound_args_buf,
use_prep_stmt=use_prep_stmt,
state=state,
dbver=dbv.dbver,
)
if compiled.request and query_unit.cache_sql:
group = build_cache_persistence_units(
[(compiled.request, compiled.query_unit_group)]
)
if not use_prep_stmt:
query_unit.sql_hash = b''
group.append(query_unit, bound_args_buf)
data = await group.execute(
be_conn,
dbv,
fe_conn=fe_conn if not read_data else None,
state=state,
)
else:
data = await be_conn.parse_execute(
query=query_unit,
fe_conn=fe_conn if not read_data else None,
bind_data=bound_args_buf,
use_prep_stmt=use_prep_stmt,
state=state,
dbver=dbv.dbver,
)

if query_unit.needs_readback and data:
config_ops = [
Expand Down

0 comments on commit ed9bb13

Please sign in to comment.