Skip to content

Commit

Permalink
Optimize pickled in-tx CompilerConnectionState (#7047)
Browse files Browse the repository at this point in the history
User schema is excluded from the pickle until DDL.
  • Loading branch information
fantix authored Mar 15, 2024
1 parent 3c8ba43 commit fe3d79c
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 19 deletions.
42 changes: 38 additions & 4 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ class TransactionState(NamedTuple):

id: int
name: Optional[str]
user_schema: s_schema.FlatSchema
local_user_schema: s_schema.FlatSchema | None
global_schema: s_schema.FlatSchema
modaliases: immutables.Map[Optional[str], str]
session_config: immutables.Map[str, config.SettingValue]
Expand All @@ -687,6 +687,13 @@ class TransactionState(NamedTuple):
migration_state: Optional[MigrationState] = None
migration_rewrite_state: Optional[MigrationRewriteState] = None

@property
def user_schema(self) -> s_schema.FlatSchema:
if self.local_user_schema is None:
return self.tx.root_user_schema
else:
return self.local_user_schema


class Transaction:

Expand Down Expand Up @@ -717,7 +724,9 @@ def __init__(
self._current = TransactionState(
id=self._id,
name=None,
user_schema=user_schema,
local_user_schema=(
None if user_schema is self.root_user_schema else user_schema
),
global_schema=global_schema,
modaliases=modaliases,
session_config=session_config,
Expand All @@ -734,6 +743,10 @@ def __init__(
def id(self) -> int:
return self._id

@property
def root_user_schema(self) -> s_schema.FlatSchema:
return self._constate.root_user_schema

def is_implicit(self) -> bool:
return self._implicit

Expand Down Expand Up @@ -873,7 +886,7 @@ def update_schema(self, new_schema: s_schema.Schema) -> None:
global_schema = new_schema.get_global_schema()
assert isinstance(global_schema, s_schema.FlatSchema)
self._current = self._current._replace(
user_schema=user_schema,
local_user_schema=user_schema,
global_schema=global_schema,
)

Expand Down Expand Up @@ -909,11 +922,15 @@ def update_migration_rewrite_state(
self._current = self._current._replace(migration_rewrite_state=mrstate)


CStateStateType = Tuple[Dict[int, TransactionState], Transaction, int]


class CompilerConnectionState:

__slots__ = ('_savepoints_log', '_current_tx', '_tx_count',)
__slots__ = ('_savepoints_log', '_current_tx', '_tx_count', '_user_schema')

_savepoints_log: Dict[int, TransactionState]
_user_schema: Optional[s_schema.FlatSchema]

def __init__(
self,
Expand All @@ -926,6 +943,8 @@ def __init__(
system_config: immutables.Map[str, config.SettingValue],
cached_reflection: immutables.Map[str, Tuple[str, ...]],
):
assert isinstance(user_schema, s_schema.FlatSchema)
self._user_schema = user_schema
self._tx_count = time.monotonic_ns()
self._init_current_tx(
user_schema=user_schema,
Expand All @@ -938,6 +957,21 @@ def __init__(
)
self._savepoints_log = {}

def __getstate__(self) -> CStateStateType:
return self._savepoints_log, self._current_tx, self._tx_count

def __setstate__(self, state: CStateStateType) -> None:
self._savepoints_log, self._current_tx, self._tx_count = state
self._user_schema = None

@property
def root_user_schema(self) -> s_schema.FlatSchema:
assert self._user_schema is not None
return self._user_schema

def set_root_user_schema(self, user_schema: s_schema.FlatSchema) -> None:
self._user_schema = user_schema

def _new_txid(self) -> int:
self._tx_count += 1
return self._tx_count
Expand Down
18 changes: 17 additions & 1 deletion edb/server/compiler_pool/multitenant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,28 @@ def compile(
return units, pickled_state


def compile_in_tx(cstate, _, *args, **kwargs):
def compile_in_tx(
_,
client_id: Optional[int],
dbname: Optional[str],
user_schema: Optional[bytes],
cstate,
*args,
**kwargs,
):
global LAST_STATE
if cstate == state.REUSE_LAST_STATE_MARKER:
cstate = LAST_STATE
else:
cstate = pickle.loads(cstate)
if client_id is None:
assert user_schema is not None
cstate.set_root_user_schema(pickle.loads(user_schema))
else:
assert dbname is not None
client_schema = clients[client_id]
db = client_schema.dbs[dbname]
cstate.set_root_user_schema(db.user_schema)
units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs)
LAST_STATE = cstate
return units, pickle.dumps(cstate, -1)
Expand Down
131 changes: 121 additions & 10 deletions edb/server/compiler_pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,14 @@ async def compile(
self._release_worker(worker)

async def compile_in_tx(
self, txid, pickled_state, state_id, *compile_args
self,
dbname,
user_schema_pickle,
txid,
pickled_state,
state_id,
*compile_args,
**compiler_args,
):
# When we compile a query, the compiler returns a tuple:
# a QueryUnit and the state the compiler is in if it's in a
Expand All @@ -413,7 +420,8 @@ async def compile_in_tx(
# stored in edgecon; we never modify it, so `is` is sufficient and
# is faster than `==`.
worker = await self._acquire_worker(
condition=lambda w: (w._last_pickled_state is pickled_state)
condition=lambda w: (w._last_pickled_state is pickled_state),
compiler_args=compiler_args,
)

if worker._last_pickled_state is pickled_state:
Expand All @@ -422,10 +430,21 @@ async def compile_in_tx(
# state over the network. So we replace the state with a marker,
# that the compiler process will recognize.
pickled_state = state.REUSE_LAST_STATE_MARKER
dbname = user_schema_pickle = None
else:
worker_db = worker._dbs.get(dbname)
if worker_db is None:
dbname = None
elif worker_db.user_schema_pickle is user_schema_pickle:
user_schema_pickle = None
else:
dbname = None

try:
units, new_pickled_state = await worker.call(
'compile_in_tx',
dbname,
user_schema_pickle,
pickled_state,
txid,
*compile_args
Expand Down Expand Up @@ -1154,22 +1173,35 @@ def _release_worker(self, worker, *, put_in_front: bool = True):
self._semaphore.release()

async def compile_in_tx(
self, txid, pickled_state, state_id, *compile_args
self,
dbname,
user_schema_pickle,
txid,
pickled_state,
state_id,
*compile_args,
**compiler_args,
):
worker = await self._acquire_worker()
try:
return await worker.call(
'compile_in_tx',
state.REUSE_LAST_STATE_MARKER,
state_id,
None, # client_id
None, # dbname
None, # user_schema_pickle
state.REUSE_LAST_STATE_MARKER,
txid,
*compile_args
)
except state.StateNotFound:
return await worker.call(
'compile_in_tx',
0, # state_id
None, # client_id
None, # dbname
user_schema_pickle,
pickled_state,
0,
txid,
*compile_args
)
Expand Down Expand Up @@ -1285,11 +1317,6 @@ def flush_invalidation(self) -> None:
self._cache.pop(client_id, None)
self._last_used_by_client.pop(client_id, None)

async def call(self, method_name, *args, sync_state=None):
if method_name == "compile_in_tx":
args = (args[0], 0, *args[1:])
return await super().call(method_name, *args, sync_state=sync_state)


@srvargs.CompilerPoolMode.MultiTenant.assign_implementation
class MultiTenantPool(FixedPool):
Expand Down Expand Up @@ -1501,6 +1528,90 @@ def sync_worker_state_cb(
dbname,
), callback

async def compile_in_tx(
self,
dbname,
user_schema_pickle,
txid,
pickled_state,
state_id,
*compile_args,
**compiler_args,
):
client_id = compiler_args.get("client_id")

# Prefer a worker we used last time in the transaction (condition), or
# (weighter) one with the user schema at tx start so that we can pass
# over only the pickled state. Then prefer the least-recently used one
# if many workers passed any check in the weighter, or the most vacant.
def weighter(w: MultiTenantWorker):
if ts := w.get_tenant_schema(client_id):
if db := ts.dbs.get(dbname):
return (
True,
db.user_schema_pickle is user_schema_pickle,
w.last_used(client_id),
)
else:
return True, False, w.last_used(client_id)
else:
return False, False, self._cache_size - w.cache_size()

worker = await self._acquire_worker(
condition=lambda w: (w._last_pickled_state is pickled_state),
weighter=weighter,
**compiler_args,
)

# Avoid sending information that we know the worker already have.
if worker._last_pickled_state is pickled_state:
pickled_state = state.REUSE_LAST_STATE_MARKER
dbname = client_id = user_schema_pickle = None
else:
assert isinstance(worker, MultiTenantWorker)
assert client_id is not None
tenant_schema = worker.get_tenant_schema(client_id)
if tenant_schema is None:
# Just pass state + root user schema if this is a new client in
# the worker; we don't want to initialize the client as we
# don't have enough information to do so.
dbname = client_id = None
else:
worker_db = tenant_schema.dbs.get(dbname)
if worker_db is None:
# The worker has the client but not the database
dbname = client_id = None
elif worker_db.user_schema_pickle is user_schema_pickle:
# Avoid sending the root user schema because the worker has
# it - just send client_id + dbname to reference it, as
# well as the state of course.
user_schema_pickle = None
else:
# The worker has a different root user schema
dbname = client_id = None

try:
units, new_pickled_state = await worker.call(
'compile_in_tx',
# multitenant_worker is also used in MultiSchemaPool for remote
# compilers where the first argument "state_id" is used to find
# worker without passing the pickled state. Here in multi-
# tenant mode, we already have the pickled state, so "state_id"
# is not used. Just prepend a fake ID to comply to the API.
0, # state_id
client_id,
dbname,
user_schema_pickle,
pickled_state,
txid,
*compile_args
)
worker._last_pickled_state = new_pickled_state
return units, new_pickled_state, 0

finally:
self._release_worker(worker, put_in_front=False)


async def create_compiler_pool(
*,
Expand Down
2 changes: 1 addition & 1 deletion edb/server/compiler_pool/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def acquire(
if condition(w):
self._queue.remove(w)
return w
elif weighter is not None:
if weighter is not None:
rv = self._queue[0]
weight = weighter(rv)
it = iter(self._queue)
Expand Down
20 changes: 18 additions & 2 deletions edb/server/compiler_pool/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,15 @@ async def _call_for_client(
self._release_worker(worker)

async def compile_in_tx(
self, pickled_state, state_id, txid, *compile_args, msg=None
self,
state_id,
client_id,
dbname,
user_schema_pickle,
pickled_state,
txid,
*compile_args,
msg=None,
):
if pickled_state == state_mod.REUSE_LAST_STATE_MARKER:
worker = await self._acquire_worker(
Expand All @@ -414,7 +422,15 @@ async def compile_in_tx(
worker = await self._acquire_worker()
try:
resp = await worker.call(
"compile_in_tx", pickled_state, txid, *compile_args, msg=msg
"compile_in_tx",
state_id,
client_id,
dbname,
user_schema_pickle,
pickled_state,
txid,
*compile_args,
msg=msg,
)
status, *data = pickle.loads(resp)
if status == 0:
Expand Down
13 changes: 12 additions & 1 deletion edb/server/compiler_pool/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,23 @@ def compile(
return units, pickled_state


def compile_in_tx(cstate, *args, **kwargs):
def compile_in_tx(
dbname: Optional[str],
user_schema: Optional[bytes],
cstate,
*args,
**kwargs
):
global LAST_STATE
if cstate == state.REUSE_LAST_STATE_MARKER:
cstate = LAST_STATE
else:
cstate = pickle.loads(cstate)
if dbname is None:
assert user_schema is not None
cstate.set_root_user_schema(pickle.loads(user_schema))
else:
cstate.set_root_user_schema(DBS[dbname].user_schema)
units, cstate = COMPILER.compile_in_tx_request(cstate, *args, **kwargs)
LAST_STATE = cstate
return units, pickle.dumps(cstate, -1)
Expand Down
1 change: 1 addition & 0 deletions edb/server/dbview/dbview.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ cdef class DatabaseConnectionView:
object _txid
object _in_tx_db_config
object _in_tx_savepoints
object _in_tx_root_user_schema_pickle
object _in_tx_user_schema_pickle
object _in_tx_user_schema_version
object _in_tx_user_config_spec
Expand Down
Loading

0 comments on commit fe3d79c

Please sign in to comment.