From 7482017e6afefb225116758ad672236349acccbf Mon Sep 17 00:00:00 2001 From: Fantix King Date: Thu, 7 Mar 2024 17:32:21 -0500 Subject: [PATCH] Add mTLS support (#6460) * Add --tls-client-ca-file * Use --default-auth-method to configure auth for /metrics and /server * lso fixes an ISE with `--default-auth-method tcp:auto`. --- edb/server/args.py | 65 +++++++++- edb/server/main.py | 12 +- edb/server/multitenant.py | 11 +- edb/server/protocol/auth_helpers.pxd | 2 + edb/server/protocol/auth_helpers.pyx | 36 +++++- edb/server/protocol/binary.pxd | 1 + edb/server/protocol/binary.pyx | 9 +- edb/server/protocol/frontend.pyx | 2 + edb/server/protocol/protocol.pxd | 2 + edb/server/protocol/protocol.pyx | 69 ++++++++++- edb/server/server.py | 19 ++- edb/testbase/server.py | 88 +++++++++++++- tests/test_http_auth.py | 40 ++----- tests/test_server_auth.py | 171 +++++++++++++++++++++++++-- tests/test_server_ops.py | 24 ++++ 15 files changed, 487 insertions(+), 64 deletions(-) diff --git a/edb/server/args.py b/edb/server/args.py index 199a6277d45..8771d9d8611 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -258,6 +258,7 @@ class ServerConfig(NamedTuple): tls_cert_file: pathlib.Path tls_key_file: pathlib.Path tls_cert_mode: ServerTlsCertMode + tls_client_ca_file: Optional[pathlib.Path] jws_key_file: pathlib.Path jose_key_mode: JOSEKeyMode @@ -501,8 +502,21 @@ def _validate_default_auth_method( if method in {ServerAuthMethod.Auto, ServerAuthMethod.Scram}: pass else: - for m in methods: - methods[m] = method + for t in methods: + # HTTP_METRICS and HTTP_HEALTH support only mTLS, but for + # backward compatibility, default them to `auto` if unsupported + # method is passed explicitly. + if t in ( + ServerConnTransport.HTTP_METRICS, + ServerConnTransport.HTTP_HEALTH, + ): + if method not in ( + ServerAuthMethod.Trust, + ServerAuthMethod.mTLS, + ): + continue + + methods[t] = method elif "," not in value and ":" not in value: raise click.BadParameter( f"invalid authentication method: {value}, " @@ -815,6 +829,18 @@ def resolve_envvar_value(self, ctx: click.Context): '"require_file" when the --security option is set to "strict", ' 'and "generate_self_signed" when the --security option is set to ' '"insecure_dev_mode"'), + click.option( + '--tls-client-ca-file', + type=PathPath(), + envvar='EDGEDB_SERVER_TLS_CLIENT_CA_FILE', + help='Specifies a path to a file containing a TLS CA certificate to ' + 'verify client certificates on demand. When set, the default ' + 'authentication method of HTTP_METRICS(/metrics) and HTTP_HEALTH' + '(/server/*) will also become "mTLS", unless explicitly set in ' + '--default-auth-method. Note, the protection of such HTTP ' + 'endpoints is only complete if --http-endpoint-security is also ' + 'set to `tls`, or they are still accessible in plaintext HTTP.' + ), click.option( '--generate-self-signed-cert', type=bool, default=False, is_flag=True, help='DEPRECATED.\n\n' @@ -1174,16 +1200,47 @@ def parse_args(**kwargs: Any): if kwargs['http_endpoint_security'] == 'default': kwargs['http_endpoint_security'] = 'optional' if not kwargs['default_auth_method']: - kwargs['default_auth_method'] = { + kwargs['default_auth_method'] = ServerAuthMethods({ t: ServerAuthMethod.Trust for t in ServerConnTransport.__members__.values() - } + }) if kwargs['tls_cert_mode'] == 'default': kwargs['tls_cert_mode'] = 'generate_self_signed' elif not kwargs['default_auth_method']: kwargs['default_auth_method'] = DEFAULT_AUTH_METHODS + methods = dict(kwargs['default_auth_method'].items()) + for transport in ServerConnTransport.__members__.values(): + method = methods[transport] + if method is ServerAuthMethod.Auto: + if transport in ( + ServerConnTransport.HTTP_METRICS, + ServerConnTransport.HTTP_HEALTH, + ): + if kwargs['tls_client_ca_file'] is None: + method = ServerAuthMethod.Trust + else: + method = ServerAuthMethod.mTLS + else: + method = DEFAULT_AUTH_METHODS.get(transport) + methods[transport] = method + elif transport in ( + ServerConnTransport.HTTP_METRICS, + ServerConnTransport.HTTP_HEALTH, + ): + if method is ServerAuthMethod.mTLS: + if kwargs['tls_client_ca_file'] is None: + abort('--tls-client-ca-file is required ' + 'for mTLS authentication') + elif method is not ServerAuthMethod.Trust: + abort( + f'--default-auth-method of {transport} can only be one ' + f'of: {ServerAuthMethod.Trust}, {ServerAuthMethod.mTLS} ' + f'or {ServerAuthMethod.Auto}' + ) + kwargs['default_auth_method'] = ServerAuthMethods(methods) + if kwargs['binary_endpoint_security'] == 'default': kwargs['binary_endpoint_security'] = 'tls' diff --git a/edb/server/main.py b/edb/server/main.py index 3a1e0c9a874..85e69e0aa08 100644 --- a/edb/server/main.py +++ b/edb/server/main.py @@ -272,7 +272,11 @@ async def _run_server( return ss.init_tls( - args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated) + args.tls_cert_file, + args.tls_key_file, + tls_cert_newly_generated, + args.tls_client_ca_file, + ) ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated) @@ -290,7 +294,11 @@ def load_configuration(_signum): try: if args.readiness_state_file: tenant.reload_readiness_state() - ss.reload_tls(args.tls_cert_file, args.tls_key_file) + ss.reload_tls( + args.tls_cert_file, + args.tls_key_file, + args.tls_client_ca_file, + ) ss.load_jwcrypto(args.jws_key_file) except Exception: logger.critical( diff --git a/edb/server/multitenant.py b/edb/server/multitenant.py index 6341c616ca6..bb9a5a38d10 100644 --- a/edb/server/multitenant.py +++ b/edb/server/multitenant.py @@ -417,7 +417,10 @@ async def run_server( tls_cert_newly_generated, jws_keys_newly_generated ) = await ss.maybe_generate_pki(args, ss) ss.init_tls( - args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated + args.tls_cert_file, + args.tls_key_file, + tls_cert_newly_generated, + args.tls_client_ca_file, ) ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated) @@ -433,7 +436,11 @@ def load_configuration(_signum): logger.info("reloading configuration") try: - ss.reload_tls(args.tls_cert_file, args.tls_key_file) + ss.reload_tls( + args.tls_cert_file, + args.tls_key_file, + args.tls_client_ca_file, + ) ss.load_jwcrypto(args.jws_key_file) ss.reload_tenants() except Exception: diff --git a/edb/server/protocol/auth_helpers.pxd b/edb/server/protocol/auth_helpers.pxd index f5ad83c9203..e43d6c4f504 100644 --- a/edb/server/protocol/auth_helpers.pxd +++ b/edb/server/protocol/auth_helpers.pxd @@ -25,3 +25,5 @@ cdef scram_get_verifier(tenant, str user) cdef parse_basic_auth(str auth_payload) cdef extract_http_user(scheme, auth_payload, params) cdef auth_basic(tenant, str username, str password) +cdef auth_mtls(transport) +cdef auth_mtls_with_user(transport, str username) diff --git a/edb/server/protocol/auth_helpers.pyx b/edb/server/protocol/auth_helpers.pyx index c7ecbb32b01..ec8f1aeb52f 100644 --- a/edb/server/protocol/auth_helpers.pyx +++ b/edb/server/protocol/auth_helpers.pyx @@ -28,9 +28,6 @@ import logging from jwcrypto import jwt from edb import errors -from edb.common import debug - -from edb.server.protocol cimport args_ser cdef object logger = logging.getLogger('edb.server') @@ -235,3 +232,36 @@ cdef auth_basic(tenant, username: str, password: str): verifier, mock_auth = scram_get_verifier(tenant, username) if not scram_verify_password(password, verifier) or mock_auth: raise errors.AuthenticationError('authentication failed') + + +cdef auth_mtls(transport): + sslobj = transport.get_extra_info('ssl_object') + if sslobj is None: + raise errors.AuthenticationError( + "mTLS authentication is not supported over plaintext transport") + cert_data = sslobj.getpeercert() + if not cert_data: # None or empty dict + # If --tls-client-ca-file is specified, the SSLContext used here would + # have done load_verify_locations() in `server/server.py`, and we will + # have a valid client certificate (non-empty dict) now if one was + # provided by the client and passed validation; empty dict otherwise. + # `None` just means the peer didn't send a client certificate. + raise errors.AuthenticationError( + "valid client certificate required") + return cert_data + + +cdef auth_mtls_with_user(transport, str username): + cert_data = auth_mtls(transport) + try: + for rdn in cert_data["subject"]: + if rdn[0][0] == 'commonName': + if rdn[0][1] == username: + return + except Exception as ex: + raise errors.AuthenticationError( + "bad client certificate") from ex + + raise errors.AuthenticationError( + f"Common Name of client certificate doesn't match {username!r}", + ) diff --git a/edb/server/protocol/binary.pxd b/edb/server/protocol/binary.pxd index fb3322d3751..d6d480834d6 100644 --- a/edb/server/protocol/binary.pxd +++ b/edb/server/protocol/binary.pxd @@ -109,3 +109,4 @@ cdef class VirtualTransport: cdef: WriteBuffer buf bint closed + object transport diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 293dc75d4ce..08d96c2c9f7 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1722,9 +1722,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): @cython.final cdef class VirtualTransport: - def __init__(self): + def __init__(self, transport): self.buf = WriteBuffer.new() self.closed = False + self.transport = transport def write(self, data): self.buf.write_bytes(bytes(data)) @@ -1741,6 +1742,9 @@ cdef class VirtualTransport: def abort(self): self.closed = True + def get_extra_info(self, name, default=None): + return self.transport.get_extra_info(name, default) + async def eval_buffer( server, @@ -1751,12 +1755,13 @@ async def eval_buffer( protocol_version: edbdef.ProtocolVersion, auth_data: bytes, transport: srvargs.ServerConnTransport, + tcp_transport: asyncio.Transport, ): cdef: VirtualTransport vtr EdgeConnection proto - vtr = VirtualTransport() + vtr = VirtualTransport(tcp_transport) proto = new_edge_connection( server, diff --git a/edb/server/protocol/frontend.pyx b/edb/server/protocol/frontend.pyx index c3b99900e26..9713084722f 100644 --- a/edb/server/protocol/frontend.pyx +++ b/edb/server/protocol/frontend.pyx @@ -591,6 +591,8 @@ cdef class FrontendConnection(AbstractFrontendConnection): 'Simple password authentication required but it is only ' 'supported for HTTP endpoints' ) + elif authmethod_name == 'mTLS': + auth_helpers.auth_mtls_with_user(self._transport, user) else: raise errors.InternalServerError( f'unimplemented auth method: {authmethod_name}') diff --git a/edb/server/protocol/protocol.pxd b/edb/server/protocol/protocol.pxd index 8fe8e444e26..56a7ee62776 100644 --- a/edb/server/protocol/protocol.pxd +++ b/edb/server/protocol/protocol.pxd @@ -74,6 +74,8 @@ cdef class HttpProtocol: str message = ?) cdef _bad_request(self, HttpRequest request, HttpResponse response, str message) + cdef _unauthorized(self, HttpRequest request, HttpResponse response, + str message) cdef _return_binary_error(self, binary.EdgeConnection proto) cdef _write(self, bytes req_version, bytes resp_status, bytes content_type, dict custom_headers, bytes body, diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index d0f70de8721..0584088b073 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -572,6 +572,7 @@ cdef class HttpProtocol: protocol_version=proto_ver, auth_data=self.current_request.authorization, transport=srvargs.ServerConnTransport.HTTP, + tcp_transport=self.transport, ) response.status = http.HTTPStatus.OK response.content_type = PROTO_MIME @@ -660,6 +661,13 @@ cdef class HttpProtocol: self.tenant, ) elif route == 'server': + if not await self._authenticate_for_default_conn_transport( + request, + response, + srvargs.ServerConnTransport.HTTP_HEALTH, + ): + return + # System API request await system_api.handle_request( request, @@ -669,6 +677,13 @@ cdef class HttpProtocol: self.tenant, ) elif path_parts == ['metrics'] and request.method == b'GET': + if not await self._authenticate_for_default_conn_transport( + request, + response, + srvargs.ServerConnTransport.HTTP_METRICS, + ): + return + # Quoting the Open Metrics spec: # Implementers MUST expose metrics in the OpenMetrics # text format in response to a simple HTTP GET request @@ -779,6 +794,16 @@ cdef class HttpProtocol: return False + cdef _unauthorized( + self, + HttpRequest request, + HttpResponse response, + str message, + ): + response.body = message.encode("utf-8") + response.status = http.HTTPStatus.UNAUTHORIZED + response.close_connection = True + async def _check_http_auth( self, HttpRequest request, @@ -822,6 +847,13 @@ cdef class HttpProtocol: 'authentication failed: ' 'SCRAM authentication required but not supported for HTTP' ) + elif authmethod_name == 'mTLS': + if ( + self.http_endpoint_security + is srvargs.ServerEndpointSecurityMode.Tls + or self.is_tls + ): + auth_helpers.auth_mtls_with_user(self.transport, username) else: raise errors.AuthenticationError( 'authentication failed: wrong method used') @@ -830,9 +862,7 @@ cdef class HttpProtocol: if debug.flags.server: markup.dump(ex) - response.body = str(ex).encode() - response.status = http.HTTPStatus.UNAUTHORIZED - response.close_connection = True + self._unauthorized(request, response, str(ex)) # If no scheme was specified, add a WWW-Authenticate header if scheme == '': @@ -844,6 +874,39 @@ cdef class HttpProtocol: return True + async def _authenticate_for_default_conn_transport( + self, + HttpRequest request, + HttpResponse response, + transport: srvargs.ServerConnTransport, + ): + try: + auth_method = self.server.get_default_auth_method(transport) + + # If the auth method and the provided auth information match, + # try to resolve the authentication. + if auth_method is srvargs.ServerAuthMethod.Trust: + pass + elif auth_method is srvargs.ServerAuthMethod.mTLS: + if ( + self.http_endpoint_security + is srvargs.ServerEndpointSecurityMode.Tls + or self.is_tls + ): + auth_helpers.auth_mtls(self.transport) + else: + raise errors.AuthenticationError( + 'authentication failed: wrong method used') + + except Exception as ex: + if debug.flags.server: + markup.dump(ex) + + self._unauthorized(request, response, str(ex)) + + return False + + return True def get_request_url(request, is_tls): request_url = request.url diff --git a/edb/server/server.py b/edb/server/server.py index 452c8df5b82..94c340edfbb 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -783,7 +783,7 @@ def _sni_callback(self, sslobj, server_name, sslctx): # Used in multi-tenant server only. This method must not fail. pass - def reload_tls(self, tls_cert_file, tls_key_file): + def reload_tls(self, tls_cert_file, tls_key_file, client_ca_file): logger.info("loading TLS certificates") tls_password_needed = False if self._tls_certs_reload_retry_handle is not None: @@ -843,6 +843,16 @@ def _tls_private_key_password(): raise StartupError(f"Cannot load TLS certificates - {e}") from e + if client_ca_file is not None: + try: + sslctx.load_verify_locations(client_ca_file) + sslctx_pgext.load_verify_locations(client_ca_file) + except ssl.SSLError as e: + raise StartupError( + f"Cannot load client CA certificates - {e}") from e + sslctx.verify_mode = ssl.CERT_OPTIONAL + sslctx_pgext.verify_mode = ssl.CERT_OPTIONAL + sslctx.set_alpn_protocols(['edgedb-binary', 'http/1.1']) sslctx.sni_callback = self._sni_callback sslctx_pgext.sni_callback = self._sni_callback @@ -854,16 +864,17 @@ def init_tls( tls_cert_file, tls_key_file, tls_cert_newly_generated, + client_ca_file, ): assert self._sslctx is self._sslctx_pgext is None - self.reload_tls(tls_cert_file, tls_key_file) + self.reload_tls(tls_cert_file, tls_key_file, client_ca_file) self._tls_cert_file = str(tls_cert_file) self._tls_cert_newly_generated = tls_cert_newly_generated def reload_tls(_file_modified, _event, retry=0): try: - self.reload_tls(tls_cert_file, tls_key_file) + self.reload_tls(tls_cert_file, tls_key_file, client_ca_file) except (StartupError, FileNotFoundError) as e: if retry > defines._TLS_CERT_RELOAD_MAX_RETRIES: logger.critical(str(e)) @@ -891,6 +902,8 @@ def reload_tls(_file_modified, _event, retry=0): self.monitor_fs(tls_cert_file, reload_tls) if tls_cert_file != tls_key_file: self.monitor_fs(tls_key_file, reload_tls) + if client_ca_file is not None: + self.monitor_fs(client_ca_file, reload_tls) def load_jwcrypto(self, jws_key_file: pathlib.Path) -> None: try: diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 3c71c029d38..aaf3157bb9e 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -53,6 +53,7 @@ import shlex import socket import ssl +import struct import subprocess import sys import tempfile @@ -75,6 +76,7 @@ from edb.common import retryloop from edb.common import secretkey +from edb import protocol from edb.protocol import protocol as test_protocol from edb.testbase import serutils @@ -382,13 +384,22 @@ def get_api_prefix(cls): @classmethod @contextlib.contextmanager - def http_con(cls, server, keep_alive=True, server_hostname=None): + def http_con( + cls, + server, + keep_alive=True, + server_hostname=None, + client_cert_file=None, + client_key_file=None, + ): conn_args = server.get_connect_args() tls_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=conn_args["tls_ca_file"], ) tls_context.check_hostname = False + if any((client_cert_file, client_key_file)): + tls_context.load_cert_chain(client_cert_file, client_key_file) if keep_alive: ConCls = StubbornHttpConnection else: @@ -495,6 +506,53 @@ def http_con_json_request( return result, headers, status + @classmethod + def http_con_binary_request( + cls, + http_con: http.client.HTTPConnection, + query: str, + proto_ver=edgedb_defines.CURRENT_PROTOCOL, + bearer_token: Optional[str] = None, + user: str = "edgedb", + database: str = "main", + ): + proto_ver_str = f"v_{proto_ver[0]}_{proto_ver[1]}" + mime_type = f"application/x.edgedb.{proto_ver_str}.binary" + headers = {"Content-Type": mime_type, "X-EdgeDB-User": user} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + content, headers, status = cls.http_con_request( + http_con, + method="POST", + path=f"db/{database}", + prefix="", + body=protocol.Execute( + annotations=[], + allowed_capabilities=protocol.Capability.ALL, + compilation_flags=protocol.CompilationFlag(0), + implicit_limit=0, + command_text=query, + output_format=protocol.OutputFormat.JSON, + expected_cardinality=protocol.Cardinality.AT_MOST_ONE, + input_typedesc_id=b"\0" * 16, + output_typedesc_id=b"\0" * 16, + state_typedesc_id=b"\0" * 16, + arguments=b"", + state_data=b"", + ).dump() + protocol.Sync().dump(), + headers=headers, + ) + content = memoryview(content) + uint32_unpack = struct.Struct("!L").unpack + msgs = [] + while content: + mtype = content[0] + (msize,) = uint32_unpack(content[1:5]) + msg = protocol.ServerMessage.parse(mtype, content[5: msize + 1]) + msgs.append(msg) + content = content[msize + 1 :] + return msgs, headers, status + _default_cluster = None @@ -792,13 +850,22 @@ async def assertRaisesRegexTx(self, exception, regex, msg=None, **kwargs): @classmethod @contextlib.contextmanager - def http_con(cls, server=None, keep_alive=True, server_hostname=None): + def http_con( + cls, + server=None, + keep_alive=True, + server_hostname=None, + client_cert_file=None, + client_key_file=None, + ): if server is None: server = cls with super().http_con( server, keep_alive=keep_alive, server_hostname=server_hostname, + client_cert_file=client_cert_file, + client_key_file=client_key_file, ) as http_con: yield http_con @@ -2022,7 +2089,9 @@ def __init__( reset_auth: Optional[bool] = None, tenant_id: Optional[str] = None, security: edgedb_args.ServerSecurityMode, - default_auth_method: Optional[edgedb_args.ServerAuthMethod] = None, + default_auth_method: Optional[ + edgedb_args.ServerAuthMethod | edgedb_args.ServerAuthMethods + ] = None, binary_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, http_endpoint_security: Optional[ @@ -2034,6 +2103,7 @@ def __init__( tls_key_file: Optional[os.PathLike] = None, tls_cert_mode: edgedb_args.ServerTlsCertMode = ( edgedb_args.ServerTlsCertMode.SelfSigned), + tls_client_ca_file: Optional[os.PathLike] = None, jws_key_file: Optional[os.PathLike] = None, jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, @@ -2067,6 +2137,7 @@ def __init__( self.tls_cert_file = tls_cert_file self.tls_key_file = tls_key_file self.tls_cert_mode = tls_cert_mode + self.tls_client_ca_file = tls_client_ca_file self.jws_key_file = jws_key_file self.jwt_sub_allowlist_file = jwt_sub_allowlist_file self.jwt_revocation_list_file = jwt_revocation_list_file @@ -2225,11 +2296,14 @@ async def __aenter__(self): if self.tls_key_file: cmd += ['--tls-key-file', self.tls_key_file] + if self.tls_client_ca_file: + cmd += ['--tls-client-ca-file', str(self.tls_client_ca_file)] + if self.readiness_state_file: cmd += ['--readiness-state-file', self.readiness_state_file] if self.jws_key_file: - cmd += ['--jws-key-file', self.jws_key_file] + cmd += ['--jws-key-file', str(self.jws_key_file)] if self.jwt_sub_allowlist_file: cmd += ['--jwt-sub-allowlist-file', self.jwt_sub_allowlist_file] @@ -2351,7 +2425,9 @@ def start_edgedb_server( tenant_id: Optional[str] = None, security: edgedb_args.ServerSecurityMode = ( edgedb_args.ServerSecurityMode.Strict), - default_auth_method: Optional[edgedb_args.ServerAuthMethod] = None, + default_auth_method: Optional[ + edgedb_args.ServerAuthMethod | edgedb_args.ServerAuthMethods + ] = None, binary_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, http_endpoint_security: Optional[ @@ -2363,6 +2439,7 @@ def start_edgedb_server( tls_key_file: Optional[os.PathLike] = None, tls_cert_mode: edgedb_args.ServerTlsCertMode = ( edgedb_args.ServerTlsCertMode.SelfSigned), + tls_client_ca_file: Optional[os.PathLike] = None, jws_key_file: Optional[os.PathLike] = None, jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, @@ -2429,6 +2506,7 @@ def start_edgedb_server( tls_cert_file=tls_cert_file, tls_key_file=tls_key_file, tls_cert_mode=tls_cert_mode, + tls_client_ca_file=tls_client_ca_file, jws_key_file=jws_key_file, jwt_sub_allowlist_file=jwt_sub_allowlist_file, jwt_revocation_list_file=jwt_revocation_list_file, diff --git a/tests/test_http_auth.py b/tests/test_http_auth.py index da210c7ed80..5a19aa051e7 100644 --- a/tests/test_http_auth.py +++ b/tests/test_http_auth.py @@ -18,7 +18,6 @@ import base64 -import struct import urllib import edgedb @@ -144,46 +143,21 @@ def test_http_auth_scram_valid(self): server_final = base64.b64decode(values["data"]) server_sig = scram.parse_server_final_message(server_final) self.assertEqual(server_sig, expected_server_sig) + proto_ver = edbdef.CURRENT_PROTOCOL proto_ver_str = f"v_{proto_ver[0]}_{proto_ver[1]}" mime_type = f"application/x.edgedb.{proto_ver_str}.binary" with self.http_con() as con: - con.request( - "POST", - f"/db/{args['database']}", - body=protocol.Execute( - annotations=[], - allowed_capabilities=protocol.Capability.ALL, - compilation_flags=protocol.CompilationFlag(0), - implicit_limit=0, - command_text="SELECT 42", - output_format=protocol.OutputFormat.JSON, - expected_cardinality=protocol.Cardinality.AT_MOST_ONE, - input_typedesc_id=b"\0" * 16, - output_typedesc_id=b"\0" * 16, - state_typedesc_id=b"\0" * 16, - arguments=b"", - state_data=b"", - ).dump() - + protocol.Sync().dump(), - headers={ - "Content-Type": mime_type, - "Authorization": f"Bearer {token.decode('ascii')}", - "X-EdgeDB-User": args["user"], - }, + msgs, headers, status = self.http_con_binary_request( + con, + "SELECT 42", + bearer_token=token.decode("ascii"), + user=args["user"], + database=args["database"], ) - content, headers, status = self.http_con_read_response(con) self.assertEqual(status, 200) self.assertEqual(headers, headers | {"content-type": mime_type}) - uint32_unpack = struct.Struct("!L").unpack - msgs = [] - while content: - mtype = content[0] - (msize,) = uint32_unpack(content[1:5]) - msg = protocol.ServerMessage.parse(mtype, content[5 : msize + 1]) - msgs.append(msg) - content = content[msize + 1 :] self.assertIsInstance(msgs[0], protocol.CommandDataDescription) self.assertIsInstance(msgs[1], protocol.Data) self.assertEqual(bytes(msgs[1].data[0].data), b"42") diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index dccb9ccd90c..c9f01e84d1a 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -20,15 +20,19 @@ import os import pathlib import signal +import ssl import tempfile import unittest import urllib.error import urllib.request +import asyncpg import jwcrypto.jwk import edgedb +from edb import errors +from edb import protocol from edb.common import secretkey from edb.server import args from edb.server import cluster as edbcluster @@ -303,22 +307,39 @@ async def _basic_http_request( resp_status = resp.status return resp_body, resp_status - async def _jwt_http_request( - self, server, sk, username='edgedb', db='edgedb', proto='edgeql' + async def _http_request( + self, + server, + sk=None, + username='edgedb', + db='edgedb', + proto='edgeql', + client_cert_file=None, + client_key_file=None, ): - with self.http_con(server, keep_alive=False) as con: + with self.http_con( + server, + keep_alive=False, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + ) as con: + headers = {'X-EdgeDB-User': username} + if sk is not None: + headers['Authorization'] = f'bearer {sk}' return self.http_con_request( con, path=f'/db/{db}/{proto}', # ... the graphql ones will produce an error, but that's # still a 200 params=dict(query='select 1'), - headers={ - 'Authorization': f'bearer {sk}', - 'X-EdgeDB-User': username, - }, + headers=headers, ) + async def _jwt_http_request( + self, server, sk, username='edgedb', db='edgedb', proto='edgeql' + ): + return await self._http_request(server, sk, username, db, proto) + def _jwt_gql_request(self, server, sk): return self._jwt_http_request(server, sk, proto='graphql') @@ -560,3 +581,139 @@ async def test_server_auth_in_transaction(self): await self.con.query(''' DROP ROLE foo; ''') + + @unittest.skipIf( + "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, + "cannot use CONFIGURE INSTANCE in multi-tenant mode", + ) + async def test_server_auth_mtls(self): + if not self.has_create_role: + self.skipTest('create role is not supported by the backend') + + certs = pathlib.Path(__file__).parent / 'certs' + client_ca_cert_file = certs / 'client_ca.cert.pem' + client_ssl_cert_file = certs / 'client.cert.pem' + client_ssl_key_file = certs / 'client.key.pem' + async with tb.start_edgedb_server( + tls_client_ca_file=client_ca_cert_file, + security=args.ServerSecurityMode.Strict, + ) as sd: + # Setup mTLS and extensions + conn = await sd.connect() + try: + await conn.query("CREATE SUPERUSER ROLE ssl_user;") + await conn.query("CREATE EXTENSION edgeql_http;") + await self._test_mtls( + sd, client_ssl_cert_file, client_ssl_key_file, False) + await conn.query(""" + CONFIGURE INSTANCE INSERT Auth { + comment := 'test', + priority := 0, + method := (INSERT mTLS { + transports := { + cfg::ConnectionTransport.TCP, + cfg::ConnectionTransport.TCP_PG, + cfg::ConnectionTransport.HTTP, + cfg::ConnectionTransport.SIMPLE_HTTP, + }, + }), + } + """) + await self._test_mtls( + sd, client_ssl_cert_file, client_ssl_key_file, True) + finally: + await conn.aclose() + + async def _test_mtls( + self, sd, client_ssl_cert_file, client_ssl_key_file, granted + ): + # Verifies mTLS authentication on edgeql_http + if granted: + body, _, code = await self._http_request(sd, username="ssl_user") + self.assertEqual(code, 401, f"Wrong result: {body}") + body, _, code = await self._http_request( + sd, + username="ssl_user", + client_cert_file=client_ssl_cert_file, + client_key_file=client_ssl_key_file, + ) + if granted: + self.assertEqual(code, 200, f"Wrong result: {body}") + else: + self.assertEqual(code, 401, f"Wrong result: {body}") + + # Verifies mTLS authentication on the binary protocol + if granted: + with self.assertRaisesRegex( + edgedb.AuthenticationError, + 'client certificate required', + ): + await sd.connect() + # FIXME: add mTLS support in edgedb-python + + # Verifies mTLS authentication on binary protocol over HTTP + if granted: + with self.http_con( + sd, + keep_alive=False, + ) as con: + msgs, _, status = self.http_con_binary_request( + con, "select 42", user="ssl_user") + self.assertEqual(status, 200) + self.assertIsInstance(msgs[0], protocol.ErrorResponse) + self.assertEqual( + msgs[0].error_code, errors.AuthenticationError.get_code()) + with self.http_con( + sd, + keep_alive=False, + client_cert_file=client_ssl_cert_file, + client_key_file=client_ssl_key_file, + ) as con: + msgs, _, status = self.http_con_binary_request( + con, "select 42", user="ssl_user") + if granted: + self.assertEqual(status, 200) + self.assertIsInstance(msgs[0], protocol.CommandDataDescription) + self.assertIsInstance(msgs[1], protocol.Data) + self.assertEqual(bytes(msgs[1].data[0].data), b"42") + self.assertIsInstance(msgs[2], protocol.CommandComplete) + self.assertEqual(msgs[2].status, "SELECT") + self.assertIsInstance(msgs[3], protocol.ReadyForCommand) + else: + self.assertEqual(status, 200) + self.assertIsInstance(msgs[0], protocol.ErrorResponse) + self.assertEqual( + msgs[0].error_code, errors.AuthenticationError.get_code()) + + # Verifies mTLS authentication on emulated Postgres protocol + conargs = sd.get_connect_args() + tls_context = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH, + cafile=conargs["tls_ca_file"], + ) + tls_context.check_hostname = False + conargs = dict( + host=conargs['host'], + port=conargs['port'], + user="ssl_user", + database=conargs.get('database', 'main'), + ssl=tls_context, + ) + if granted: + with self.assertRaisesRegex( + asyncpg.InvalidAuthorizationSpecificationError, + 'client certificate required', + ): + await asyncpg.connect(**conargs) + tls_context.load_cert_chain( + client_ssl_cert_file, client_ssl_key_file) + if granted: + conn = await asyncpg.connect(**conargs) + self.assertEqual(await conn.fetchval("select 42"), 42) + await conn.close() + else: + with self.assertRaisesRegex( + asyncpg.InvalidAuthorizationSpecificationError, + 'authentication failed', + ): + await asyncpg.connect(**conargs) diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 3208ac6d128..1934ee43a28 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -37,6 +37,7 @@ import uuid import edgedb +import httpx from edgedb import errors from edb import protocol @@ -783,6 +784,29 @@ async def test_server_ops_cleartext_http_allowed(self): finally: await con.aclose() + async def test_server_ops_mtls_http_transports(self): + certs = pathlib.Path(__file__).parent / 'certs' + client_ca_cert_file = certs / 'client_ca.cert.pem' + client_ssl_cert_file = certs / 'client.cert.pem' + client_ssl_key_file = certs / 'client.key.pem' + async with tb.start_edgedb_server( + tls_client_ca_file=client_ca_cert_file, + security=args.ServerSecurityMode.Strict, + ) as sd: + def test(url): + resp = httpx.get(url, verify=sd.tls_cert_file) + self.assertFalse(resp.is_success) + + resp = httpx.get( + url, + verify=sd.tls_cert_file, + cert=(str(client_ssl_cert_file), str(client_ssl_key_file)), + ) + self.assertTrue(resp.is_success) + + test(f'https://{sd.host}:{sd.port}/metrics') + test(f'https://{sd.host}:{sd.port}/server/status/alive') + @unittest.skipIf( "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, "--readiness-state-file is not allowed in multi-tenant mode",