Skip to content

Commit

Permalink
Add a cors-always-allowed-origins option (#8233)
Browse files Browse the repository at this point in the history
For cloud instances we currently default to setting the
`cors_allow_origins` config to `*` to let the cloud ui make queries to
the instance. But the user can change that config, which breaks the UI,
and also for security it's probably better to leave the default
`cors_allow_origins` config empty.

Since we want to allow the cloud UI to always be able to query the
instance, regardless of how the user configures `cors_allow_origins`,
this server option will allow us to set the cloud ui as an always
allowed cors origin. Supporting multiple origins and wildcards is to
handle the gel rename (allow both our new and old cloud ui urls to work)
and for preview ui builds.
  • Loading branch information
jaclarke authored and msullivan committed Feb 5, 2025
1 parent b5bc80c commit 1c8e13b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 2 deletions.
12 changes: 12 additions & 0 deletions edb/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class ServerConfig(NamedTuple):

admin_ui: bool

cors_always_allowed_origins: Optional[str]


class PathPath(click.Path):
name = 'path'
Expand Down Expand Up @@ -1089,6 +1091,16 @@ def resolve_envvar_value(self, ctx: click.Context):
),
default='default',
help='Enable admin UI.'),
click.option(
'--cors-always-allowed-origins',
envvar="GEL_SERVER_CORS_ALWAYS_ALLOWED_ORIGINS",
cls=EnvvarResolver,
hidden=True,
help='A comma separated list of origins to always allow CORS requests '
'from regardless of the `cors_allow_orgin` config. The `*` '
'character can be used as a wildcard. Intended for use by cloud '
'to always allow the cloud UI to make requests to the instance.'
),
click.option(
'--disable-dynamic-system-config', is_flag=True,
envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG",
Expand Down
1 change: 1 addition & 0 deletions edb/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ async def _run_server(
pidfile_dir=args.pidfile_dir,
new_instance=new_instance,
admin_ui=args.admin_ui,
cors_always_allowed_origins=args.cors_always_allowed_origins,
disable_dynamic_system_config=args.disable_dynamic_system_config,
compiler_state=compiler.state,
tenant=tenant,
Expand Down
1 change: 1 addition & 0 deletions edb/server/multitenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ async def run_server(
default_auth_method=args.default_auth_method,
testmode=args.testmode,
admin_ui=args.admin_ui,
cors_always_allowed_origins=args.cors_always_allowed_origins,
disable_dynamic_system_config=args.disable_dynamic_system_config,
compiler_pool_size=args.compiler_pool_size,
compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant,
Expand Down
12 changes: 10 additions & 2 deletions edb/server/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -840,13 +840,21 @@ cdef class HttpProtocol:
config = self.tenant.get_sys_config().get('cors_allow_origins')

allowed_origins = config.value if config else None
overrides = self.server.get_cors_always_allowed_origins()

if allowed_origins is None:
if allowed_origins is None and overrides == []:
return False

origin = request.origin.decode() if request.origin else None

origin_allowed = origin is not None and (
origin in allowed_origins or '*' in allowed_origins)
any(
override.match(origin) if isinstance(override, re.Pattern)
else origin == override
for override in overrides
)
or (origin in allowed_origins or '*' in allowed_origins)
)

if origin_allowed:
response.custom_headers['Access-Control-Allow-Origin'] = origin
Expand Down
14 changes: 14 additions & 0 deletions edb/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


from __future__ import annotations
import re
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
default_auth_method: srvargs.ServerAuthMethods = (
srvargs.DEFAULT_AUTH_METHODS),
admin_ui: bool = False,
cors_always_allowed_origins: Optional[str] = None,
disable_dynamic_system_config: bool = False,
compiler_state: edbcompiler.CompilerState,
use_monitor_fs: bool = False,
Expand Down Expand Up @@ -252,6 +254,15 @@ def __init__(

self._admin_ui = admin_ui

self._cors_always_allowed_origins = [
re.compile(
'^' + origin
.replace('.', '\\.')
.replace('*', '.*') + '$'
) if '*' in origin else origin
for origin in cors_always_allowed_origins.split(',')
] if cors_always_allowed_origins else []

self._file_watch_handles = []
self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None

Expand Down Expand Up @@ -308,6 +319,9 @@ def in_test_mode(self):
def is_admin_ui_enabled(self):
return self._admin_ui

def get_cors_always_allowed_origins(self):
return self._cors_always_allowed_origins

def on_binary_client_created(self) -> str:
self._binary_proto_id_counter += 1

Expand Down

0 comments on commit 1c8e13b

Please sign in to comment.