From 0205fbc25dfbfc5d11a6a51aeb7cdc80ddf9d921 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 13 Jan 2025 17:40:41 -0800 Subject: [PATCH] Add TOML config file support (#8121) (#8215) This adds a hidden server command-line argument --config-file, as well as the "config-file" key in multi- tenant config file. Replaces #8059, Fixes #1325, Fixes #7990 Co-authored-by: Fantix King --- edb/common/asyncutil.py | 198 ++++++++++++++++++++++++++++- edb/pgsql/metaschema.py | 37 ++++-- edb/pgsql/patches.py | 3 +- edb/server/args.py | 10 ++ edb/server/config/__init__.py | 20 ++- edb/server/main.py | 77 +++++++----- edb/server/multitenant.py | 18 ++- edb/server/pgcon/__init__.py | 4 +- edb/server/pgcon/connect.py | 29 ++--- edb/server/pgcon/pgcon.pxd | 1 + edb/server/pgcon/pgcon.pyi | 3 +- edb/server/smtp.py | 4 +- edb/server/tenant.py | 200 ++++++++++++++++++++++++++++-- edb/testbase/server.py | 15 +++ tests/common/test_asyncutil.py | 219 +++++++++++++++++++++++++++++++++ tests/test_server_config.py | 174 ++++++++++++++++++++++++++ tests/test_server_ops.py | 99 ++++++++++++--- 17 files changed, 1011 insertions(+), 100 deletions(-) diff --git a/edb/common/asyncutil.py b/edb/common/asyncutil.py index 7d439a6172c..7a806cc1017 100644 --- a/edb/common/asyncutil.py +++ b/edb/common/asyncutil.py @@ -18,9 +18,20 @@ from __future__ import annotations -from typing import Callable, TypeVar, Awaitable +from typing import ( + Any, + Awaitable, + Callable, + cast, + overload, + Self, + TypeVar, + Type, +) import asyncio +import inspect +import warnings _T = TypeVar('_T') @@ -140,3 +151,188 @@ async def debounce( batch = [] last_signal = t target_time = None + + +_Owner = TypeVar("_Owner") +HandlerFunction = Callable[[], Awaitable[None]] +HandlerMethod = Callable[[Any], Awaitable[None]] + + +class ExclusiveTask: + """Manages to run a repeatable task once at a time.""" + + _handler: HandlerFunction + _task: asyncio.Task | None + _scheduled: bool + _stop_requested: bool + + def __init__(self, handler: HandlerFunction) -> None: + self._handler = handler + self._task = None + self._scheduled = False + self._stop_requested = False + + @property + def scheduled(self) -> bool: + return self._scheduled + + async def _run(self) -> None: + if self._scheduled and not self._stop_requested: + self._scheduled = False + else: + return + try: + await self._handler() + finally: + if self._scheduled and not self._stop_requested: + self._task = asyncio.create_task(self._run()) + else: + self._task = None + + def schedule(self) -> None: + """Schedule to run the task as soon as possible. + + If already scheduled, nothing happens; it won't queue up. + + If the task is already running, it will be scheduled to run again as + soon as the running task is done. + """ + if not self._stop_requested: + self._scheduled = True + if self._task is None: + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + """Cancel scheduled task and wait for the running one to finish. + + After an ExclusiveTask is stopped, no more new schedules are allowed. + Note: "cancel scheduled task" only means setting self._scheduled to + False; if an asyncio task is scheduled, stop() will still wait for it. + """ + self._scheduled = False + self._stop_requested = True + if self._task is not None: + await self._task + + +class ExclusiveTaskProperty: + _method: HandlerMethod + _name: str | None + + def __init__( + self, method: HandlerMethod, *, slot: str | None = None + ) -> None: + self._method = method + self._name = slot + + def __set_name__(self, owner: Type[_Owner], name: str) -> None: + if (slots := getattr(owner, "__slots__", None)) is not None: + if self._name is None: + raise TypeError("missing slot in @exclusive_task()") + if self._name not in slots: + raise TypeError( + f"slot {self._name!r} must be defined in __slots__" + ) + + if self._name is None: + self._name = name + + @overload + def __get__(self, instance: None, owner: Type[_Owner]) -> Self: ... + + @overload + def __get__( + self, instance: _Owner, owner: Type[_Owner] + ) -> ExclusiveTask: ... + + def __get__( + self, instance: _Owner | None, owner: Type[_Owner] + ) -> ExclusiveTask | Self: + # getattr on the class + if instance is None: + return self + + assert self._name is not None + + # getattr on an object with __dict__ + if (d := getattr(instance, "__dict__", None)) is not None: + if rv := d.get(self._name, None): + return rv + rv = ExclusiveTask(self._method.__get__(instance, owner)) + d[self._name] = rv + return rv + + # getattr on an object with __slots__ + else: + if rv := getattr(instance, self._name, None): + return rv + rv = ExclusiveTask(self._method.__get__(instance, owner)) + setattr(instance, self._name, rv) + return rv + + +ExclusiveTaskDecorator = Callable[ + [HandlerFunction | HandlerMethod], ExclusiveTask | ExclusiveTaskProperty +] + + +def _exclusive_task( + handler: HandlerFunction | HandlerMethod, *, slot: str | None +) -> ExclusiveTask | ExclusiveTaskProperty: + sig = inspect.signature(handler) + params = list(sig.parameters.values()) + if len(params) == 0: + handler = cast(HandlerFunction, handler) + if slot is not None: + warnings.warn( + "slot is specified but unused in @exclusive_task()", + stacklevel=2, + ) + return ExclusiveTask(handler) + elif len(params) == 1 and params[0].kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + handler = cast(HandlerMethod, handler) + return ExclusiveTaskProperty(handler, slot=slot) + else: + raise TypeError("bad signature") + + +@overload +def exclusive_task(handler: HandlerFunction) -> ExclusiveTask: ... + + +@overload +def exclusive_task( + handler: HandlerMethod, *, slot: str | None = None +) -> ExclusiveTaskProperty: ... + + +@overload +def exclusive_task(*, slot: str | None = None) -> ExclusiveTaskDecorator: ... + + +def exclusive_task( + handler: HandlerFunction | HandlerMethod | None = None, + *, + slot: str | None = None, +) -> ExclusiveTask | ExclusiveTaskProperty | ExclusiveTaskDecorator: + """Convert an async function into an ExclusiveTask. + + This decorator can be applied to either top-level functions or methods + in a class. In the latter case, the exclusiveness is bound to each object + of the owning class. If the owning class defines __slots__, you must also + define an extra slot to store the exclusive state and tell exclusive_task() + by providing the `slot` argument. + """ + if handler is None: + + def decorator( + handler: HandlerFunction | HandlerMethod, + ) -> ExclusiveTask | ExclusiveTaskProperty: + return _exclusive_task(handler, slot=slot) + + return decorator + + return _exclusive_task(handler, slot=slot) diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 8e323b420be..8fcd01f5ed7 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -3580,6 +3580,32 @@ class SysConfigFullFunction(trampoline.VersionedFunction): SELECT * FROM config_defaults WHERE name like '%::%' ), + config_static AS ( + SELECT + s.name AS name, + s.value AS value, + (CASE + WHEN s.type = 'A' THEN 'command line' + -- Due to inplace upgrade limits, without adding a new + -- layer, configuration file values are manually squashed + -- into the `environment variables` layer, see below. + ELSE 'environment variable' + END) AS source, + config_spec.backend_setting IS NOT NULL AS is_backend + FROM + _edgecon_state s + INNER JOIN config_spec ON (config_spec.name = s.name) + WHERE + -- Give precedence to configuration file values over + -- environment variables manually. + s.type = 'A' OR s.type = 'F' OR ( + s.type = 'E' AND NOT EXISTS ( + SELECT 1 FROM _edgecon_state ss + WHERE ss.name = s.name AND ss.type = 'F' + ) + ) + ), + config_sys AS ( SELECT s.key AS name, @@ -3610,16 +3636,12 @@ class SysConfigFullFunction(trampoline.VersionedFunction): SELECT s.name AS name, s.value AS value, - (CASE - WHEN s.type = 'A' THEN 'command line' - WHEN s.type = 'E' THEN 'environment variable' - ELSE 'session' - END) AS source, - FALSE AS from_backend -- only 'B' is for backend settings + 'session' AS source, + FALSE AS is_backend -- only 'B' is for backend settings FROM _edgecon_state s WHERE - s.type != 'B' + s.type = 'C' ), pg_db_setting AS ( @@ -3789,6 +3811,7 @@ class SysConfigFullFunction(trampoline.VersionedFunction): FROM ( SELECT * FROM config_defaults UNION ALL + SELECT * FROM config_static UNION ALL SELECT * FROM config_sys UNION ALL SELECT * FROM config_db UNION ALL SELECT * FROM config_sess diff --git a/edb/pgsql/patches.py b/edb/pgsql/patches.py index f3af993070b..83fae3875fd 100644 --- a/edb/pgsql/patches.py +++ b/edb/pgsql/patches.py @@ -77,11 +77,12 @@ def _setup_patches(patches: list[tuple[str, str]]) -> list[tuple[str, str]]: * sql-introspection - refresh all sql introspection views """ PATCHES: list[tuple[str, str]] = _setup_patches([ - # 6.0b2? + # 6.0b2 # One of the sql-introspection's adds a param with a default to # uuid_to_oid, so we need to drop the original to avoid ambiguity. ('sql', ''' drop function if exists edgedbsql_v6_2f20b3fed0.uuid_to_oid(uuid) cascade '''), ('sql-introspection', ''), + ('metaschema-sql', 'SysConfigFullFunction'), ]) diff --git a/edb/server/args.py b/edb/server/args.py index f76505df3f3..820a261fdf3 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -150,6 +150,7 @@ class ReloadTrigger(enum.StrEnum): 3. Multi-tenant config file (server config) 4. Readiness state (server or tenant config) 5. JWT sub allowlist and revocation list (server or tenant config) + 6. The TOML config file (server or tenant config) """ Default = "default" @@ -265,6 +266,7 @@ class ServerConfig(NamedTuple): disable_dynamic_system_config: bool reload_config_files: ReloadTrigger net_worker_mode: NetWorkerMode + config_file: Optional[pathlib.Path] startup_script: Optional[StartupScript] status_sinks: List[Callable[[str], None]] @@ -1106,6 +1108,13 @@ def resolve_envvar_value(self, ctx: click.Context): default='default', help='Controls how the std::net workers work.', ), + click.option( + "--config-file", type=PathPath(), metavar="PATH", + envvar="GEL_SERVER_CONFIG_FILE", + cls=EnvvarResolver, + help='Path to a TOML file to configure the server.', + hidden=True, + ), ]) @@ -1534,6 +1543,7 @@ def parse_args(**kwargs: Any): "readiness_state_file", "jwt_sub_allowlist_file", "jwt_revocation_list_file", + "config_file", ): if kwargs.get(name): opt = "--" + name.replace("_", "-") diff --git a/edb/server/config/__init__.py b/edb/server/config/__init__.py index 4cf98290979..c3420ae0272 100644 --- a/edb/server/config/__init__.py +++ b/edb/server/config/__init__.py @@ -18,7 +18,9 @@ from __future__ import annotations -from typing import Any, Mapping +from typing import Any, Mapping, TypedDict + +import enum import immutables @@ -50,9 +52,25 @@ 'load_ext_settings_from_schema', 'get_compilation_config', 'QueryCacheMode', + 'ConState', 'ConStateType', ) +# See edb/server/pgcon/connect.py for documentation of the types +class ConStateType(enum.StrEnum): + session_config = "C" + backend_session_config = "B" + command_line_argument = "A" + environment_variable = "E" + config_file = "F" + + +class ConState(TypedDict): + name: str + value: Any + type: ConStateType + + def lookup( name: str, *configs: Mapping[str, SettingValue], diff --git a/edb/server/main.py b/edb/server/main.py index de1de206150..1bd36c8ed01 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -36,7 +36,6 @@ import asyncio import contextlib -import enum import json import logging import os @@ -199,6 +198,7 @@ async def _run_server( do_setproctitle: bool, new_instance: bool, compiler: edbcompiler.Compiler, + init_con_data: list[config.ConState], ): sockets = service_manager.get_activation_listen_sockets() @@ -218,10 +218,12 @@ async def _run_server( backend_adaptive_ha=args.backend_adaptive_ha, extensions_dir=args.extensions_dir, ) + tenant.set_init_con_data(init_con_data) tenant.set_reloadable_files( readiness_state_file=args.readiness_state_file, jwt_sub_allowlist_file=args.jwt_sub_allowlist_file, jwt_revocation_list_file=args.jwt_revocation_list_file, + config_file=args.config_file, ) ss = server.Server( runstate_dir=runstate_dir, @@ -258,6 +260,8 @@ async def _run_server( await tenant.load_sidechannel_configs( json.loads(magic_smtp), compiler=compiler ) + if args.config_file: + await tenant.load_config_file(compiler) # This coroutine runs as long as the server, # and compiler(.state) is *heavy*, so make sure we don't # keep a reference to it. @@ -303,6 +307,7 @@ def load_configuration(_signum): args.tls_client_ca_file, ) ss.load_jwcrypto(args.jws_key_file) + tenant.reload_config_file.schedule() except Exception: logger.critical( "Unexpected error occurred during reload configuration; " @@ -531,12 +536,19 @@ async def run_server( compiler_state.config_spec, ) - sys_config, backend_settings = initialize_static_cfg( - args, - is_remote_cluster=True, - compiler=compiler, + sys_config, backend_settings, init_con_data = ( + initialize_static_cfg( + args, + is_remote_cluster=True, + compiler=compiler, + ) ) del compiler + if backend_settings: + abort( + 'Static backend settings for remote backend are ' + 'not supported' + ) with _internal_state_dir(runstate_dir, args) as ( int_runstate_dir, args, @@ -544,7 +556,6 @@ async def run_server( return await multitenant.run_server( args, sys_config=sys_config, - backend_settings=backend_settings, sys_queries={ key: sql.encode("utf-8") for key, sql in sys_queries.items() @@ -557,6 +568,7 @@ async def run_server( internal_runstate_dir=int_runstate_dir, do_setproctitle=do_setproctitle, compiler_state=compiler_state, + init_con_data=init_con_data, ) except server.StartupError as e: abort(str(e)) @@ -616,17 +628,22 @@ async def run_server( new_instance, compiler = await _init_cluster(cluster, args) - _, backend_settings = initialize_static_cfg( + _, backend_settings, init_con_data = initialize_static_cfg( args, is_remote_cluster=not is_local_cluster, compiler=compiler, ) - if is_local_cluster and (new_instance or backend_settings): - logger.info('Restarting server to reload configuration...') - await cluster.stop() - await cluster.start(server_settings=backend_settings) - backend_settings = {} + if is_local_cluster: + if new_instance or backend_settings: + logger.info('Restarting server to reload configuration...') + await cluster.stop() + await cluster.start(server_settings=backend_settings) + elif backend_settings: + abort( + 'Static backend settings for remote backend are not supported' + ) + del backend_settings if ( not args.bootstrap_only @@ -673,6 +690,7 @@ async def run_server( do_setproctitle=do_setproctitle, new_instance=new_instance, compiler=compiler, + init_con_data=init_con_data, ) except server.StartupError as e: @@ -807,28 +825,23 @@ def main_dev(): main() -class Source(enum.StrEnum): - command_line_argument = "A" - environment_variable = "E" - - -sources = { - Source.command_line_argument: "command line argument", - Source.environment_variable: "environment variable", -} - - def initialize_static_cfg( args: srvargs.ServerConfig, is_remote_cluster: bool, compiler: edbcompiler.Compiler, -) -> Tuple[Mapping[str, config.SettingValue], Dict[str, str]]: +) -> Tuple[ + Mapping[str, config.SettingValue], Dict[str, str], list[config.ConState] +]: result = {} - init_con_script_data = [] + init_con_script_data: list[config.ConState] = [] backend_settings = {} config_spec = compiler.state.config_spec + sources = { + config.ConStateType.command_line_argument: "command line argument", + config.ConStateType.environment_variable: "environment variable", + } - def add_config_values(obj: dict[str, Any], source: Source): + def add_config_values(obj: dict[str, Any], source: config.ConStateType): settings = compiler.compile_structured_config( {"cfg::Config": obj}, source=sources[source] )["cfg::Config"] @@ -837,7 +850,7 @@ def add_config_values(obj: dict[str, Any], source: Source): if is_remote_cluster: if setting.backend_setting and setting.requires_restart: - if source == Source.command_line_argument: + if source == config.ConStateType.command_line_argument: where = "on command line" else: where = "as an environment variable" @@ -876,7 +889,7 @@ def add_config_values(obj: dict[str, Any], source: Source): if cfg != name: values[cfg] = value if values: - add_config_values(values, Source.environment_variable) + add_config_values(values, config.ConStateType.environment_variable) values = {} if args.bind_addresses: @@ -884,13 +897,9 @@ def add_config_values(obj: dict[str, Any], source: Source): if args.port: values["listen_port"] = args.port if values: - add_config_values(values, Source.command_line_argument) - - if init_con_script_data: - from . import pgcon - pgcon.set_init_con_script_data(init_con_script_data) + add_config_values(values, config.ConStateType.command_line_argument) - return immutables.Map(result), backend_settings + return immutables.Map(result), backend_settings, init_con_script_data if __name__ == '__main__': diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index c4fdbfc33c1..cc59e71c30d 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -63,6 +63,7 @@ "jwt-revocation-list-file": str, "readiness-state-file": str, "admin": bool, + "config-file": str, }, ) @@ -70,7 +71,7 @@ class MultiTenantServer(server.BaseServer): _config_file: pathlib.Path _sys_config: Mapping[str, config.SettingValue] - _backend_settings: Mapping[str, str] + _init_con_data: list[config.ConState] _tenants_by_sslobj: MutableMapping _tenants_conf: dict[str, dict[str, str]] @@ -89,7 +90,7 @@ def __init__( *, compiler_pool_tenant_cache_size: int, sys_config: Mapping[str, config.SettingValue], - backend_settings: Mapping[str, str], + init_con_data: list[config.ConState], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], **kwargs, @@ -97,7 +98,7 @@ def __init__( super().__init__(**kwargs) self._config_file = config_file self._sys_config = sys_config - self._backend_settings = backend_settings + self._init_con_data = init_con_data self._compiler_pool_tenant_cache_size = compiler_pool_tenant_cache_size self._tenants_by_sslobj = weakref.WeakKeyDictionary() @@ -243,13 +244,18 @@ async def _create_tenant(self, conf: TenantConfig) -> edbtenant.Tenant: max_backend_connections=max_conns, backend_adaptive_ha=conf.get("backend-adaptive-ha", False), ) + tenant.set_init_con_data(self._init_con_data) + config_file = conf.get("config-file") tenant.set_reloadable_files( readiness_state_file=conf.get("readiness-state-file"), jwt_sub_allowlist_file=conf.get("jwt-sub-allowlist-file"), jwt_revocation_list_file=conf.get("jwt-revocation-list-file"), + config_file=config_file, ) tenant.set_server(self) tenant.load_jwcrypto() + if config_file: + await tenant.load_config_file(self.get_compiler_pool()) try: await tenant.init_sys_pgcon() await tenant.init() @@ -379,6 +385,7 @@ async def _reload_tenant(self, serial: int, sni: str, conf: TenantConfig): "readiness-state-file", "jwt-sub-allowlist-file", "jwt-revocation-list-file", + "config-file", } if diff: logger.warning( @@ -395,6 +402,7 @@ async def _reload_tenant(self, serial: int, sni: str, conf: TenantConfig): "jwt-sub-allowlist-file"), jwt_revocation_list_file=conf.get( "jwt-revocation-list-file"), + config_file=conf.get("config-file"), ): # none of the reloadable values was modified return @@ -429,7 +437,7 @@ async def run_server( args: srvargs.ServerConfig, *, sys_config: Mapping[str, config.SettingValue], - backend_settings: Mapping[str, str], + init_con_data: list[config.ConState], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], runstate_dir: pathlib.Path, @@ -444,7 +452,7 @@ async def run_server( ss = MultiTenantServer( multitenant_config_file, sys_config=sys_config, - backend_settings=backend_settings, + init_con_data=init_con_data, sys_queries=sys_queries, report_config_typedesc=report_config_typedesc, runstate_dir=runstate_dir, diff --git a/edb/server/pgcon/__init__.py b/edb/server/pgcon/__init__.py index c6bc771fb67..8babc785151 100644 --- a/edb/server/pgcon/__init__.py +++ b/edb/server/pgcon/__init__.py @@ -31,14 +31,13 @@ ) from .connect import ( pg_connect, - set_init_con_script_data, SETUP_TEMP_TABLE_SCRIPT, SETUP_CONFIG_CACHE_SCRIPT, + RESET_STATIC_CFG_SCRIPT, ) __all__ = ( 'pg_connect', - 'set_init_con_script_data', 'PGConnection', 'BackendError', 'BackendConnectionError', @@ -46,4 +45,5 @@ 'BackendCatalogNameError', 'SETUP_TEMP_TABLE_SCRIPT', 'SETUP_CONFIG_CACHE_SCRIPT', + 'RESET_STATIC_CFG_SCRIPT' ) diff --git a/edb/server/pgcon/connect.py b/edb/server/pgcon/connect.py index 7d912b76c70..aafe1532f3f 100644 --- a/edb/server/pgcon/connect.py +++ b/edb/server/pgcon/connect.py @@ -18,12 +18,10 @@ from __future__ import annotations -import json import logging import textwrap from edb.pgsql.common import quote_ident as pg_qi -from edb.pgsql.common import quote_literal as pg_ql from edb.pgsql import params as pg_params from edb.server import pgcon @@ -33,7 +31,6 @@ logger = logging.getLogger('edb.server') INIT_CON_SCRIPT: bytes | None = None -INIT_CON_SCRIPT_DATA = '' # The '_edgecon_state table' is used to store information about # the current session. The `type` column is one character, with one @@ -45,12 +42,16 @@ # a corresponding Postgres config setting. # * 'A': an instance-level config setting from command-line arguments # * 'E': an instance-level config setting from environment variable +# * 'F': an instance/tenant-level config setting from the TOML config file +# +# Please also update ConStateType in edb/server/config/__init__.py if changed. SETUP_TEMP_TABLE_SCRIPT = ''' CREATE TEMPORARY TABLE _edgecon_state ( name text NOT NULL, value jsonb NOT NULL, type text NOT NULL CHECK( - type = 'C' OR type = 'B' OR type = 'A' OR type = 'E'), + type = 'C' OR type = 'B' OR type = 'A' OR type = 'E' + OR type = 'F'), UNIQUE(name, type) ); '''.strip() @@ -60,6 +61,12 @@ value edgedb._sys_config_val_t NOT NULL ); '''.strip() +RESET_STATIC_CFG_SCRIPT: bytes = b''' + WITH x1 AS ( + DELETE FROM _config_cache + ) + DELETE FROM _edgecon_state WHERE type = 'A' OR type = 'E' OR type = 'F'; +''' def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: @@ -82,8 +89,6 @@ def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: {SETUP_TEMP_TABLE_SCRIPT} {SETUP_CONFIG_CACHE_SCRIPT} - {INIT_CON_SCRIPT_DATA} - PREPARE _clear_state AS WITH x1 AS ( DELETE FROM _config_cache @@ -193,15 +198,3 @@ async def pg_connect( raise return pgconn - - -def set_init_con_script_data(cfg): - global INIT_CON_SCRIPT, INIT_CON_SCRIPT_DATA - INIT_CON_SCRIPT = None - INIT_CON_SCRIPT_DATA = ( - f''' - INSERT INTO _edgecon_state - SELECT * FROM jsonb_to_recordset({pg_ql(json.dumps(cfg))}::jsonb) - AS cfg(name text, value jsonb, type text); - ''' - ).strip() diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index bc7883ca340..6b9cf16b0c1 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -138,6 +138,7 @@ cdef class PGConnection: public object pinned_by object last_state + public object last_init_con_data str last_indirect_return diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index 7c6821bf028..8e961dfdc62 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -36,8 +36,6 @@ class BackendConnectionError(BackendError): ... class BackendPrivilegeError(BackendError): ... class BackendCatalogNameError(BackendError): ... -def set_init_con_script_data(cfg: list[dict[str, Any]]): ... - class PGConnection(asyncio.Protocol): idle: bool @@ -47,6 +45,7 @@ class PGConnection(asyncio.Protocol): parameter_status: dict[str, str] backend_secret: int is_ssl: bool + last_init_con_data: object def __init__(self, dbname): ... async def close(self): ... diff --git a/edb/server/smtp.py b/edb/server/smtp.py index 0265bc50df7..003730dbbec 100644 --- a/edb/server/smtp.py +++ b/edb/server/smtp.py @@ -56,7 +56,7 @@ class SMTPProviderConfig: class SMTP: def __init__(self, db: dbview.Database): - current_provider = _get_current_email_provider(db) + current_provider = get_current_email_provider(db) self.sender = current_provider.sender or "noreply@example.com" default_port = ( 465 @@ -205,7 +205,7 @@ def _send_test_mode_email(self, message: email.message.Message): pickle.dump(args, f) -def _get_current_email_provider( +def get_current_email_provider( db: dbview.Database, ) -> SMTPProviderConfig: current_provider_name = db.lookup_config("current_email_provider_name") diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 5ef7903e8df..50130f2a17f 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -34,6 +34,7 @@ import asyncio import contextlib +import dataclasses import functools import json import logging @@ -42,6 +43,7 @@ import pickle import struct import sys +import textwrap import time import tomllib import uuid @@ -51,6 +53,7 @@ from edb import buildmeta from edb import errors +from edb.common import asyncutil from edb.common import retryloop from edb.common import verutils from edb.common.log import current_tenant @@ -76,6 +79,7 @@ from . import pgcluster from . import server as edbserver + from . import compiler_pool as edbcompiler_pool logger = logging.getLogger("edb.server") @@ -116,12 +120,15 @@ class Tenant(ha_base.ClusterProtocol): _suggested_client_pool_size: int _pg_pool: connpool.Pool _pg_unavailable_msg: str | None + _init_con_data: list[config.ConState] + _init_con_sql: bytes | None _ha_master_serial: int _backend_adaptive_ha: adaptive_ha.AdaptiveHASupport | None _readiness_state_file: pathlib.Path | None _readiness: srvargs.ReadinessState _readiness_reason: str + _config_file: pathlib.Path | None _extensions_dirs: tuple[pathlib.Path, ...] @@ -182,6 +189,7 @@ def __init__( self._readiness_state_file = None self._readiness = srvargs.ReadinessState.Default self._readiness_reason = "" + self._config_file = None self._max_backend_connections = max_backend_connections self._suggested_client_pool_size = max( @@ -199,6 +207,8 @@ def __init__( self._pg_unavailable_msg = None self._block_new_connections = set() self._report_config_data = {} + self._init_con_data = [] + self._init_con_sql = None # DB state will be initialized in init(). self._dbindex = None @@ -222,6 +232,7 @@ def set_reloadable_files( readiness_state_file: str | pathlib.Path | None = None, jwt_sub_allowlist_file: str | pathlib.Path | None = None, jwt_revocation_list_file: str | pathlib.Path | None = None, + config_file: str | pathlib.Path | None = None, ) -> bool: rv = False @@ -243,6 +254,12 @@ def set_reloadable_files( self._jwt_revocation_list_file = jwt_revocation_list_file rv = True + if isinstance(config_file, str): + config_file = pathlib.Path(config_file) + if self._config_file != config_file: + self._config_file = config_file + rv = True + return rv def set_server(self, server: edbserver.BaseServer) -> None: @@ -250,14 +267,26 @@ def set_server(self, server: edbserver.BaseServer) -> None: self.__loop = server.get_loop() async def load_sidechannel_configs( - self, value: Any, *, compiler: edbcompiler.Compiler | None = None + self, + value: Any, + *, + compiler: ( + edbcompiler.Compiler | edbcompiler_pool.AbstractPool | None + ) = None, ) -> None: if compiler is None: compiler = self._server.get_compiler_pool() - result = compiler.compile_structured_config( - {"cfg::Config": {"email_providers": value}}, source="magic", - allow_nested=True, - ) + objects = {"cfg::Config": {"email_providers": value}} + if isinstance(compiler, edbcompiler.Compiler): + result = compiler.compile_structured_config( + objects, source="magic", allow_nested=True + ) + else: + result = await compiler.compile_structured_config( + objects, + "magic", # source + True, # allow_nested + ) email_providers = result["cfg::Config"]["email_providers"] self._sidechannel_email_configs = list(email_providers.value) @@ -604,6 +633,15 @@ def reload_jwt_revocation_list_file(): ) ) + if self._config_file is not None: + + def reload_config_file(): + self.reload_config_file.schedule() + + self._file_watch_finalizers.append( + self.server.monitor_fs(self._config_file, reload_config_file) + ) + async def start_accepting_new_tasks(self) -> None: assert self._task_group is None self._task_group = asyncio.TaskGroup() @@ -701,6 +739,27 @@ def terminate_sys_pgcon(self) -> None: self.__sys_pgcon = None del self._sys_pgcon_waiter + def set_init_con_data(self, data: list[config.ConState]) -> None: + self._init_con_data = data + self._init_con_sql = None + if data: + self._init_con_sql = self._make_init_con_sql(data) + + def _make_init_con_sql(self, data: list[config.ConState]) -> bytes: + if not data: + return b"" + + from edb.pgsql import common + + quoted_json = common.quote_literal(json.dumps(data)) + return textwrap.dedent( + f''' + INSERT INTO _edgecon_state + SELECT * FROM jsonb_to_recordset({quoted_json}::jsonb) + AS cfg(name text, value jsonb, type text); + ''' + ).strip().encode() + async def _pg_connect( self, dbname: str, @@ -720,6 +779,11 @@ async def _pg_connect( ) if self._server.stmt_cache_size is not None: rv.set_stmt_cache_size(self._server.stmt_cache_size) + + if self._init_con_sql: + await rv.sql_execute(self._init_con_sql) + rv.last_init_con_data = self._init_con_data + except Exception: metrics.backend_connection_establishment_errors.inc( 1.0, self._instance_name @@ -971,11 +1035,24 @@ async def acquire_pgcon(self, dbname: str) -> pgcon.PGConnection: for _ in range(self._pg_pool.max_capacity): conn = await self._pg_pool.acquire(dbname) - if conn.is_healthy(): - return conn + if not conn.is_healthy(): + logger.warning("acquired an unhealthy pgcon; discard now") + elif conn.last_init_con_data is not self._init_con_data: + try: + await conn.sql_execute( + pgcon.RESET_STATIC_CFG_SCRIPT + + (self._init_con_sql or b'') + ) + except Exception as e: + logger.warning( + "failed to update pgcon; discard now: %s", e + ) + else: + conn.last_init_con_data = self._init_con_data + return conn else: - logger.warning("Acquired an unhealthy pgcon; discard now.") - self._pg_pool.release(dbname, conn, discard=True) + return conn + self._pg_pool.release(dbname, conn, discard=True) else: # This is unlikely to happen, but we defer to the caller to retry # when it does happen @@ -1355,9 +1432,13 @@ async def _load_reported_config(self) -> None: async def _load_sys_config( self, query_name: str = "sysconfig", + syscon: pgcon.PGConnection | None = None, ) -> Mapping[str, config.SettingValue]: - async with self.use_sys_pgcon() as syscon: - query = self._server.get_sys_query(query_name) + query = self._server.get_sys_query(query_name) + if syscon is None: + async with self.use_sys_pgcon() as syscon: + sys_config_json = await syscon.sql_fetch_val(query) + else: sys_config_json = await syscon.sql_fetch_val(query) return config.from_json(self._server.config_settings, sys_config_json) @@ -1569,6 +1650,93 @@ def set_readiness_state(self, state: srvargs.ReadinessState, reason: str): self._readiness = state self._readiness_reason = reason + @asyncutil.exclusive_task + async def reload_config_file(self): + if self._config_file is None: + return + + try: + await self._reload_config_file() + except Exception: + logger.error("failed to reload config file", exc_info=True) + metrics.background_errors.inc( + 1.0, self._instance_name, "reload_config_file" + ) + + async def load_config_file(self, compiler): + logger.info("loading config file") + + # Read the TOML file + with self._config_file.open('rb') as f: + toml_data = tomllib.load(f) + + # Handle special case for `magic_smtp_config` + magic_smtp_config = toml_data.pop("magic_smtp_config", None) + if magic_smtp_config: + await self.load_sidechannel_configs( + magic_smtp_config, compiler=compiler + ) + + # Parse TOML config file content into JSON + if toml_data and toml_data.get("cfg::Config"): + result = compiler.compile_structured_config( + toml_data, "configuration file" + ) + if asyncio.iscoroutine(result): + result = await result + + def setting_filter(value: config.SettingValue) -> bool: + if self._server.config_settings[value.name].backend_setting: + raise errors.ConfigurationError( + f"backend config {value.name!r} cannot be set " + f"via config file" + ) + return True + + json_obj = config.to_json_obj( + self._server.config_settings, + result["cfg::Config"], + include_source=False, + setting_filter=setting_filter, + ) + config_file_data = [ + { + "name": name, + "value": value, + "type": config.ConStateType.config_file, + } + for name, value in json_obj.items() + ] + else: + config_file_data = [] + + # Update init_con_data and SQL + self.set_init_con_data( + [ + cs + for cs in self._init_con_data + if cs["type"] != config.ConStateType.config_file + ] + + config_file_data + ) + + async def _reload_config_file(self): + # Load TOML config file + compiler = self._server.get_compiler_pool() + await self.load_config_file(compiler) + + # Update sys pgcon and reload system config + async with self.use_sys_pgcon() as syscon: + if syscon.last_init_con_data is not self._init_con_data: + await syscon.sql_execute( + pgcon.RESET_STATIC_CFG_SCRIPT + (self._init_con_sql or b'') + ) + syscon.last_init_con_data = self._init_con_data + sys_config = await self._load_sys_config(syscon=syscon) + # GOTCHA: no need to notify other EdgeDBs on the same backend about + # such change to sysconfig, because static config is instance-local + self._dbindex.update_sys_config(sys_config) + def reload(self): # In multi-tenant mode, the file paths for the following states may be # unset in a reload, while it's impossible in a regular server. @@ -1583,6 +1751,7 @@ def reload(self): self.reload_readiness_state() self.load_jwcrypto() + self.reload_config_file.schedule() self.start_watching_files() @@ -1921,6 +2090,8 @@ async def task(): self.create_task(task(), interruptable=True) def get_debug_info(self) -> dict[str, Any]: + from . import smtp + pgaddr = self.get_pgaddr() pgaddr.clear_server_settings() pgdict = pgaddr.__dict__ @@ -1949,6 +2120,12 @@ def get_debug_info(self) -> dict[str, Any]: if db.name in defines.EDGEDB_SPECIAL_DBS: continue + try: + email_provider = dataclasses.asdict( + smtp.get_current_email_provider(db) + ) + except errors.ConfigurationError: + email_provider = None dbs[db.name] = dict( name=db.name, dbver=db.dbver, @@ -1969,6 +2146,7 @@ def get_debug_info(self) -> dict[str, Any]: ) for view in db.iter_views() ], + current_email_provider=email_provider, ) obj["databases"] = dbs diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 7b07d070c48..ed2f28aabe1 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -2388,6 +2388,7 @@ def __init__( jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, multitenant_config: Optional[str] = None, + config_file: Optional[os.PathLike] = None, default_branch: Optional[str] = None, env: Optional[Dict[str, str]] = None, extra_args: Optional[List[str]] = None, @@ -2424,6 +2425,7 @@ def __init__( self.jwt_sub_allowlist_file = jwt_sub_allowlist_file self.jwt_revocation_list_file = jwt_revocation_list_file self.multitenant_config = multitenant_config + self.config_file = config_file self.default_branch = default_branch self.env = env self.extra_args = extra_args @@ -2591,6 +2593,9 @@ async def __aenter__(self): cmd += ['--jwt-revocation-list-file', self.jwt_revocation_list_file] + if self.config_file: + cmd += ['--config-file', self.config_file] + if not self.multitenant_config: cmd += ['--instance-name=localtest'] @@ -2752,6 +2757,7 @@ def start_edgedb_server( jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, multitenant_config: Optional[str] = None, + config_file: Optional[os.PathLike] = None, env: Optional[Dict[str, str]] = None, extra_args: Optional[List[str]] = None, default_branch: Optional[str] = None, @@ -2832,6 +2838,7 @@ def start_edgedb_server( jwt_sub_allowlist_file=jwt_sub_allowlist_file, jwt_revocation_list_file=jwt_revocation_list_file, multitenant_config=multitenant_config, + config_file=config_file, env=env, extra_args=extra_args, default_branch=default_branch, @@ -3033,5 +3040,13 @@ async def g(self, *args, **kwargs): return decorator +@contextlib.asynccontextmanager +async def temp_file_with(data: bytes): + with tempfile.NamedTemporaryFile() as f: + f.write(data) + f.flush() + yield f + + needs_factoring = _needs_factoring(weakly=False) needs_factoring_weakly = _needs_factoring(weakly=True) diff --git a/tests/common/test_asyncutil.py b/tests/common/test_asyncutil.py index dfd99f37b2a..2f9f31afcc4 100644 --- a/tests/common/test_asyncutil.py +++ b/tests/common/test_asyncutil.py @@ -101,3 +101,222 @@ async def sleep_until(t): (2020, [15]), ], ) + + +class TestExclusiveTask(unittest.TestCase): + async def _test(self, task: asyncutil.ExclusiveTask, get_counter): + # double-schedule is effective only once + task.schedule() + self.assertTrue(task.scheduled) + task.schedule() + self.assertTrue(task.scheduled) + self.assertEqual(get_counter(), 0) + + # an exclusive task is running, schedule another one with a double shot + await asyncio.sleep(4) + self.assertFalse(task.scheduled) + await asyncio.sleep(1) + task.schedule() + self.assertTrue(task.scheduled) + task.schedule() + self.assertTrue(task.scheduled) + self.assertEqual(get_counter(), 1) + + # first task done, second follows immediately + await asyncio.sleep(5) + self.assertFalse(task.scheduled) + self.assertEqual(get_counter(), 3) + + # all done + await asyncio.sleep(9) + self.assertFalse(task.scheduled) + self.assertEqual(get_counter(), 4) + + # works repeatedly + await asyncio.sleep(1) + task.schedule() + self.assertTrue(task.scheduled) + await asyncio.sleep(3) + self.assertFalse(task.scheduled) + await asyncio.sleep(1) + task.schedule() + self.assertTrue(task.scheduled) + self.assertEqual(get_counter(), 5) + + # now stop the scheduled task and wait for the running one to finish + await asyncio.sleep(1) + await task.stop() + self.assertFalse(task.scheduled) + self.assertEqual(get_counter(), 6) + + # no further schedule allowed + task.schedule() + self.assertFalse(task.scheduled) + await asyncio.sleep(10) + self.assertEqual(get_counter(), 6) + + @with_fake_event_loop + async def test_exclusive_task_01(self): + counter = 0 + + @asyncutil.exclusive_task + async def task(): + nonlocal counter + counter += 1 + await asyncio.sleep(8) + counter += 1 + + await self._test(task, lambda: counter) + + @with_fake_event_loop + async def test_exclusive_task_02(self): + counter = 0 + + @asyncutil.exclusive_task() + async def task(): + nonlocal counter + counter += 1 + await asyncio.sleep(8) + counter += 1 + + await self._test(task, lambda: counter) + + @with_fake_event_loop + async def test_exclusive_task_03(self): + class MyClass: + def __init__(self): + self.counter = 0 + + @asyncutil.exclusive_task + async def task(self): + self.counter += 1 + await asyncio.sleep(8) + self.counter += 1 + + obj = MyClass() + await self._test(obj.task, lambda: obj.counter) + + @with_fake_event_loop + async def test_exclusive_task_04(self): + class MyClass: + def __init__(self): + self.counter = 0 + + @asyncutil.exclusive_task(slot="another") + async def task(self): + self.counter += 1 + await asyncio.sleep(8) + self.counter += 1 + + obj = MyClass() + await self._test(obj.task, lambda: obj.counter) + + @with_fake_event_loop + async def test_exclusive_task_05(self): + class MyClass: + __slots__ = ("counter", "another",) + + def __init__(self): + self.counter = 0 + + @asyncutil.exclusive_task(slot="another") + async def task(self): + self.counter += 1 + await asyncio.sleep(8) + self.counter += 1 + + obj = MyClass() + await self._test(obj.task, lambda: obj.counter) + + @with_fake_event_loop + async def test_exclusive_task_06(self): + class MyClass: + def __init__(self, factor: int): + self.counter = 0 + self.factor = factor + + @asyncutil.exclusive_task + async def task(self): + self.counter += self.factor + await asyncio.sleep(8) + self.counter += self.factor + + obj1 = MyClass(1) + obj2 = MyClass(2) + async with asyncio.TaskGroup() as g: + g.create_task( + self._test(obj1.task, lambda: obj1.counter // obj1.factor) + ) + await asyncio.sleep(3) + g.create_task( + self._test(obj2.task, lambda: obj2.counter // obj2.factor) + ) + + def test_exclusive_task_07(self): + with self.assertRaises(TypeError): + class MyClass: + __slots__ = () + + @asyncutil.exclusive_task + async def task(self): + pass + + def test_exclusive_task_08(self): + with self.assertRaises(TypeError): + class MyClass: + __slots__ = () + + @asyncutil.exclusive_task(slot="missing") + async def task(self): + pass + + def test_exclusive_task_09(self): + with self.assertRaises(TypeError): + @asyncutil.exclusive_task + async def task(*args, **kwargs): + pass + + def test_exclusive_task_10(self): + with self.assertRaises(TypeError): + @asyncutil.exclusive_task + async def task(*, p): + pass + + def test_exclusive_task_11(self): + with self.assertRaises(TypeError): + class MyClass: + @asyncutil.exclusive_task + async def task(self, p): + pass + + def test_exclusive_task_12(self): + with self.assertRaises(TypeError): + class MyClass: + @asyncutil.exclusive_task + @classmethod + async def task(cls): + pass + + @with_fake_event_loop + async def test_exclusive_task_13(self): + counter = 0 + + class MyClass: + @asyncutil.exclusive_task + @staticmethod + async def task(): + nonlocal counter + counter += 1 + await asyncio.sleep(8) + counter += 1 + + obj1 = MyClass() + obj2 = MyClass() + + async with asyncio.TaskGroup() as g: + g.create_task( + self._test(obj1.task, lambda: counter) + ) + g.create_task( + self._test(obj2.task, lambda: counter) + ) diff --git a/tests/test_server_config.py b/tests/test_server_config.py index 9206d102c4a..2bc4145e91a 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -19,10 +19,12 @@ import asyncio import datetime +import enum import json import os import platform import random +import signal import tempfile import textwrap import unittest @@ -2311,6 +2313,178 @@ async def test_server_config_default(self): finally: await conn.aclose() + async def test_server_config_file_01(self): + conf = textwrap.dedent(''' + ["cfg::Config"] + session_idle_timeout = "8m42s" + durprop = "996" + apply_access_policies = false + multiprop = "single" + current_email_provider_name = "localmock" + + [[magic_smtp_config]] + _tname = "cfg::SMTPProviderConfig" + name = "localmock" + sender = "sender@example.com" + timeout_per_email = "1 minute 48 seconds" + ''') + async with tb.temp_file_with( + conf.encode() + ) as config_file, tb.start_edgedb_server( + config_file=config_file.name, + http_endpoint_security=args.ServerEndpointSecurityMode.Optional, + ) as sd: + conn = await sd.connect() + try: + sysconfig = conn.get_settings()["system_config"] + self.assertEqual( + sysconfig.session_idle_timeout, + datetime.timedelta(minutes=8, seconds=42), + ) + + self.assertEqual( + await conn.query_single("""\ + select assert_single(cfg::Config.session_idle_timeout) + """), + datetime.timedelta(minutes=8, seconds=42), + ) + self.assertEqual( + await conn.query_single("""\ + select assert_single( + cfg::Config.durprop) + """), + datetime.timedelta(seconds=996), + ) + self.assertFalse( + await conn.query_single("""\ + select assert_single(cfg::Config.apply_access_policies) + """) + ) + self.assertEqual( + await conn.query("""\ + select assert_single(cfg::Config).multiprop + """), + ["single"], + ) + + dbname = await conn.query_single("""\ + select sys::get_current_branch() + """) + provider = sd.fetch_server_info()["databases"][dbname][ + "current_email_provider" + ] + self.assertEqual(provider['name'], 'localmock') + self.assertEqual(provider['sender'], 'sender@example.com') + self.assertEqual(provider['timeout_per_email'], 'PT1M48S') + + await conn.query("""\ + configure current database + set current_email_provider_name := 'non_exist'; + """) + async for tr in self.try_until_succeeds(ignore=AssertionError): + async with tr: + provider = sd.fetch_server_info()["databases"][dbname][ + "current_email_provider" + ] + self.assertIsNone(provider) + finally: + await conn.aclose() + + async def test_server_config_file_02(self): + conf = textwrap.dedent(''' + ["cfg::Config"] + allow_bare_ddl = "illegal_input" + ''') + with self.assertRaisesRegex( + cluster.ClusterError, + "'cfg::AllowBareDDL' enum has no member called 'illegal_input'" + ): + async with tb.temp_file_with( + conf.encode() + ) as config_file, tb.start_edgedb_server( + config_file=config_file.name, + ): + pass + + async def test_server_config_file_03(self): + conf = textwrap.dedent(''' + ["cfg::Config"] + apply_access_policies = "on" + ''') + with self.assertRaisesRegex( + cluster.ClusterError, + "can only be one of: true, false", + ): + async with tb.temp_file_with( + conf.encode() + ) as config_file, tb.start_edgedb_server( + config_file=config_file.name, + ): + pass + + async def test_server_config_file_04(self): + conf = textwrap.dedent(''' + ["cfg::Config"] + query_execution_timeout = "1 hour" + ''') + with self.assertRaisesRegex( + cluster.ClusterError, + "backend config 'query_execution_timeout' cannot be set " + "via config file" + ): + async with tb.temp_file_with( + conf.encode() + ) as config_file, tb.start_edgedb_server( + config_file=config_file.name, + ): + pass + + async def test_server_config_file_05(self): + class Prop(enum.Enum): + One = "One" + Two = "Two" + Three = "Three" + + conf = textwrap.dedent(''' + ["cfg::Config"] + enumprop = "One" + ''') + async with tb.temp_file_with( + conf.encode() + ) as config_file, tb.start_edgedb_server( + config_file=config_file.name, + ) as sd: + conn = await sd.connect() + try: + self.assertEqual( + await conn.query_single("""\ + select assert_single( + cfg::Config.enumprop) + """), + Prop.One, + ) + + config_file.seek(0) + config_file.truncate() + config_file.write(textwrap.dedent(''' + ["cfg::Config"] + enumprop = "Three" + ''').encode()) + config_file.flush() + os.kill(sd.pid, signal.SIGHUP) + + async for tr in self.try_until_succeeds(ignore=AssertionError): + async with tr: + self.assertEqual( + await conn.query_single("""\ + select assert_single( + cfg::Config.enumprop) + """), + Prop.Three, + ) + finally: + await conn.aclose() + class TestDynamicSystemConfig(tb.TestCase): async def test_server_dynamic_system_config(self): diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 5ab3f198d6b..f156bd1dd7d 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -33,6 +33,7 @@ import ssl import sys import tempfile +import textwrap import time import unittest import urllib.error @@ -1531,7 +1532,6 @@ async def _init_pg_cluster(self, path): raise return cluster, connect_args - @unittest.skip('Test was failing mysteriously in CI. See #7933.') async def test_server_ops_multi_tenant(self): with ( tempfile.TemporaryDirectory() as td1, @@ -1539,18 +1539,28 @@ async def test_server_ops_multi_tenant(self): tempfile.NamedTemporaryFile("w+") as conf_file, tempfile.NamedTemporaryFile("w+") as rd1, tempfile.NamedTemporaryFile("w+") as rd2, + tempfile.NamedTemporaryFile("w+") as cf1, + tempfile.NamedTemporaryFile("w+") as cf2, ): fs = [] conf = {} - for i, td, rd in [(1, td1, rd1), (2, td2, rd2)]: + for i, td, rd, cf in [(1, td1, rd1, cf1), (2, td2, rd2, cf2)]: rd.file.write("default:ok") rd.file.flush() + cf.write(textwrap.dedent(f""" + [[magic_smtp_config]] + _tname = "cfg::SMTPProviderConfig" + name = "provider:{i}" + sender = "sender@host{i}.com" + """)) + cf.flush() fs.append(self.loop.create_task(self._init_pg_cluster(td))) conf[f"{i}.localhost"] = { "instance-name": f"localtest{i}", "backend-dsn": f'postgres:///?user=postgres&host={td}', "max-backend-connections": 10, "readiness-state-file": rd.name, + "config-file": cf.name, } await asyncio.wait(fs) cluster1, args1 = await fs[0] @@ -1566,12 +1576,22 @@ async def test_server_ops_multi_tenant(self): runstate_dir=runstate_dir, multitenant_config=conf_file.name, max_allowed_connections=None, + http_endpoint_security=args.ServerEndpointSecurityMode.Optional, ) async with srv as sd: mtargs = MultiTenantArgs( - srv, sd, conf_file, conf, args1, args2, rd1, rd2 + srv, + sd, + conf_file, + conf, + args1, + args2, + rd1, + rd2, + cf1, + cf2 ) - for i in range(1, 7): + for i in range(1, 8): name = f"_test_server_ops_multi_tenant_{i}" with self.subTest(name, i=i): await getattr(self, name)(mtargs) @@ -1604,9 +1624,6 @@ async def _test_server_ops_multi_tenant_3(self, mtargs: MultiTenantArgs): self.assertIn( '\nedgedb_server_mt_tenants_current 2.0\n', data ) - self.assertIn( - '\nedgedb_server_mt_config_reload_errors_total 0.0\n', data - ) self.assertIn( '\nedgedb_server_mt_tenant_add_total' '{tenant="localtest1"} 1.0\n', @@ -1634,10 +1651,6 @@ async def _test_server_ops_multi_tenant_3(self, mtargs: MultiTenantArgs): '\nedgedb_server_mt_tenants_current 1.0\n', data, ) - self.assertIn( - '\nedgedb_server_mt_config_reload_errors_total 0.0\n', - data, - ) self.assertIn( '\nedgedb_server_mt_tenant_add_total' '{tenant="localtest1"} 1.0\n', @@ -1665,10 +1678,6 @@ async def _test_server_ops_multi_tenant_3(self, mtargs: MultiTenantArgs): '\nedgedb_server_mt_tenants_current 2.0\n', data, ) - self.assertIn( - '\nedgedb_server_mt_config_reload_errors_total 0.0\n', - data, - ) self.assertIn( '\nedgedb_server_mt_tenant_add_total' '{tenant="localtest1"} 2.0\n', @@ -1733,7 +1742,7 @@ async def _test_server_ops_multi_tenant_5(self, mtargs: MultiTenantArgs): await self._test_server_ops_multi_tenant_2(mtargs) async def _test_server_ops_global_compile_cache( - self, mtargs: MultiTenantArgs, ddl, **kwargs + self, mtargs: MultiTenantArgs, ddl, i, **kwargs ): conn = await mtargs.sd.connect(**kwargs) try: @@ -1748,6 +1757,9 @@ async def _test_server_ops_global_compile_cache( insert ext::auth::EmailPasswordProviderConfig {{ require_verification := false, }}; + + configure current database set + current_email_provider_name := 'provider:{i}'; ''') finally: await conn.aclose() @@ -1775,14 +1787,49 @@ async def _test_server_ops_multi_tenant_6(self, mtargs: MultiTenantArgs): await self._test_server_ops_global_compile_cache( mtargs, "create type GlobalCache1 { create property name: str }", + 1, **mtargs.args1, ) await self._test_server_ops_global_compile_cache( mtargs, "create type GlobalCache2 { create property active: bool }", + 2, **mtargs.args2, ) + async def _test_server_ops_multi_tenant_7(self, mtargs: MultiTenantArgs): + self.assertEqual( + (await mtargs.current_email_provider(1))["sender"], + "sender@host1.com", + ) + self.assertEqual( + (await mtargs.current_email_provider(2))["sender"], + "sender@host2.com", + ) + + mtargs.cf1.seek(0) + mtargs.cf1.truncate(0) + mtargs.cf1.write(textwrap.dedent(""" + [[magic_smtp_config]] + _tname = "cfg::SMTPProviderConfig" + name = "provider:1" + sender = "updated@example.com" + """)) + mtargs.cf1.flush() + assert mtargs.srv.proc is not None + mtargs.srv.proc.send_signal(signal.SIGHUP) + + async for tr in self.try_until_succeeds(ignore=AssertionError): + async with tr: + self.assertEqual( + (await mtargs.current_email_provider(1))["sender"], + "updated@example.com", + ) + self.assertEqual( + (await mtargs.current_email_provider(2))["sender"], + "sender@host2.com", + ) + class MultiTenantArgs(NamedTuple): srv: tb._EdgeDBServer @@ -1793,6 +1840,9 @@ class MultiTenantArgs(NamedTuple): args2: dict[str, str] rd1: tempfile._TemporaryFileWrapper rd2: tempfile._TemporaryFileWrapper + cf1: tempfile._TemporaryFileWrapper + cf2: tempfile._TemporaryFileWrapper + dbnames: list[str | None] = [None, None] def reload_server(self): self.conf_file.file.seek(0) @@ -1801,6 +1851,23 @@ def reload_server(self): self.conf_file.file.flush() self.srv.proc.send_signal(signal.SIGHUP) + def fetch_server_info(self, i): + return self.sd.fetch_server_info()["tenants"][f"{i}.localhost"] + + async def current_email_provider(self, i): + tenant_info = self.fetch_server_info(i) + dbname = self.dbnames[i - 1] + if dbname is None: + conn = await self.sd.connect(**getattr(self, f"args{i}")) + try: + dbname = await conn.query_single("""\ + select sys::get_current_branch() + """) + self.dbnames[i - 1] = dbname + finally: + await conn.aclose() + return tenant_info["databases"][dbname]["current_email_provider"] + class TestPGExtensions(tb.TestCase): async def test_edb_stat_statements(self):