diff --git a/edb/server/args.py b/edb/server/args.py index 6428dbf33a2..17a1f8a5132 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -292,6 +292,8 @@ class ServerConfig(NamedTuple): admin_ui: bool + cors_always_allowed_origins: Optional[str] + class PathPath(click.Path): name = 'path' @@ -1089,6 +1091,16 @@ def resolve_envvar_value(self, ctx: click.Context): ), default='default', help='Enable admin UI.'), + click.option( + '--cors-always-allowed-origins', + envvar="GEL_SERVER_CORS_ALWAYS_ALLOWED_ORIGINS", + cls=EnvvarResolver, + hidden=True, + help='A comma separated list of origins to always allow CORS requests ' + 'from regardless of the `cors_allow_orgin` config. The `*` ' + 'character can be used as a wildcard. Intended for use by cloud ' + 'to always allow the cloud UI to make requests to the instance.' + ), click.option( '--disable-dynamic-system-config', is_flag=True, envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG", diff --git a/edb/server/main.py b/edb/server/main.py index 302616a3b16..ac32ac803a7 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -248,6 +248,7 @@ async def _run_server( pidfile_dir=args.pidfile_dir, new_instance=new_instance, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_state=compiler.state, tenant=tenant, diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index afa70a2b8b7..9765d59c89f 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -467,6 +467,7 @@ async def run_server( default_auth_method=args.default_auth_method, testmode=args.testmode, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_pool_size=args.compiler_pool_size, compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant, diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index 3b1fe928050..3cf84e5c633 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -840,13 +840,21 @@ cdef class HttpProtocol: config = self.tenant.get_sys_config().get('cors_allow_origins') allowed_origins = config.value if config else None + overrides = self.server.get_cors_always_allowed_origins() - if allowed_origins is None: + if allowed_origins is None and overrides == []: return False origin = request.origin.decode() if request.origin else None + origin_allowed = origin is not None and ( - origin in allowed_origins or '*' in allowed_origins) + any( + override.match(origin) if isinstance(override, re.Pattern) + else origin == override + for override in overrides + ) + or (origin in allowed_origins or '*' in allowed_origins) + ) if origin_allowed: response.custom_headers['Access-Control-Allow-Origin'] = origin diff --git a/edb/server/server.py b/edb/server/server.py index 6fa8d595375..5885cfd684c 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -20,6 +20,7 @@ from __future__ import annotations +import re from typing import ( Any, Callable, @@ -160,6 +161,7 @@ def __init__( default_auth_method: srvargs.ServerAuthMethods = ( srvargs.DEFAULT_AUTH_METHODS), admin_ui: bool = False, + cors_always_allowed_origins: Optional[str] = None, disable_dynamic_system_config: bool = False, compiler_state: edbcompiler.CompilerState, use_monitor_fs: bool = False, @@ -252,6 +254,15 @@ def __init__( self._admin_ui = admin_ui + self._cors_always_allowed_origins = [ + re.compile( + '^' + origin + .replace('.', '\\.') + .replace('*', '.*') + '$' + ) if '*' in origin else origin + for origin in cors_always_allowed_origins.split(',') + ] if cors_always_allowed_origins else [] + self._file_watch_handles = [] self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None @@ -308,6 +319,9 @@ def in_test_mode(self): def is_admin_ui_enabled(self): return self._admin_ui + def get_cors_always_allowed_origins(self): + return self._cors_always_allowed_origins + def on_binary_client_created(self) -> str: self._binary_proto_id_counter += 1