From d82e14fa1a31026c3e9f08d40ba8b9ca6d1d73c4 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 18 Apr 2024 20:37:42 -0700 Subject: [PATCH] Allow multiple authentication methods per transport in `--default-auth-method` (#7224) Currently, `--default-auth-method` only allows one default method per transport. This isn't very flexible and so allow multiple authentication methods to be tried in sequence (according to the specified order in `--default-auth-method`). Technically, this infrastructure also allows trying multiple configured `cfg::Auth` methods in order of `.priority`, but that is a bigger (and breaking) change so I left it out. --- edb/server/args.py | 75 ++++++++++------- edb/server/protocol/frontend.pyx | 58 ++++++++----- edb/server/protocol/protocol.pyx | 137 ++++++++++++++++++++----------- edb/server/server.py | 28 +++++-- edb/server/tenant.py | 17 ++-- tests/test_server_auth.py | 124 ++++++++++++++++++++++++---- 6 files changed, 314 insertions(+), 125 deletions(-) diff --git a/edb/server/args.py b/edb/server/args.py index 8771d9d8611..321593a02e7 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -167,29 +167,31 @@ class ServerAuthMethods: def __init__( self, - methods: Mapping[ServerConnTransport, ServerAuthMethod], + methods: Mapping[ServerConnTransport, list[ServerAuthMethod]], ) -> None: self._methods = dict(methods) - def get(self, transport: ServerConnTransport) -> ServerAuthMethod: + def get(self, transport: ServerConnTransport) -> list[ServerAuthMethod]: return self._methods[transport] - def items(self) -> ItemsView[ServerConnTransport, ServerAuthMethod]: + def items(self) -> ItemsView[ServerConnTransport, list[ServerAuthMethod]]: return self._methods.items() def __str__(self): return ','.join( - f'{t.lower()}:{m.lower()}' for t, m in self._methods.items() + f'{t.lower()}:{'/'.join(m.lower() for m in mm)}' + for t, mm in self._methods.items() ) DEFAULT_AUTH_METHODS = ServerAuthMethods({ - ServerConnTransport.TCP: ServerAuthMethod.Scram, - ServerConnTransport.TCP_PG: ServerAuthMethod.Scram, - ServerConnTransport.HTTP: ServerAuthMethod.JWT, - ServerConnTransport.SIMPLE_HTTP: ServerAuthMethod.Password, - ServerConnTransport.HTTP_METRICS: ServerAuthMethod.Auto, - ServerConnTransport.HTTP_HEALTH: ServerAuthMethod.Auto, + ServerConnTransport.TCP: [ServerAuthMethod.Scram], + ServerConnTransport.TCP_PG: [ServerAuthMethod.Scram], + ServerConnTransport.HTTP: [ServerAuthMethod.JWT], + ServerConnTransport.SIMPLE_HTTP: [ + ServerAuthMethod.Password, ServerAuthMethod.JWT], + ServerConnTransport.HTTP_METRICS: [ServerAuthMethod.Auto], + ServerConnTransport.HTTP_HEALTH: [ServerAuthMethod.Auto], }) @@ -516,7 +518,7 @@ def _validate_default_auth_method( ): continue - methods[t] = method + methods[t] = [method] elif "," not in value and ":" not in value: raise click.BadParameter( f"invalid authentication method: {value}, " @@ -531,23 +533,26 @@ def _validate_default_auth_method( } for transport_spec in transport_specs: transport_spec = transport_spec.strip() - transport_name, _, method_name = transport_spec.partition(':') - if not method_name: + transport_name, _, method_names = transport_spec.partition(':') + if not method_names: raise click.BadParameter( - "format is :[,...]") + "format is :[/method...][,...]") transport = transport_names.get(transport_name.lower()) if not transport: raise click.BadParameter( f"invalid connection transport: {transport_name}, " f"supported values are: {', '.join(transport_names)})" ) - method = names.get(method_name) - if not method: - raise click.BadParameter( - f"invalid authentication method: {method_name}, " - f"supported values are: {', '.join(names)})" - ) - methods[transport] = method + transport_methods = [] + for method_name in method_names.split('/'): + method = names.get(method_name) + if not method: + raise click.BadParameter( + f"invalid authentication method: {method_name}, " + f"supported values are: {', '.join(names)})" + ) + transport_methods.append(method) + methods[transport] = transport_methods return ServerAuthMethods(methods) @@ -1201,7 +1206,7 @@ def parse_args(**kwargs: Any): kwargs['http_endpoint_security'] = 'optional' if not kwargs['default_auth_method']: kwargs['default_auth_method'] = ServerAuthMethods({ - t: ServerAuthMethod.Trust + t: [ServerAuthMethod.Trust] for t in ServerConnTransport.__members__.values() }) if kwargs['tls_cert_mode'] == 'default': @@ -1210,10 +1215,11 @@ def parse_args(**kwargs: Any): elif not kwargs['default_auth_method']: kwargs['default_auth_method'] = DEFAULT_AUTH_METHODS - methods = dict(kwargs['default_auth_method'].items()) + transport_methods = dict(kwargs['default_auth_method'].items()) for transport in ServerConnTransport.__members__.values(): - method = methods[transport] - if method is ServerAuthMethod.Auto: + methods = transport_methods[transport] + if ServerAuthMethod.Auto in methods: + pos = methods.index(ServerAuthMethod.Auto) if transport in ( ServerConnTransport.HTTP_METRICS, ServerConnTransport.HTTP_HEALTH, @@ -1222,24 +1228,33 @@ def parse_args(**kwargs: Any): method = ServerAuthMethod.Trust else: method = ServerAuthMethod.mTLS + methods[pos] = method else: - method = DEFAULT_AUTH_METHODS.get(transport) - methods[transport] = method + methods = ( + methods[:pos] + + DEFAULT_AUTH_METHODS.get(transport) + + methods[pos + 1:] + ) + transport_methods[transport] = [method] elif transport in ( ServerConnTransport.HTTP_METRICS, ServerConnTransport.HTTP_HEALTH, ): - if method is ServerAuthMethod.mTLS: + if ServerAuthMethod.mTLS in methods: if kwargs['tls_client_ca_file'] is None: abort('--tls-client-ca-file is required ' 'for mTLS authentication') - elif method is not ServerAuthMethod.Trust: + + if not all( + m is ServerAuthMethod.Trust or m is ServerAuthMethod.mTLS + for m in methods + ): abort( f'--default-auth-method of {transport} can only be one ' f'of: {ServerAuthMethod.Trust}, {ServerAuthMethod.mTLS} ' f'or {ServerAuthMethod.Auto}' ) - kwargs['default_auth_method'] = ServerAuthMethods(methods) + kwargs['default_auth_method'] = ServerAuthMethods(transport_methods) if kwargs['binary_endpoint_security'] == 'default': kwargs['binary_endpoint_security'] = 'tls' diff --git a/edb/server/protocol/frontend.pyx b/edb/server/protocol/frontend.pyx index 1d7829c4de3..d09307348a6 100644 --- a/edb/server/protocol/frontend.pyx +++ b/edb/server/protocol/frontend.pyx @@ -596,29 +596,49 @@ cdef class FrontendConnection(AbstractFrontendConnection): # The user has already been authenticated by other means # (such as the ability to write to a protected socket). if self._external_auth: - authmethod_name = 'Trust' + authmethods = [ + self.server.config_settings.get_type_by_name('cfg::Trust')() + ] else: - authmethod = await self.tenant.get_auth_method( + authmethods = await self.tenant.get_auth_methods( user, self._transport_proto) + + auth_errors = {} + + for authmethod in authmethods: authmethod_name = authmethod._tspec.name.split('::')[1] - if authmethod_name == 'SCRAM': - await self._auth_scram(user) - elif authmethod_name == 'JWT': - self._auth_jwt(user, database, params) - elif authmethod_name == 'Trust': - self._auth_trust(user) - elif authmethod_name == 'Password': - raise errors.AuthenticationError( - 'authentication failed: ' - 'Simple password authentication required but it is only ' - 'supported for HTTP endpoints' - ) - elif authmethod_name == 'mTLS': - auth_helpers.auth_mtls_with_user(self._transport, user) - else: - raise errors.InternalServerError( - f'unimplemented auth method: {authmethod_name}') + try: + if authmethod_name == 'SCRAM': + await self._auth_scram(user) + elif authmethod_name == 'JWT': + self._auth_jwt(user, database, params) + elif authmethod_name == 'Trust': + self._auth_trust(user) + elif authmethod_name == 'Password': + raise errors.AuthenticationError( + 'authentication failed: ' + 'Simple password authentication required but it is ' + 'only supported for HTTP endpoints' + ) + elif authmethod_name == 'mTLS': + auth_helpers.auth_mtls_with_user(self._transport, user) + else: + raise errors.InternalServerError( + f'unimplemented auth method: {authmethod_name}') + except errors.AuthenticationError as e: + auth_errors[authmethod_name] = e + else: + break + + if len(auth_errors) == len(authmethods): + if len(auth_errors) > 1: + desc = "; ".join( + f"{k}: {e.args[0]}" for k, e in auth_errors.items()) + raise errors.AuthenticationError( + f"all authentication methods failed: {desc}") + else: + raise next(iter(auth_errors.values())) cdef WriteBuffer _make_authentication_sasl_initial(self, list methods): raise NotImplementedError diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index d6b43e8252c..772c6cd5550 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -1,4 +1,4 @@ -# + # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. @@ -865,43 +865,65 @@ cdef class HttpProtocol: username, opt_password = auth_helpers.extract_http_user( scheme, auth_payload, request.params) - # Fetch the configured auth method - authmethod = await self.tenant.get_auth_method( + # Fetch the configured auth methods + authmethods = await self.tenant.get_auth_methods( username, srvargs.ServerConnTransport.SIMPLE_HTTP) - authmethod_name = authmethod._tspec.name.split('::')[1] - # If the auth method and the provided auth information match, - # try to resolve the authentication. - if authmethod_name == 'JWT' and scheme == 'bearer': - if not self.is_tls: - raise errors.AuthenticationError( - 'JWT HTTP auth must use HTTPS') + auth_errors = {} + + for authmethod in authmethods: + authmethod_name = authmethod._tspec.name.split('::')[1] + try: + # If the auth method and the provided auth information + # match, try to resolve the authentication. + if authmethod_name == 'JWT' and scheme == 'bearer': + if not self.is_tls: + raise errors.AuthenticationError( + 'JWT HTTP auth must use HTTPS') + + auth_helpers.auth_jwt( + self.tenant, auth_payload, username, dbname) + elif authmethod_name == 'Password' and scheme == 'basic': + if not self.is_tls: + raise errors.AuthenticationError( + 'Basic HTTP auth must use HTTPS') + + auth_helpers.auth_basic( + self.tenant, username, opt_password) + elif authmethod_name == 'Trust': + pass + elif authmethod_name == 'SCRAM': + raise errors.AuthenticationError( + 'authentication failed: ' + 'SCRAM authentication required but not ' + 'supported for HTTP' + ) + elif authmethod_name == 'mTLS': + if ( + self.http_endpoint_security + is srvargs.ServerEndpointSecurityMode.Tls + or self.is_tls + ): + auth_helpers.auth_mtls_with_user( + self.transport, username) + else: + raise errors.AuthenticationError( + 'authentication failed: wrong method used') - auth_helpers.auth_jwt( - self.tenant, auth_payload, username, dbname) - elif authmethod_name == 'Password' and scheme == 'basic': - if not self.is_tls: - raise errors.AuthenticationError( - 'Basic HTTP auth must use HTTPS') + except errors.AuthenticationError as e: + auth_errors[authmethod_name] = e - auth_helpers.auth_basic(self.tenant, username, opt_password) - elif authmethod_name == 'Trust': - pass - elif authmethod_name == 'SCRAM': - raise errors.AuthenticationError( - 'authentication failed: ' - 'SCRAM authentication required but not supported for HTTP' - ) - elif authmethod_name == 'mTLS': - if ( - self.http_endpoint_security - is srvargs.ServerEndpointSecurityMode.Tls - or self.is_tls - ): - auth_helpers.auth_mtls_with_user(self.transport, username) - else: - raise errors.AuthenticationError( - 'authentication failed: wrong method used') + else: + break + + if len(auth_errors) == len(authmethods): + if len(auth_errors) > 1: + desc = "; ".join( + f"{k}: {e.args[0]}" for k, e in auth_errors.items()) + raise errors.AuthenticationError( + f"all authentication methods failed: {desc}") + else: + raise next(iter(auth_errors.values())) except Exception as ex: if debug.flags.server: @@ -926,22 +948,39 @@ cdef class HttpProtocol: transport: srvargs.ServerConnTransport, ): try: - auth_method = self.server.get_default_auth_method(transport) + auth_methods = self.server.get_default_auth_methods(transport) + auth_errors = {} + + for auth_method in auth_methods: + authmethod_name = auth_method._tspec.name.split('::')[1] + try: + # If the auth method and the provided auth information + # match, try to resolve the authentication. + if authmethod_name == 'Trust': + pass + elif authmethod_name == 'mTLS': + if ( + self.http_endpoint_security + is srvargs.ServerEndpointSecurityMode.Tls + or self.is_tls + ): + auth_helpers.auth_mtls(self.transport) + else: + raise errors.AuthenticationError( + 'authentication failed: wrong method used') + except errors.AuthenticationError as e: + auth_errors[authmethod_name] = e + else: + break - # If the auth method and the provided auth information match, - # try to resolve the authentication. - if auth_method is srvargs.ServerAuthMethod.Trust: - pass - elif auth_method is srvargs.ServerAuthMethod.mTLS: - if ( - self.http_endpoint_security - is srvargs.ServerEndpointSecurityMode.Tls - or self.is_tls - ): - auth_helpers.auth_mtls(self.transport) - else: - raise errors.AuthenticationError( - 'authentication failed: wrong method used') + if len(auth_errors) == len(auth_methods): + if len(auth_errors) > 1: + desc = "; ".join( + f"{k}: {e.args[0]}" for k, e in auth_errors.items()) + raise errors.AuthenticationError( + f"all authentication methods failed: {desc}") + else: + raise next(iter(auth_errors.values())) except Exception as ex: if debug.flags.server: diff --git a/edb/server/server.py b/edb/server/server.py index 9a6e261b86f..ebcb48f3e41 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -234,7 +234,9 @@ def __init__( self._jws_key: jwk.JWK | None = None self._jws_keys_newly_generated = False - self._default_auth_method = default_auth_method + self._default_auth_method_spec = default_auth_method + self._default_auth_methods = self._get_auth_method_types( + default_auth_method) self._binary_endpoint_security = binary_endpoint_security self._http_endpoint_security = http_endpoint_security @@ -248,6 +250,22 @@ def __init__( self._disable_dynamic_system_config = disable_dynamic_system_config self._report_config_typedesc = {} + def _get_auth_method_types( + self, + auth_methods_spec: srvargs.ServerAuthMethods, + ) -> dict[srvargs.ServerConnTransport, list[config.CompositeConfigType]]: + mapping = {} + for transport, methods in auth_methods_spec.items(): + result = [] + for method in methods: + auth_type = self.config_settings.get_type_by_name( + f'cfg::{method.value}' + ) + result.append(auth_type()) + mapping[transport] = result + + return mapping + async def _request_stats_logger(self): last_seen = -1 while True: @@ -1063,7 +1081,7 @@ def get_debug_info(self): params=dict( dev_mode=self._devmode, test_mode=self._testmode, - default_auth_method=str(self._default_auth_method), + default_auth_methods=str(self._default_auth_method_spec), listen_hosts=self._listen_hosts, listen_port=self._listen_port, ), @@ -1081,10 +1099,10 @@ def get_report_config_typedesc( ) -> dict[defines.ProtocolVersion, bytes]: return self._report_config_typedesc - def get_default_auth_method( + def get_default_auth_methods( self, transport: srvargs.ServerConnTransport - ) -> srvargs.ServerAuthMethod: - return self._default_auth_method.get(transport) + ) -> list[config.CompositeConfigType]: + return self._default_auth_methods.get(transport, []) def get_std_schema(self) -> s_schema.Schema: return self._std_schema diff --git a/edb/server/tenant.py b/edb/server/tenant.py index d613d095ca7..cab56f7e6d9 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -1090,12 +1090,13 @@ def resolve_branch_name( assert database is not None return database - async def get_auth_method( + async def get_auth_methods( self, user: str, transport: srvargs.ServerConnTransport, - ) -> Any: + ) -> list[config.CompositeConfigType]: authlist = self._sys_auth + methods = [] if authlist: for auth in authlist: @@ -1105,13 +1106,13 @@ async def get_auth_method( ) if match: - return auth.method + methods.append(auth.method) + break - default_method = self._server.get_default_auth_method(transport) - auth_type = self._server.config_settings.get_type_by_name( - f'cfg::{default_method.value}' - ) - return auth_type() + if not methods: + methods = self._server.get_default_auth_methods(transport) + + return methods async def new_dbview( self, diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index 279cfb2e1df..6ed75ee4cfd 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -310,8 +310,10 @@ async def _basic_http_request( async def _http_request( self, server, + *, sk=None, username='edgedb', + password=None, db='edgedb', proto='edgeql', client_cert_file=None, @@ -326,6 +328,9 @@ async def _http_request( headers = {'X-EdgeDB-User': username} if sk is not None: headers['Authorization'] = f'bearer {sk}' + elif password is not None: + headers['Authorization'] = self.make_auth_header( + username, password) return self.http_con_request( con, path=f'/db/{db}/{proto}', @@ -336,12 +341,31 @@ async def _http_request( ) async def _jwt_http_request( - self, server, sk, username='edgedb', db='edgedb', proto='edgeql' + self, + server, + *, + sk=None, + username='edgedb', + password=None, + db='edgedb', + proto='edgeql', ): - return await self._http_request(server, sk, username, db, proto) + return await self._http_request( + server, + sk=sk, + username=username, + password=password, + db=db, + proto=proto, + ) - def _jwt_gql_request(self, server, sk): - return self._jwt_http_request(server, sk, proto='graphql') + def _jwt_gql_request(self, server, *, sk=None, password=None): + return self._jwt_http_request( + server, + sk=sk, + password=password, + proto='graphql', + ) @unittest.skipIf( "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, @@ -399,9 +423,9 @@ async def test_server_auth_jwt_1(self): ): await sd.connect(secret_key=corrupt_sk) - body, _, code = await self._jwt_http_request(sd, corrupt_sk) + body, _, code = await self._jwt_http_request(sd, sk=corrupt_sk) self.assertEqual(code, 401, f"Wrong result: {body}") - body, _, code = await self._jwt_gql_request(sd, corrupt_sk) + body, _, code = await self._jwt_gql_request(sd, sk=corrupt_sk) self.assertEqual(code, 401, f"Wrong result: {body}") # Try to mess up the *signature* part of it @@ -413,19 +437,19 @@ async def test_server_auth_jwt_1(self): await sd.connect(secret_key=wrong_sk) body, _, code = await self._jwt_http_request( - sd, corrupt_sk, db='non_existant') + sd, sk=corrupt_sk, db='non_existant') self.assertEqual(code, 401, f"Wrong result: {body}") # Good key (control check, mostly) - body, _, code = await self._jwt_http_request(sd, base_sk) + body, _, code = await self._jwt_http_request(sd, sk=base_sk) self.assertEqual(code, 200, f"Wrong result: {body}") # Good key but nonexistant user body, _, code = await self._jwt_http_request( - sd, base_sk, username='elonmusk') + sd, sk=base_sk, username='elonmusk') self.assertEqual(code, 401, f"Wrong result: {body}") # Good key but user needs password auth body, _, code = await self._jwt_http_request( - sd, base_sk, username='foo') + sd, sk=base_sk, username='foo') self.assertEqual(code, 401, f"Wrong result: {body}") good_keys = [ @@ -442,9 +466,9 @@ async def test_server_auth_jwt_1(self): conn = await sd.connect(secret_key=sk) await conn.aclose() - body, _, code = await self._jwt_http_request(sd, sk) + body, _, code = await self._jwt_http_request(sd, sk=sk) self.assertEqual(code, 200, f"Wrong result: {body}") - body, _, code = await self._jwt_gql_request(sd, sk) + body, _, code = await self._jwt_gql_request(sd, sk=sk) self.assertEqual(code, 200, f"Wrong result: {body}") bad_keys = { @@ -469,9 +493,9 @@ async def test_server_auth_jwt_1(self): ): await sd.connect(secret_key=sk) - body, _, code = await self._jwt_http_request(sd, sk) + body, _, code = await self._jwt_http_request(sd, sk=sk) self.assertEqual(code, 401, f"Wrong result: {body}") - body, _, code = await self._jwt_gql_request(sd, sk) + body, _, code = await self._jwt_gql_request(sd, sk=sk) self.assertEqual(code, 401, f"Wrong result: {body}") @unittest.skipIf( @@ -560,6 +584,78 @@ async def test_server_auth_jwt_2(self): ): await sd.connect(secret_key=sk) + @unittest.skipIf( + "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, + "cannot use CONFIGURE INSTANCE in multi-tenant mode", + ) + async def test_server_auth_multiple_methods(self): + jwk_fd, jwk_file = tempfile.mkstemp() + + key = jwcrypto.jwk.JWK(generate='EC') + with open(jwk_fd, "wb") as f: + f.write(key.export_to_pem(private_key=True, password=None)) + jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) + async with tb.start_edgedb_server( + jws_key_file=pathlib.Path(jwk_file), + default_auth_method=args.ServerAuthMethods({ + args.ServerConnTransport.TCP: [ + args.ServerAuthMethod.JWT, + args.ServerAuthMethod.Scram, + ], + args.ServerConnTransport.SIMPLE_HTTP: [ + args.ServerAuthMethod.Password, + args.ServerAuthMethod.JWT, + ], + }), + extra_args=["--instance-name=localtest"], + ) as sd: + base_sk = secretkey.generate_secret_key(jwk) + conn = await sd.connect(secret_key=base_sk) + await conn.execute(''' + CREATE EXTENSION edgeql_http; + CREATE EXTENSION graphql; + ''') + await conn.aclose() + + # bad secret keys + with self.assertRaisesRegex( + edgedb.AuthenticationError, + 'authentication failed: malformed JWT', + ): + await sd.connect(secret_key='wrong', password=None) + + # But connecting with the default password should still work + # because we are defaulting to Scram/JWT + c1 = await sd.connect(secret_key='wrong') + await c1.aclose() + + sk = secretkey.generate_secret_key(jwk) + + body, _, code = await self._jwt_http_request(sd, sk=sk) + self.assertEqual(code, 200, f"Wrong result: {body}") + body, _, code = await self._jwt_gql_request(sd, sk=sk) + self.assertEqual(code, 200, f"Wrong result: {body}") + + corrupt_sk = sk[:50] + "0" + sk[51:] + body, _, code = await self._jwt_http_request(sd, sk=corrupt_sk) + self.assertEqual(code, 401, f"Wrong result: {body}") + body, _, code = await self._jwt_gql_request(sd, sk=corrupt_sk) + self.assertEqual(code, 401, f"Wrong result: {body}") + + body, _, code = await self._jwt_http_request( + sd, password=sd.password) + self.assertEqual(code, 200, f"Wrong result: {body}") + body, _, code = await self._jwt_gql_request( + sd, password=sd.password) + self.assertEqual(code, 200, f"Wrong result: {body}") + + body, _, code = await self._jwt_http_request( + sd, password="wrong password") + self.assertEqual(code, 401, f"Wrong result: {body}") + body, _, code = await self._jwt_gql_request( + sd, password="wrong password") + self.assertEqual(code, 401, f"Wrong result: {body}") + async def test_server_auth_in_transaction(self): if not self.has_create_role: self.skipTest('create role is not supported by the backend')