From 9653f62ccd320aa9167c8e54d3a035b1e2fa35b5 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Wed, 5 Feb 2025 18:30:57 +0000 Subject: [PATCH] Add a `cors-always-allowed-origins` option (#8233) For cloud instances we currently default to setting the `cors_allow_origins` config to `*` to let the cloud ui make queries to the instance. But the user can change that config, which breaks the UI, and also for security it's probably better to leave the default `cors_allow_origins` config empty. Since we want to allow the cloud UI to always be able to query the instance, regardless of how the user configures `cors_allow_origins`, this server option will allow us to set the cloud ui as an always allowed cors origin. Supporting multiple origins and wildcards is to handle the gel rename (allow both our new and old cloud ui urls to work) and for preview ui builds. --- edb/server/args.py | 12 ++++++++++++ edb/server/main.py | 1 + edb/server/multitenant.py | 1 + edb/server/protocol/protocol.pyx | 12 ++++++++++-- edb/server/server.py | 14 ++++++++++++++ 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/edb/server/args.py b/edb/server/args.py index ceed1d4f6f4..8c630d3e3c1 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -292,6 +292,8 @@ class ServerConfig(NamedTuple): admin_ui: bool + cors_always_allowed_origins: Optional[str] + class PathPath(click.Path): name = 'path' @@ -1089,6 +1091,16 @@ def resolve_envvar_value(self, ctx: click.Context): ), default='default', help='Enable admin UI.'), + click.option( + '--cors-always-allowed-origins', + envvar="GEL_SERVER_CORS_ALWAYS_ALLOWED_ORIGINS", + cls=EnvvarResolver, + hidden=True, + help='A comma separated list of origins to always allow CORS requests ' + 'from regardless of the `cors_allow_orgin` config. The `*` ' + 'character can be used as a wildcard. Intended for use by cloud ' + 'to always allow the cloud UI to make requests to the instance.' + ), click.option( '--disable-dynamic-system-config', is_flag=True, envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG", diff --git a/edb/server/main.py b/edb/server/main.py index 302616a3b16..ac32ac803a7 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -248,6 +248,7 @@ async def _run_server( pidfile_dir=args.pidfile_dir, new_instance=new_instance, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_state=compiler.state, tenant=tenant, diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index afa70a2b8b7..9765d59c89f 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -467,6 +467,7 @@ async def run_server( default_auth_method=args.default_auth_method, testmode=args.testmode, admin_ui=args.admin_ui, + cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_pool_size=args.compiler_pool_size, compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant, diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index 3b1fe928050..3cf84e5c633 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -840,13 +840,21 @@ cdef class HttpProtocol: config = self.tenant.get_sys_config().get('cors_allow_origins') allowed_origins = config.value if config else None + overrides = self.server.get_cors_always_allowed_origins() - if allowed_origins is None: + if allowed_origins is None and overrides == []: return False origin = request.origin.decode() if request.origin else None + origin_allowed = origin is not None and ( - origin in allowed_origins or '*' in allowed_origins) + any( + override.match(origin) if isinstance(override, re.Pattern) + else origin == override + for override in overrides + ) + or (origin in allowed_origins or '*' in allowed_origins) + ) if origin_allowed: response.custom_headers['Access-Control-Allow-Origin'] = origin diff --git a/edb/server/server.py b/edb/server/server.py index 664d6402a8c..ddd841803a9 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -20,6 +20,7 @@ from __future__ import annotations +import re from typing import ( Any, Callable, @@ -160,6 +161,7 @@ def __init__( default_auth_method: srvargs.ServerAuthMethods = ( srvargs.DEFAULT_AUTH_METHODS), admin_ui: bool = False, + cors_always_allowed_origins: Optional[str] = None, disable_dynamic_system_config: bool = False, compiler_state: edbcompiler.CompilerState, use_monitor_fs: bool = False, @@ -252,6 +254,15 @@ def __init__( self._admin_ui = admin_ui + self._cors_always_allowed_origins = [ + re.compile( + '^' + origin + .replace('.', '\\.') + .replace('*', '.*') + '$' + ) if '*' in origin else origin + for origin in cors_always_allowed_origins.split(',') + ] if cors_always_allowed_origins else [] + self._file_watch_handles = [] self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None @@ -308,6 +319,9 @@ def in_test_mode(self): def is_admin_ui_enabled(self): return self._admin_ui + def get_cors_always_allowed_origins(self): + return self._cors_always_allowed_origins + def on_binary_client_created(self) -> str: self._binary_proto_id_counter += 1