diff --git a/edb/server/config/__init__.py b/edb/server/config/__init__.py index 6ddd9fdd2a4..45b4176c7f7 100644 --- a/edb/server/config/__init__.py +++ b/edb/server/config/__init__.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Any, Mapping +from typing import Any, Literal, Mapping, TypedDict import immutables @@ -51,9 +51,20 @@ 'get_compilation_config', 'coerce_single_value', 'QueryCacheMode', + 'ConState', 'ConStateType', ) +# See edb/server/pgcon/connect.py for documentation of the types +ConStateType = Literal['C', 'B', 'A', 'E'] + + +class ConState(TypedDict): + name: str + value: Any + type: ConStateType + + def lookup( name: str, *configs: Mapping[str, SettingValue], diff --git a/edb/server/main.py b/edb/server/main.py index c5dbf4e9394..93deb9b0a1e 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -197,6 +197,7 @@ async def _run_server( do_setproctitle: bool, new_instance: bool, compiler_state: edbcompiler.CompilerState, + init_con_data: list[config.ConState], ): sockets = service_manager.get_activation_listen_sockets() @@ -216,6 +217,7 @@ async def _run_server( backend_adaptive_ha=args.backend_adaptive_ha, extensions_dir=args.extensions_dir, ) + tenant.set_init_con_data(init_con_data) tenant.set_reloadable_files( readiness_state_file=args.readiness_state_file, jwt_sub_allowlist_file=args.jwt_sub_allowlist_file, @@ -522,10 +524,12 @@ async def run_server( compiler_state.config_spec, ) - sys_config, backend_settings = initialize_static_cfg( - args, - is_remote_cluster=True, - config_spec=compiler_state.config_spec, + sys_config, backend_settings, init_con_data = ( + initialize_static_cfg( + args, + is_remote_cluster=True, + config_spec=compiler_state.config_spec, + ) ) with _internal_state_dir(runstate_dir, args) as ( int_runstate_dir, @@ -547,6 +551,7 @@ async def run_server( internal_runstate_dir=int_runstate_dir, do_setproctitle=do_setproctitle, compiler_state=compiler_state, + init_con_data=init_con_data, ) except server.StartupError as e: abort(str(e)) @@ -606,7 +611,7 @@ async def run_server( new_instance, compiler_state = await _init_cluster(cluster, args) - _, backend_settings = initialize_static_cfg( + _, backend_settings, init_con_data = initialize_static_cfg( args, is_remote_cluster=not is_local_cluster, config_spec=compiler_state.config_spec, @@ -663,6 +668,7 @@ async def run_server( do_setproctitle=do_setproctitle, new_instance=new_instance, compiler_state=compiler_state, + init_con_data=init_con_data, ) except server.StartupError as e: @@ -810,7 +816,9 @@ def initialize_static_cfg( args: srvargs.ServerConfig, is_remote_cluster: bool, config_spec: config.Spec -) -> Tuple[Mapping[str, config.SettingValue], Dict[str, str]]: +) -> Tuple[ + Mapping[str, config.SettingValue], Dict[str, str], list[config.ConState] +]: result = {} init_con_script_data = [] backend_settings = {} @@ -898,11 +906,7 @@ def iter_environ(): if args.port: add_config("listen_port", args.port, command_line_argument) - if init_con_script_data: - from . import pgcon - pgcon.set_init_con_script_data(init_con_script_data) - - return immutables.Map(result), backend_settings + return immutables.Map(result), backend_settings, init_con_script_data if __name__ == '__main__': diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index c4fdbfc33c1..f66f5b50ceb 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -70,6 +70,7 @@ class MultiTenantServer(server.BaseServer): _config_file: pathlib.Path _sys_config: Mapping[str, config.SettingValue] + _init_con_data: list[config.ConState] _backend_settings: Mapping[str, str] _tenants_by_sslobj: MutableMapping @@ -89,6 +90,7 @@ def __init__( *, compiler_pool_tenant_cache_size: int, sys_config: Mapping[str, config.SettingValue], + init_con_data: list[config.ConState], backend_settings: Mapping[str, str], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], @@ -97,6 +99,7 @@ def __init__( super().__init__(**kwargs) self._config_file = config_file self._sys_config = sys_config + self._init_con_data = init_con_data self._backend_settings = backend_settings self._compiler_pool_tenant_cache_size = compiler_pool_tenant_cache_size @@ -243,6 +246,7 @@ async def _create_tenant(self, conf: TenantConfig) -> edbtenant.Tenant: max_backend_connections=max_conns, backend_adaptive_ha=conf.get("backend-adaptive-ha", False), ) + tenant.set_init_con_data(self._init_con_data) tenant.set_reloadable_files( readiness_state_file=conf.get("readiness-state-file"), jwt_sub_allowlist_file=conf.get("jwt-sub-allowlist-file"), @@ -429,6 +433,7 @@ async def run_server( args: srvargs.ServerConfig, *, sys_config: Mapping[str, config.SettingValue], + init_con_data: list[config.ConState], backend_settings: Mapping[str, str], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], @@ -444,6 +449,7 @@ async def run_server( ss = MultiTenantServer( multitenant_config_file, sys_config=sys_config, + init_con_data=init_con_data, backend_settings=backend_settings, sys_queries=sys_queries, report_config_typedesc=report_config_typedesc, diff --git a/edb/server/pgcon/__init__.py b/edb/server/pgcon/__init__.py index c6bc771fb67..2edf5755519 100644 --- a/edb/server/pgcon/__init__.py +++ b/edb/server/pgcon/__init__.py @@ -31,14 +31,12 @@ ) from .connect import ( pg_connect, - set_init_con_script_data, SETUP_TEMP_TABLE_SCRIPT, SETUP_CONFIG_CACHE_SCRIPT, ) __all__ = ( 'pg_connect', - 'set_init_con_script_data', 'PGConnection', 'BackendError', 'BackendConnectionError', diff --git a/edb/server/pgcon/connect.py b/edb/server/pgcon/connect.py index 7d912b76c70..051871eeb13 100644 --- a/edb/server/pgcon/connect.py +++ b/edb/server/pgcon/connect.py @@ -18,12 +18,10 @@ from __future__ import annotations -import json import logging import textwrap from edb.pgsql.common import quote_ident as pg_qi -from edb.pgsql.common import quote_literal as pg_ql from edb.pgsql import params as pg_params from edb.server import pgcon @@ -33,7 +31,6 @@ logger = logging.getLogger('edb.server') INIT_CON_SCRIPT: bytes | None = None -INIT_CON_SCRIPT_DATA = '' # The '_edgecon_state table' is used to store information about # the current session. The `type` column is one character, with one @@ -45,6 +42,8 @@ # a corresponding Postgres config setting. # * 'A': an instance-level config setting from command-line arguments # * 'E': an instance-level config setting from environment variable +# +# Please also update ConStateType in edb/server/config/__init__.py if changed. SETUP_TEMP_TABLE_SCRIPT = ''' CREATE TEMPORARY TABLE _edgecon_state ( name text NOT NULL, @@ -82,8 +81,6 @@ def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: {SETUP_TEMP_TABLE_SCRIPT} {SETUP_CONFIG_CACHE_SCRIPT} - {INIT_CON_SCRIPT_DATA} - PREPARE _clear_state AS WITH x1 AS ( DELETE FROM _config_cache @@ -193,15 +190,3 @@ async def pg_connect( raise return pgconn - - -def set_init_con_script_data(cfg): - global INIT_CON_SCRIPT, INIT_CON_SCRIPT_DATA - INIT_CON_SCRIPT = None - INIT_CON_SCRIPT_DATA = ( - f''' - INSERT INTO _edgecon_state - SELECT * FROM jsonb_to_recordset({pg_ql(json.dumps(cfg))}::jsonb) - AS cfg(name text, value jsonb, type text); - ''' - ).strip() diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index 7c6821bf028..03eb86cecc1 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -36,8 +36,6 @@ class BackendConnectionError(BackendError): ... class BackendPrivilegeError(BackendError): ... class BackendCatalogNameError(BackendError): ... -def set_init_con_script_data(cfg: list[dict[str, Any]]): ... - class PGConnection(asyncio.Protocol): idle: bool diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 77047a6cd38..e8b5358da55 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -42,6 +42,7 @@ import pickle import struct import sys +import textwrap import time import tomllib import uuid @@ -116,6 +117,8 @@ class Tenant(ha_base.ClusterProtocol): _suggested_client_pool_size: int _pg_pool: connpool.Pool _pg_unavailable_msg: str | None + _init_con_data: list[config.ConState] + _init_con_sql: bytes | None _ha_master_serial: int _backend_adaptive_ha: adaptive_ha.AdaptiveHASupport | None @@ -199,6 +202,8 @@ def __init__( self._pg_unavailable_msg = None self._block_new_connections = set() self._report_config_data = {} + self._init_con_data = [] + self._init_con_sql = None # DB state will be initialized in init(). self._dbindex = None @@ -694,6 +699,21 @@ def terminate_sys_pgcon(self) -> None: self.__sys_pgcon = None del self._sys_pgcon_waiter + def set_init_con_data(self, data: list[config.ConState]) -> None: + self._init_con_data = data + self._init_con_sql = None + if data: + from edb.pgsql import common + + quoted_json = common.quote_literal(json.dumps(data)) + self._init_con_sql = textwrap.dedent( + f''' + INSERT INTO _edgecon_state + SELECT * FROM jsonb_to_recordset({quoted_json}::jsonb) + AS cfg(name text, value jsonb, type text); + ''' + ).strip().encode() + async def _pg_connect( self, dbname: str, @@ -713,6 +733,10 @@ async def _pg_connect( ) if self._server.stmt_cache_size is not None: rv.set_stmt_cache_size(self._server.stmt_cache_size) + + if self._init_con_sql is not None: + await rv.sql_execute(self._init_con_sql) + except Exception: metrics.backend_connection_establishment_errors.inc( 1.0, self._instance_name