Skip to content

Commit

Permalink
Allow multiple authentication methods per transport in `--default-aut…
Browse files Browse the repository at this point in the history
…h-method` (#7224)

Currently, `--default-auth-method` only allows one default method per
transport. This isn't very flexible and so allow multiple authentication
methods to be tried in sequence (according to the specified order in
`--default-auth-method`).

Technically, this infrastructure also allows trying multiple configured
`cfg::Auth` methods in order of `.priority`, but that is a bigger (and
breaking) change so I left it out.
  • Loading branch information
elprans authored Apr 19, 2024
1 parent 3c75d6c commit d82e14f
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 125 deletions.
75 changes: 45 additions & 30 deletions edb/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,29 +167,31 @@ class ServerAuthMethods:

def __init__(
self,
methods: Mapping[ServerConnTransport, ServerAuthMethod],
methods: Mapping[ServerConnTransport, list[ServerAuthMethod]],
) -> None:
self._methods = dict(methods)

def get(self, transport: ServerConnTransport) -> ServerAuthMethod:
def get(self, transport: ServerConnTransport) -> list[ServerAuthMethod]:
return self._methods[transport]

def items(self) -> ItemsView[ServerConnTransport, ServerAuthMethod]:
def items(self) -> ItemsView[ServerConnTransport, list[ServerAuthMethod]]:
return self._methods.items()

def __str__(self):
return ','.join(
f'{t.lower()}:{m.lower()}' for t, m in self._methods.items()
f'{t.lower()}:{'/'.join(m.lower() for m in mm)}'
for t, mm in self._methods.items()
)


DEFAULT_AUTH_METHODS = ServerAuthMethods({
ServerConnTransport.TCP: ServerAuthMethod.Scram,
ServerConnTransport.TCP_PG: ServerAuthMethod.Scram,
ServerConnTransport.HTTP: ServerAuthMethod.JWT,
ServerConnTransport.SIMPLE_HTTP: ServerAuthMethod.Password,
ServerConnTransport.HTTP_METRICS: ServerAuthMethod.Auto,
ServerConnTransport.HTTP_HEALTH: ServerAuthMethod.Auto,
ServerConnTransport.TCP: [ServerAuthMethod.Scram],
ServerConnTransport.TCP_PG: [ServerAuthMethod.Scram],
ServerConnTransport.HTTP: [ServerAuthMethod.JWT],
ServerConnTransport.SIMPLE_HTTP: [
ServerAuthMethod.Password, ServerAuthMethod.JWT],
ServerConnTransport.HTTP_METRICS: [ServerAuthMethod.Auto],
ServerConnTransport.HTTP_HEALTH: [ServerAuthMethod.Auto],
})


Expand Down Expand Up @@ -516,7 +518,7 @@ def _validate_default_auth_method(
):
continue

methods[t] = method
methods[t] = [method]
elif "," not in value and ":" not in value:
raise click.BadParameter(
f"invalid authentication method: {value}, "
Expand All @@ -531,23 +533,26 @@ def _validate_default_auth_method(
}
for transport_spec in transport_specs:
transport_spec = transport_spec.strip()
transport_name, _, method_name = transport_spec.partition(':')
if not method_name:
transport_name, _, method_names = transport_spec.partition(':')
if not method_names:
raise click.BadParameter(
"format is <transport>:<method>[,...]")
"format is <transport>:<method>[/method...][,...]")
transport = transport_names.get(transport_name.lower())
if not transport:
raise click.BadParameter(
f"invalid connection transport: {transport_name}, "
f"supported values are: {', '.join(transport_names)})"
)
method = names.get(method_name)
if not method:
raise click.BadParameter(
f"invalid authentication method: {method_name}, "
f"supported values are: {', '.join(names)})"
)
methods[transport] = method
transport_methods = []
for method_name in method_names.split('/'):
method = names.get(method_name)
if not method:
raise click.BadParameter(
f"invalid authentication method: {method_name}, "
f"supported values are: {', '.join(names)})"
)
transport_methods.append(method)
methods[transport] = transport_methods

return ServerAuthMethods(methods)

Expand Down Expand Up @@ -1201,7 +1206,7 @@ def parse_args(**kwargs: Any):
kwargs['http_endpoint_security'] = 'optional'
if not kwargs['default_auth_method']:
kwargs['default_auth_method'] = ServerAuthMethods({
t: ServerAuthMethod.Trust
t: [ServerAuthMethod.Trust]
for t in ServerConnTransport.__members__.values()
})
if kwargs['tls_cert_mode'] == 'default':
Expand All @@ -1210,10 +1215,11 @@ def parse_args(**kwargs: Any):
elif not kwargs['default_auth_method']:
kwargs['default_auth_method'] = DEFAULT_AUTH_METHODS

methods = dict(kwargs['default_auth_method'].items())
transport_methods = dict(kwargs['default_auth_method'].items())
for transport in ServerConnTransport.__members__.values():
method = methods[transport]
if method is ServerAuthMethod.Auto:
methods = transport_methods[transport]
if ServerAuthMethod.Auto in methods:
pos = methods.index(ServerAuthMethod.Auto)
if transport in (
ServerConnTransport.HTTP_METRICS,
ServerConnTransport.HTTP_HEALTH,
Expand All @@ -1222,24 +1228,33 @@ def parse_args(**kwargs: Any):
method = ServerAuthMethod.Trust
else:
method = ServerAuthMethod.mTLS
methods[pos] = method
else:
method = DEFAULT_AUTH_METHODS.get(transport)
methods[transport] = method
methods = (
methods[:pos]
+ DEFAULT_AUTH_METHODS.get(transport)
+ methods[pos + 1:]
)
transport_methods[transport] = [method]
elif transport in (
ServerConnTransport.HTTP_METRICS,
ServerConnTransport.HTTP_HEALTH,
):
if method is ServerAuthMethod.mTLS:
if ServerAuthMethod.mTLS in methods:
if kwargs['tls_client_ca_file'] is None:
abort('--tls-client-ca-file is required '
'for mTLS authentication')
elif method is not ServerAuthMethod.Trust:

if not all(
m is ServerAuthMethod.Trust or m is ServerAuthMethod.mTLS
for m in methods
):
abort(
f'--default-auth-method of {transport} can only be one '
f'of: {ServerAuthMethod.Trust}, {ServerAuthMethod.mTLS} '
f'or {ServerAuthMethod.Auto}'
)
kwargs['default_auth_method'] = ServerAuthMethods(methods)
kwargs['default_auth_method'] = ServerAuthMethods(transport_methods)

if kwargs['binary_endpoint_security'] == 'default':
kwargs['binary_endpoint_security'] = 'tls'
Expand Down
58 changes: 39 additions & 19 deletions edb/server/protocol/frontend.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -596,29 +596,49 @@ cdef class FrontendConnection(AbstractFrontendConnection):
# The user has already been authenticated by other means
# (such as the ability to write to a protected socket).
if self._external_auth:
authmethod_name = 'Trust'
authmethods = [
self.server.config_settings.get_type_by_name('cfg::Trust')()
]
else:
authmethod = await self.tenant.get_auth_method(
authmethods = await self.tenant.get_auth_methods(
user, self._transport_proto)

auth_errors = {}

for authmethod in authmethods:
authmethod_name = authmethod._tspec.name.split('::')[1]

if authmethod_name == 'SCRAM':
await self._auth_scram(user)
elif authmethod_name == 'JWT':
self._auth_jwt(user, database, params)
elif authmethod_name == 'Trust':
self._auth_trust(user)
elif authmethod_name == 'Password':
raise errors.AuthenticationError(
'authentication failed: '
'Simple password authentication required but it is only '
'supported for HTTP endpoints'
)
elif authmethod_name == 'mTLS':
auth_helpers.auth_mtls_with_user(self._transport, user)
else:
raise errors.InternalServerError(
f'unimplemented auth method: {authmethod_name}')
try:
if authmethod_name == 'SCRAM':
await self._auth_scram(user)
elif authmethod_name == 'JWT':
self._auth_jwt(user, database, params)
elif authmethod_name == 'Trust':
self._auth_trust(user)
elif authmethod_name == 'Password':
raise errors.AuthenticationError(
'authentication failed: '
'Simple password authentication required but it is '
'only supported for HTTP endpoints'
)
elif authmethod_name == 'mTLS':
auth_helpers.auth_mtls_with_user(self._transport, user)
else:
raise errors.InternalServerError(
f'unimplemented auth method: {authmethod_name}')
except errors.AuthenticationError as e:
auth_errors[authmethod_name] = e
else:
break

if len(auth_errors) == len(authmethods):
if len(auth_errors) > 1:
desc = "; ".join(
f"{k}: {e.args[0]}" for k, e in auth_errors.items())
raise errors.AuthenticationError(
f"all authentication methods failed: {desc}")
else:
raise next(iter(auth_errors.values()))

cdef WriteBuffer _make_authentication_sasl_initial(self, list methods):
raise NotImplementedError
Expand Down
137 changes: 88 additions & 49 deletions edb/server/protocol/protocol.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#

# This source file is part of the EdgeDB open source project.
#
# Copyright 2021-present MagicStack Inc. and the EdgeDB authors.
Expand Down Expand Up @@ -865,43 +865,65 @@ cdef class HttpProtocol:
username, opt_password = auth_helpers.extract_http_user(
scheme, auth_payload, request.params)

# Fetch the configured auth method
authmethod = await self.tenant.get_auth_method(
# Fetch the configured auth methods
authmethods = await self.tenant.get_auth_methods(
username, srvargs.ServerConnTransport.SIMPLE_HTTP)
authmethod_name = authmethod._tspec.name.split('::')[1]

# If the auth method and the provided auth information match,
# try to resolve the authentication.
if authmethod_name == 'JWT' and scheme == 'bearer':
if not self.is_tls:
raise errors.AuthenticationError(
'JWT HTTP auth must use HTTPS')
auth_errors = {}

for authmethod in authmethods:
authmethod_name = authmethod._tspec.name.split('::')[1]
try:
# If the auth method and the provided auth information
# match, try to resolve the authentication.
if authmethod_name == 'JWT' and scheme == 'bearer':
if not self.is_tls:
raise errors.AuthenticationError(
'JWT HTTP auth must use HTTPS')

auth_helpers.auth_jwt(
self.tenant, auth_payload, username, dbname)
elif authmethod_name == 'Password' and scheme == 'basic':
if not self.is_tls:
raise errors.AuthenticationError(
'Basic HTTP auth must use HTTPS')

auth_helpers.auth_basic(
self.tenant, username, opt_password)
elif authmethod_name == 'Trust':
pass
elif authmethod_name == 'SCRAM':
raise errors.AuthenticationError(
'authentication failed: '
'SCRAM authentication required but not '
'supported for HTTP'
)
elif authmethod_name == 'mTLS':
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls_with_user(
self.transport, username)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')

auth_helpers.auth_jwt(
self.tenant, auth_payload, username, dbname)
elif authmethod_name == 'Password' and scheme == 'basic':
if not self.is_tls:
raise errors.AuthenticationError(
'Basic HTTP auth must use HTTPS')
except errors.AuthenticationError as e:
auth_errors[authmethod_name] = e

auth_helpers.auth_basic(self.tenant, username, opt_password)
elif authmethod_name == 'Trust':
pass
elif authmethod_name == 'SCRAM':
raise errors.AuthenticationError(
'authentication failed: '
'SCRAM authentication required but not supported for HTTP'
)
elif authmethod_name == 'mTLS':
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls_with_user(self.transport, username)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')
else:
break

if len(auth_errors) == len(authmethods):
if len(auth_errors) > 1:
desc = "; ".join(
f"{k}: {e.args[0]}" for k, e in auth_errors.items())
raise errors.AuthenticationError(
f"all authentication methods failed: {desc}")
else:
raise next(iter(auth_errors.values()))

except Exception as ex:
if debug.flags.server:
Expand All @@ -926,22 +948,39 @@ cdef class HttpProtocol:
transport: srvargs.ServerConnTransport,
):
try:
auth_method = self.server.get_default_auth_method(transport)
auth_methods = self.server.get_default_auth_methods(transport)
auth_errors = {}

for auth_method in auth_methods:
authmethod_name = auth_method._tspec.name.split('::')[1]
try:
# If the auth method and the provided auth information
# match, try to resolve the authentication.
if authmethod_name == 'Trust':
pass
elif authmethod_name == 'mTLS':
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls(self.transport)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')
except errors.AuthenticationError as e:
auth_errors[authmethod_name] = e
else:
break

# If the auth method and the provided auth information match,
# try to resolve the authentication.
if auth_method is srvargs.ServerAuthMethod.Trust:
pass
elif auth_method is srvargs.ServerAuthMethod.mTLS:
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls(self.transport)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')
if len(auth_errors) == len(auth_methods):
if len(auth_errors) > 1:
desc = "; ".join(
f"{k}: {e.args[0]}" for k, e in auth_errors.items())
raise errors.AuthenticationError(
f"all authentication methods failed: {desc}")
else:
raise next(iter(auth_errors.values()))

except Exception as ex:
if debug.flags.server:
Expand Down
Loading

0 comments on commit d82e14f

Please sign in to comment.