Skip to content

Commit

Permalink
Refactor to put init_con_data in Tenant
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Dec 5, 2024
1 parent 82fac25 commit 4307873
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 33 deletions.
13 changes: 12 additions & 1 deletion edb/server/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


from __future__ import annotations
from typing import Any, Mapping
from typing import Any, Literal, Mapping, TypedDict

import immutables

Expand Down Expand Up @@ -51,9 +51,20 @@
'get_compilation_config',
'coerce_single_value',
'QueryCacheMode',
'ConState', 'ConStateType',
)


# See edb/server/pgcon/connect.py for documentation of the types
ConStateType = Literal['C', 'B', 'A', 'E']


class ConState(TypedDict):
name: str
value: Any
type: ConStateType


def lookup(
name: str,
*configs: Mapping[str, SettingValue],
Expand Down
26 changes: 15 additions & 11 deletions edb/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ async def _run_server(
do_setproctitle: bool,
new_instance: bool,
compiler_state: edbcompiler.CompilerState,
init_con_data: list[config.ConState],
):

sockets = service_manager.get_activation_listen_sockets()
Expand All @@ -216,6 +217,7 @@ async def _run_server(
backend_adaptive_ha=args.backend_adaptive_ha,
extensions_dir=args.extensions_dir,
)
tenant.set_init_con_data(init_con_data)
tenant.set_reloadable_files(
readiness_state_file=args.readiness_state_file,
jwt_sub_allowlist_file=args.jwt_sub_allowlist_file,
Expand Down Expand Up @@ -522,10 +524,12 @@ async def run_server(
compiler_state.config_spec,
)

sys_config, backend_settings = initialize_static_cfg(
args,
is_remote_cluster=True,
config_spec=compiler_state.config_spec,
sys_config, backend_settings, init_con_data = (
initialize_static_cfg(
args,
is_remote_cluster=True,
config_spec=compiler_state.config_spec,
)
)
with _internal_state_dir(runstate_dir, args) as (
int_runstate_dir,
Expand All @@ -547,6 +551,7 @@ async def run_server(
internal_runstate_dir=int_runstate_dir,
do_setproctitle=do_setproctitle,
compiler_state=compiler_state,
init_con_data=init_con_data,
)
except server.StartupError as e:
abort(str(e))
Expand Down Expand Up @@ -606,7 +611,7 @@ async def run_server(

new_instance, compiler_state = await _init_cluster(cluster, args)

_, backend_settings = initialize_static_cfg(
_, backend_settings, init_con_data = initialize_static_cfg(
args,
is_remote_cluster=not is_local_cluster,
config_spec=compiler_state.config_spec,
Expand Down Expand Up @@ -663,6 +668,7 @@ async def run_server(
do_setproctitle=do_setproctitle,
new_instance=new_instance,
compiler_state=compiler_state,
init_con_data=init_con_data,
)

except server.StartupError as e:
Expand Down Expand Up @@ -810,7 +816,9 @@ def initialize_static_cfg(
args: srvargs.ServerConfig,
is_remote_cluster: bool,
config_spec: config.Spec
) -> Tuple[Mapping[str, config.SettingValue], Dict[str, str]]:
) -> Tuple[
Mapping[str, config.SettingValue], Dict[str, str], list[config.ConState]
]:
result = {}
init_con_script_data = []
backend_settings = {}
Expand Down Expand Up @@ -898,11 +906,7 @@ def iter_environ():
if args.port:
add_config("listen_port", args.port, command_line_argument)

if init_con_script_data:
from . import pgcon
pgcon.set_init_con_script_data(init_con_script_data)

return immutables.Map(result), backend_settings
return immutables.Map(result), backend_settings, init_con_script_data


if __name__ == '__main__':
Expand Down
6 changes: 6 additions & 0 deletions edb/server/multitenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
class MultiTenantServer(server.BaseServer):
_config_file: pathlib.Path
_sys_config: Mapping[str, config.SettingValue]
_init_con_data: list[config.ConState]
_backend_settings: Mapping[str, str]

_tenants_by_sslobj: MutableMapping
Expand All @@ -89,6 +90,7 @@ def __init__(
*,
compiler_pool_tenant_cache_size: int,
sys_config: Mapping[str, config.SettingValue],
init_con_data: list[config.ConState],
backend_settings: Mapping[str, str],
sys_queries: Mapping[str, bytes],
report_config_typedesc: dict[defines.ProtocolVersion, bytes],
Expand All @@ -97,6 +99,7 @@ def __init__(
super().__init__(**kwargs)
self._config_file = config_file
self._sys_config = sys_config
self._init_con_data = init_con_data
self._backend_settings = backend_settings
self._compiler_pool_tenant_cache_size = compiler_pool_tenant_cache_size

Expand Down Expand Up @@ -243,6 +246,7 @@ async def _create_tenant(self, conf: TenantConfig) -> edbtenant.Tenant:
max_backend_connections=max_conns,
backend_adaptive_ha=conf.get("backend-adaptive-ha", False),
)
tenant.set_init_con_data(self._init_con_data)
tenant.set_reloadable_files(
readiness_state_file=conf.get("readiness-state-file"),
jwt_sub_allowlist_file=conf.get("jwt-sub-allowlist-file"),
Expand Down Expand Up @@ -429,6 +433,7 @@ async def run_server(
args: srvargs.ServerConfig,
*,
sys_config: Mapping[str, config.SettingValue],
init_con_data: list[config.ConState],
backend_settings: Mapping[str, str],
sys_queries: Mapping[str, bytes],
report_config_typedesc: dict[defines.ProtocolVersion, bytes],
Expand All @@ -444,6 +449,7 @@ async def run_server(
ss = MultiTenantServer(
multitenant_config_file,
sys_config=sys_config,
init_con_data=init_con_data,
backend_settings=backend_settings,
sys_queries=sys_queries,
report_config_typedesc=report_config_typedesc,
Expand Down
2 changes: 0 additions & 2 deletions edb/server/pgcon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
)
from .connect import (
pg_connect,
set_init_con_script_data,
SETUP_TEMP_TABLE_SCRIPT,
SETUP_CONFIG_CACHE_SCRIPT,
)

__all__ = (
'pg_connect',
'set_init_con_script_data',
'PGConnection',
'BackendError',
'BackendConnectionError',
Expand Down
19 changes: 2 additions & 17 deletions edb/server/pgcon/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@

from __future__ import annotations

import json
import logging
import textwrap

from edb.pgsql.common import quote_ident as pg_qi
from edb.pgsql.common import quote_literal as pg_ql
from edb.pgsql import params as pg_params
from edb.server import pgcon

Expand All @@ -33,7 +31,6 @@
logger = logging.getLogger('edb.server')

INIT_CON_SCRIPT: bytes | None = None
INIT_CON_SCRIPT_DATA = ''

# The '_edgecon_state table' is used to store information about
# the current session. The `type` column is one character, with one
Expand All @@ -45,6 +42,8 @@
# a corresponding Postgres config setting.
# * 'A': an instance-level config setting from command-line arguments
# * 'E': an instance-level config setting from environment variable
#
# Please also update ConStateType in edb/server/config/__init__.py if changed.
SETUP_TEMP_TABLE_SCRIPT = '''
CREATE TEMPORARY TABLE _edgecon_state (
name text NOT NULL,
Expand Down Expand Up @@ -82,8 +81,6 @@ def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes:
{SETUP_TEMP_TABLE_SCRIPT}
{SETUP_CONFIG_CACHE_SCRIPT}
{INIT_CON_SCRIPT_DATA}
PREPARE _clear_state AS
WITH x1 AS (
DELETE FROM _config_cache
Expand Down Expand Up @@ -193,15 +190,3 @@ async def pg_connect(
raise

return pgconn


def set_init_con_script_data(cfg):
global INIT_CON_SCRIPT, INIT_CON_SCRIPT_DATA
INIT_CON_SCRIPT = None
INIT_CON_SCRIPT_DATA = (
f'''
INSERT INTO _edgecon_state
SELECT * FROM jsonb_to_recordset({pg_ql(json.dumps(cfg))}::jsonb)
AS cfg(name text, value jsonb, type text);
'''
).strip()
2 changes: 0 additions & 2 deletions edb/server/pgcon/pgcon.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class BackendConnectionError(BackendError): ...
class BackendPrivilegeError(BackendError): ...
class BackendCatalogNameError(BackendError): ...

def set_init_con_script_data(cfg: list[dict[str, Any]]): ...

class PGConnection(asyncio.Protocol):

idle: bool
Expand Down
24 changes: 24 additions & 0 deletions edb/server/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import pickle
import struct
import sys
import textwrap
import time
import tomllib
import uuid
Expand Down Expand Up @@ -116,6 +117,8 @@ class Tenant(ha_base.ClusterProtocol):
_suggested_client_pool_size: int
_pg_pool: connpool.Pool
_pg_unavailable_msg: str | None
_init_con_data: list[config.ConState]
_init_con_sql: bytes | None

_ha_master_serial: int
_backend_adaptive_ha: adaptive_ha.AdaptiveHASupport | None
Expand Down Expand Up @@ -199,6 +202,8 @@ def __init__(
self._pg_unavailable_msg = None
self._block_new_connections = set()
self._report_config_data = {}
self._init_con_data = []
self._init_con_sql = None

# DB state will be initialized in init().
self._dbindex = None
Expand Down Expand Up @@ -694,6 +699,21 @@ def terminate_sys_pgcon(self) -> None:
self.__sys_pgcon = None
del self._sys_pgcon_waiter

def set_init_con_data(self, data: list[config.ConState]) -> None:
self._init_con_data = data
self._init_con_sql = None
if data:
from edb.pgsql import common

quoted_json = common.quote_literal(json.dumps(data))
self._init_con_sql = textwrap.dedent(
f'''
INSERT INTO _edgecon_state
SELECT * FROM jsonb_to_recordset({quoted_json}::jsonb)
AS cfg(name text, value jsonb, type text);
'''
).strip().encode()

async def _pg_connect(
self,
dbname: str,
Expand All @@ -713,6 +733,10 @@ async def _pg_connect(
)
if self._server.stmt_cache_size is not None:
rv.set_stmt_cache_size(self._server.stmt_cache_size)

if self._init_con_sql is not None:
await rv.sql_execute(self._init_con_sql)

except Exception:
metrics.backend_connection_establishment_errors.inc(
1.0, self._instance_name
Expand Down

0 comments on commit 4307873

Please sign in to comment.