diff --git a/edb/server/args.py b/edb/server/args.py index 8396ac5d018..ea6a9fbce6b 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -164,6 +164,11 @@ class ReloadTrigger(enum.StrEnum): """Watch the files for changes and reload when it happens.""" +class NetWorkerMode(enum.StrEnum): + Default = "default" + Disabled = "disabled" + + class ServerAuthMethods: def __init__( @@ -258,6 +263,7 @@ class ServerConfig(NamedTuple): readiness_state_file: Optional[pathlib.Path] disable_dynamic_system_config: bool reload_config_files: ReloadTrigger + net_worker_mode: NetWorkerMode startup_script: Optional[StartupScript] status_sinks: List[Callable[[str], None]] @@ -1043,6 +1049,16 @@ def resolve_envvar_value(self, ctx: click.Context): help='Specifies when to reload the config files. See the docstring of ' 'ReloadTrigger for more information.', ), + click.option( + "--net-worker-mode", + envvar="EDGEDB_SERVER_NET_WORKER_MODE", cls=EnvvarResolver, + type=click.Choice( + list(NetWorkerMode.__members__.values()), case_sensitive=True + ), + hidden=True, + default='default', + help='Controls how the std::net workers work.', + ), ]) @@ -1555,6 +1571,7 @@ def parse_args(**kwargs: Any): kwargs['reload_config_files'] = ReloadTrigger( kwargs['reload_config_files'] ) + kwargs['net_worker_mode'] = NetWorkerMode(kwargs['net_worker_mode']) if 'EDGEDB_SERVER_CONFIG_cfg::listen_addresses' in os.environ: abort( diff --git a/edb/server/cluster.py b/edb/server/cluster.py index 1116eec07d7..f7bbd2fdd75 100644 --- a/edb/server/cluster.py +++ b/edb/server/cluster.py @@ -68,6 +68,9 @@ def __init__( compiler_pool_mode: Optional[ edgedb_args.CompilerPoolMode ] = None, + net_worker_mode: Optional[ + edgedb_args.NetWorkerMode + ] = None, ): self._edgedb_cmd = [sys.executable, '-m', 'edb.server.main'] @@ -114,6 +117,12 @@ def __init__( str(compiler_pool_mode), )) + if net_worker_mode is not None: + self._edgedb_cmd.extend(( + '--net-worker-mode', + str(net_worker_mode), + )) + self._log_level = log_level self._runstate_dir = runstate_dir self._edgedb_cmd.extend(['--runstate-dir', str(runstate_dir)]) diff --git a/edb/server/main.py b/edb/server/main.py index edee0be9e17..8050c88e404 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -251,6 +251,7 @@ async def _run_server( srvargs.ReloadTrigger.Default, srvargs.ReloadTrigger.FileSystemEvent, ], + net_worker_mode=args.net_worker_mode, ) # This coroutine runs as long as the server, # and compiler_state is *heavy*, so make sure we don't diff --git a/edb/server/server.py b/edb/server/server.py index 647cb9a8988..168b4a198c8 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -162,6 +162,7 @@ def __init__( disable_dynamic_system_config: bool = False, compiler_state: edbcompiler.CompilerState, use_monitor_fs: bool = False, + net_worker_mode: srvargs.NetWorkerMode = srvargs.NetWorkerMode.Default, ): self.__loop = asyncio.get_running_loop() self._use_monitor_fs = use_monitor_fs @@ -228,6 +229,7 @@ def __init__( self._http_request_logger = None self._auth_gc = None self._net_worker_http = None + self._net_worker_mode = net_worker_mode self._stop_evt = asyncio.Event() self._tls_cert_file: str | Any = None @@ -1044,7 +1046,10 @@ async def start(self): await self._after_start_servers() self._auth_gc = self.__loop.create_task(pkce.gc(self)) - self._net_worker_http = self.__loop.create_task(net_worker.http(self)) + if self._net_worker_mode is srvargs.NetWorkerMode.Default: + self._net_worker_http = self.__loop.create_task( + net_worker.http(self) + ) if self._echo_runtime_info: ri = { diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 77bfba6c7ef..eb8b5697e50 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -764,7 +764,18 @@ def _extract_background_errors(metrics: str) -> str | None: async def drop_db(conn, dbname): - await conn.execute(f'DROP DATABASE {dbname}') + # net_worker (#7742) may issue a query while DROP DATABASE happens. + # This is a WIP bug that should be solved when adopting notification- + # driven net_worker, but hack around it with a retry loop for now. + # Without this, tests would flake a lot. + async for tr in TestCase.try_until_succeeds( + ignore=(edgedb.ExecutionError, edgedb.ClientConnectionError), + timeout=30 + ): + async with tr: + await conn.execute( + f'DROP DATABASE {dbname};' + ) class ClusterTestCase(BaseHTTPTestCase): @@ -2234,6 +2245,7 @@ def __init__( default_branch: Optional[str] = None, env: Optional[Dict[str, str]] = None, extra_args: Optional[List[str]] = None, + net_worker_mode: Optional[str] = None, ) -> None: self.bind_addrs = bind_addrs self.auto_shutdown_after = auto_shutdown_after @@ -2268,6 +2280,7 @@ def __init__( self.default_branch = default_branch self.env = env self.extra_args = extra_args + self.net_worker_mode = net_worker_mode async def wait_for_server_readiness(self, stream: asyncio.StreamReader): while True: @@ -2438,6 +2451,9 @@ async def __aenter__(self): if not self.multitenant_config: cmd += ['--instance-name=localtest'] + if self.net_worker_mode: + cmd += ['--net-worker-mode', self.net_worker_mode] + if self.extra_args: cmd.extend(self.extra_args) @@ -2596,6 +2612,7 @@ def start_edgedb_server( env: Optional[Dict[str, str]] = None, extra_args: Optional[List[str]] = None, default_branch: Optional[str] = None, + net_worker_mode: Optional[str] = None, ): if (not devmode.is_in_dev_mode() or adjacent_to) and not runstate_dir: if backend_dsn or adjacent_to: @@ -2665,6 +2682,7 @@ def start_edgedb_server( env=env, extra_args=extra_args, default_branch=default_branch, + net_worker_mode=net_worker_mode, ) diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 80b7cb30ef6..99fb8110ba4 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -720,6 +720,7 @@ async def test_server_ops_cache_recompile_01(self): async with tb.start_edgedb_server( data_dir=temp_dir, default_auth_method=args.ServerAuthMethod.Trust, + net_worker_mode='disabled', ) as sd: con = await sd.connect() try: @@ -775,6 +776,7 @@ async def test_server_ops_cache_recompile_01(self): async with tb.start_edgedb_server( data_dir=temp_dir, default_auth_method=args.ServerAuthMethod.Trust, + net_worker_mode='disabled', ) as sd: con = await sd.connect() try: