From 1c8e13b99bd9a57b7efe60609356e19cca464a0c Mon Sep 17 00:00:00 2001 From: James Clarke Date: Wed, 5 Feb 2025 18:30:57 +0000 Subject: [PATCH] Add a `cors-always-allowed-origins` option (#8233) For cloud instances we currently default to setting the `cors_allow_origins` config to `*` to let the cloud ui make queries to the instance. But the user can change that config, which breaks the UI, and also for security it's probably better to leave the default `cors_allow_origins` config empty. Since we want to allow the cloud UI to always be able to query the instance, regardless of how the user configures `cors_allow_origins`, this server option will allow us to set the cloud ui as an always allowed cors origin. Supporting multiple origins and wildcards is to handle the gel rename (allow both our new and old cloud ui urls to work) and for preview ui builds. --- edb/server/args.py | 12 ++++++++++++ edb/server/main.py | 1 + edb/server/multitenant.py | 1 + edb/server/protocol/protocol.pyx | 12 ++++++++++-- edb/server/server.py | 14 ++++++++++++++ 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/edb/server/args.py b/edb/server/args.py index 6428dbf33a2..17a1f8a5132 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -292,6 +292,8 @@ class ServerConfig(NamedTuple): admin_ui: bool + cors_always_allowed_origins: Optional[str] + class PathPath(click.Path): name = 'path' @@ -1089,6 +1091,16 @@ def resolve_envvar_value(self, ctx: click.Context): ), default='default', help='Enable admin UI.'), + click.option( + '--cors-always-allowed-origins', + envvar="GEL_SERVER_CORS_ALWAYS_ALLOWED_ORIGINS", + cls=EnvvarResolver, + hidden=True, + help='A comma separated list of origins to always allow CORS requests ' + 'from regardless of the `cors_allow_orgin` config. The `*` ' + 'character can be used as a wildcard. Intended for use by cloud ' + 'to always allow the cloud UI to make requests to the instance.' + ), click.option( '--disable-dynamic-system-config', is_flag=True, envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG", diff --git a/edb/server/main.py b/edb/server/main.py index 302616a3b16..ac32ac803a7 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -248,6 +248,7 @@ async def _run_server( pidfile_dir=args.pidfile_dir, new_instance=new_instance, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_state=compiler.state, tenant=tenant, diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index afa70a2b8b7..9765d59c89f 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -467,6 +467,7 @@ async def run_server( default_auth_method=args.default_auth_method, testmode=args.testmode, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_pool_size=args.compiler_pool_size, compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant, diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index 3b1fe928050..3cf84e5c633 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -840,13 +840,21 @@ cdef class HttpProtocol: config = self.tenant.get_sys_config().get('cors_allow_origins') allowed_origins = config.value if config else None + overrides = self.server.get_cors_always_allowed_origins() - if allowed_origins is None: + if allowed_origins is None and overrides == []: return False origin = request.origin.decode() if request.origin else None + origin_allowed = origin is not None and ( - origin in allowed_origins or '*' in allowed_origins) + any( + override.match(origin) if isinstance(override, re.Pattern) + else origin == override + for override in overrides + ) + or (origin in allowed_origins or '*' in allowed_origins) + ) if origin_allowed: response.custom_headers['Access-Control-Allow-Origin'] = origin diff --git a/edb/server/server.py b/edb/server/server.py index 6fa8d595375..5885cfd684c 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -20,6 +20,7 @@ from __future__ import annotations +import re from typing import ( Any, Callable, @@ -160,6 +161,7 @@ def __init__( default_auth_method: srvargs.ServerAuthMethods = ( srvargs.DEFAULT_AUTH_METHODS), admin_ui: bool = False, + cors_always_allowed_origins: Optional[str] = None, disable_dynamic_system_config: bool = False, compiler_state: edbcompiler.CompilerState, use_monitor_fs: bool = False, @@ -252,6 +254,15 @@ def __init__( self._admin_ui = admin_ui + self._cors_always_allowed_origins = [ + re.compile( + '^' + origin + .replace('.', '\\.') + .replace('*', '.*') + '$' + ) if '*' in origin else origin + for origin in cors_always_allowed_origins.split(',') + ] if cors_always_allowed_origins else [] + self._file_watch_handles = [] self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None @@ -308,6 +319,9 @@ def in_test_mode(self): def is_admin_ui_enabled(self): return self._admin_ui + def get_cors_always_allowed_origins(self): + return self._cors_always_allowed_origins + def on_binary_client_created(self) -> str: self._binary_proto_id_counter += 1