Skip to content

Commit

Permalink
Add mTLS support (#6460)
Browse files Browse the repository at this point in the history
* Add --tls-client-ca-file
* Use --default-auth-method to configure auth for /metrics and /server
* lso fixes an ISE with `--default-auth-method tcp:auto`.
  • Loading branch information
fantix authored and msullivan committed Mar 8, 2024
1 parent 69c65cc commit 7482017
Show file tree
Hide file tree
Showing 15 changed files with 487 additions and 64 deletions.
65 changes: 61 additions & 4 deletions edb/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class ServerConfig(NamedTuple):
tls_cert_file: pathlib.Path
tls_key_file: pathlib.Path
tls_cert_mode: ServerTlsCertMode
tls_client_ca_file: Optional[pathlib.Path]

jws_key_file: pathlib.Path
jose_key_mode: JOSEKeyMode
Expand Down Expand Up @@ -501,8 +502,21 @@ def _validate_default_auth_method(
if method in {ServerAuthMethod.Auto, ServerAuthMethod.Scram}:
pass
else:
for m in methods:
methods[m] = method
for t in methods:
# HTTP_METRICS and HTTP_HEALTH support only mTLS, but for
# backward compatibility, default them to `auto` if unsupported
# method is passed explicitly.
if t in (
ServerConnTransport.HTTP_METRICS,
ServerConnTransport.HTTP_HEALTH,
):
if method not in (
ServerAuthMethod.Trust,
ServerAuthMethod.mTLS,
):
continue

methods[t] = method
elif "," not in value and ":" not in value:
raise click.BadParameter(
f"invalid authentication method: {value}, "
Expand Down Expand Up @@ -815,6 +829,18 @@ def resolve_envvar_value(self, ctx: click.Context):
'"require_file" when the --security option is set to "strict", '
'and "generate_self_signed" when the --security option is set to '
'"insecure_dev_mode"'),
click.option(
'--tls-client-ca-file',
type=PathPath(),
envvar='EDGEDB_SERVER_TLS_CLIENT_CA_FILE',
help='Specifies a path to a file containing a TLS CA certificate to '
'verify client certificates on demand. When set, the default '
'authentication method of HTTP_METRICS(/metrics) and HTTP_HEALTH'
'(/server/*) will also become "mTLS", unless explicitly set in '
'--default-auth-method. Note, the protection of such HTTP '
'endpoints is only complete if --http-endpoint-security is also '
'set to `tls`, or they are still accessible in plaintext HTTP.'
),
click.option(
'--generate-self-signed-cert', type=bool, default=False, is_flag=True,
help='DEPRECATED.\n\n'
Expand Down Expand Up @@ -1174,16 +1200,47 @@ def parse_args(**kwargs: Any):
if kwargs['http_endpoint_security'] == 'default':
kwargs['http_endpoint_security'] = 'optional'
if not kwargs['default_auth_method']:
kwargs['default_auth_method'] = {
kwargs['default_auth_method'] = ServerAuthMethods({
t: ServerAuthMethod.Trust
for t in ServerConnTransport.__members__.values()
}
})
if kwargs['tls_cert_mode'] == 'default':
kwargs['tls_cert_mode'] = 'generate_self_signed'

elif not kwargs['default_auth_method']:
kwargs['default_auth_method'] = DEFAULT_AUTH_METHODS

methods = dict(kwargs['default_auth_method'].items())
for transport in ServerConnTransport.__members__.values():
method = methods[transport]
if method is ServerAuthMethod.Auto:
if transport in (
ServerConnTransport.HTTP_METRICS,
ServerConnTransport.HTTP_HEALTH,
):
if kwargs['tls_client_ca_file'] is None:
method = ServerAuthMethod.Trust
else:
method = ServerAuthMethod.mTLS
else:
method = DEFAULT_AUTH_METHODS.get(transport)
methods[transport] = method
elif transport in (
ServerConnTransport.HTTP_METRICS,
ServerConnTransport.HTTP_HEALTH,
):
if method is ServerAuthMethod.mTLS:
if kwargs['tls_client_ca_file'] is None:
abort('--tls-client-ca-file is required '
'for mTLS authentication')
elif method is not ServerAuthMethod.Trust:
abort(
f'--default-auth-method of {transport} can only be one '
f'of: {ServerAuthMethod.Trust}, {ServerAuthMethod.mTLS} '
f'or {ServerAuthMethod.Auto}'
)
kwargs['default_auth_method'] = ServerAuthMethods(methods)

if kwargs['binary_endpoint_security'] == 'default':
kwargs['binary_endpoint_security'] = 'tls'

Expand Down
12 changes: 10 additions & 2 deletions edb/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ async def _run_server(
return

ss.init_tls(
args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated)
args.tls_cert_file,
args.tls_key_file,
tls_cert_newly_generated,
args.tls_client_ca_file,
)

ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated)

Expand All @@ -290,7 +294,11 @@ def load_configuration(_signum):
try:
if args.readiness_state_file:
tenant.reload_readiness_state()
ss.reload_tls(args.tls_cert_file, args.tls_key_file)
ss.reload_tls(
args.tls_cert_file,
args.tls_key_file,
args.tls_client_ca_file,
)
ss.load_jwcrypto(args.jws_key_file)
except Exception:
logger.critical(
Expand Down
11 changes: 9 additions & 2 deletions edb/server/multitenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ async def run_server(
tls_cert_newly_generated, jws_keys_newly_generated
) = await ss.maybe_generate_pki(args, ss)
ss.init_tls(
args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated
args.tls_cert_file,
args.tls_key_file,
tls_cert_newly_generated,
args.tls_client_ca_file,
)
ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated)

Expand All @@ -433,7 +436,11 @@ def load_configuration(_signum):

logger.info("reloading configuration")
try:
ss.reload_tls(args.tls_cert_file, args.tls_key_file)
ss.reload_tls(
args.tls_cert_file,
args.tls_key_file,
args.tls_client_ca_file,
)
ss.load_jwcrypto(args.jws_key_file)
ss.reload_tenants()
except Exception:
Expand Down
2 changes: 2 additions & 0 deletions edb/server/protocol/auth_helpers.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ cdef scram_get_verifier(tenant, str user)
cdef parse_basic_auth(str auth_payload)
cdef extract_http_user(scheme, auth_payload, params)
cdef auth_basic(tenant, str username, str password)
cdef auth_mtls(transport)
cdef auth_mtls_with_user(transport, str username)
36 changes: 33 additions & 3 deletions edb/server/protocol/auth_helpers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ import logging
from jwcrypto import jwt

from edb import errors
from edb.common import debug

from edb.server.protocol cimport args_ser


cdef object logger = logging.getLogger('edb.server')
Expand Down Expand Up @@ -235,3 +232,36 @@ cdef auth_basic(tenant, username: str, password: str):
verifier, mock_auth = scram_get_verifier(tenant, username)
if not scram_verify_password(password, verifier) or mock_auth:
raise errors.AuthenticationError('authentication failed')


cdef auth_mtls(transport):
sslobj = transport.get_extra_info('ssl_object')
if sslobj is None:
raise errors.AuthenticationError(
"mTLS authentication is not supported over plaintext transport")
cert_data = sslobj.getpeercert()
if not cert_data: # None or empty dict
# If --tls-client-ca-file is specified, the SSLContext used here would
# have done load_verify_locations() in `server/server.py`, and we will
# have a valid client certificate (non-empty dict) now if one was
# provided by the client and passed validation; empty dict otherwise.
# `None` just means the peer didn't send a client certificate.
raise errors.AuthenticationError(
"valid client certificate required")
return cert_data


cdef auth_mtls_with_user(transport, str username):
cert_data = auth_mtls(transport)
try:
for rdn in cert_data["subject"]:
if rdn[0][0] == 'commonName':
if rdn[0][1] == username:
return
except Exception as ex:
raise errors.AuthenticationError(
"bad client certificate") from ex

raise errors.AuthenticationError(
f"Common Name of client certificate doesn't match {username!r}",
)
1 change: 1 addition & 0 deletions edb/server/protocol/binary.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,4 @@ cdef class VirtualTransport:
cdef:
WriteBuffer buf
bint closed
object transport
9 changes: 7 additions & 2 deletions edb/server/protocol/binary.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1722,9 +1722,10 @@ cdef class EdgeConnection(frontend.FrontendConnection):

@cython.final
cdef class VirtualTransport:
def __init__(self):
def __init__(self, transport):
self.buf = WriteBuffer.new()
self.closed = False
self.transport = transport

def write(self, data):
self.buf.write_bytes(bytes(data))
Expand All @@ -1741,6 +1742,9 @@ cdef class VirtualTransport:
def abort(self):
self.closed = True

def get_extra_info(self, name, default=None):
return self.transport.get_extra_info(name, default)


async def eval_buffer(
server,
Expand All @@ -1751,12 +1755,13 @@ async def eval_buffer(
protocol_version: edbdef.ProtocolVersion,
auth_data: bytes,
transport: srvargs.ServerConnTransport,
tcp_transport: asyncio.Transport,
):
cdef:
VirtualTransport vtr
EdgeConnection proto

vtr = VirtualTransport()
vtr = VirtualTransport(tcp_transport)

proto = new_edge_connection(
server,
Expand Down
2 changes: 2 additions & 0 deletions edb/server/protocol/frontend.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ cdef class FrontendConnection(AbstractFrontendConnection):
'Simple password authentication required but it is only '
'supported for HTTP endpoints'
)
elif authmethod_name == 'mTLS':
auth_helpers.auth_mtls_with_user(self._transport, user)
else:
raise errors.InternalServerError(
f'unimplemented auth method: {authmethod_name}')
Expand Down
2 changes: 2 additions & 0 deletions edb/server/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ cdef class HttpProtocol:
str message = ?)
cdef _bad_request(self, HttpRequest request, HttpResponse response,
str message)
cdef _unauthorized(self, HttpRequest request, HttpResponse response,
str message)
cdef _return_binary_error(self, binary.EdgeConnection proto)
cdef _write(self, bytes req_version, bytes resp_status,
bytes content_type, dict custom_headers, bytes body,
Expand Down
69 changes: 66 additions & 3 deletions edb/server/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ cdef class HttpProtocol:
protocol_version=proto_ver,
auth_data=self.current_request.authorization,
transport=srvargs.ServerConnTransport.HTTP,
tcp_transport=self.transport,
)
response.status = http.HTTPStatus.OK
response.content_type = PROTO_MIME
Expand Down Expand Up @@ -660,6 +661,13 @@ cdef class HttpProtocol:
self.tenant,
)
elif route == 'server':
if not await self._authenticate_for_default_conn_transport(
request,
response,
srvargs.ServerConnTransport.HTTP_HEALTH,
):
return

# System API request
await system_api.handle_request(
request,
Expand All @@ -669,6 +677,13 @@ cdef class HttpProtocol:
self.tenant,
)
elif path_parts == ['metrics'] and request.method == b'GET':
if not await self._authenticate_for_default_conn_transport(
request,
response,
srvargs.ServerConnTransport.HTTP_METRICS,
):
return

# Quoting the Open Metrics spec:
# Implementers MUST expose metrics in the OpenMetrics
# text format in response to a simple HTTP GET request
Expand Down Expand Up @@ -779,6 +794,16 @@ cdef class HttpProtocol:

return False

cdef _unauthorized(
self,
HttpRequest request,
HttpResponse response,
str message,
):
response.body = message.encode("utf-8")
response.status = http.HTTPStatus.UNAUTHORIZED
response.close_connection = True

async def _check_http_auth(
self,
HttpRequest request,
Expand Down Expand Up @@ -822,6 +847,13 @@ cdef class HttpProtocol:
'authentication failed: '
'SCRAM authentication required but not supported for HTTP'
)
elif authmethod_name == 'mTLS':
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls_with_user(self.transport, username)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')
Expand All @@ -830,9 +862,7 @@ cdef class HttpProtocol:
if debug.flags.server:
markup.dump(ex)

response.body = str(ex).encode()
response.status = http.HTTPStatus.UNAUTHORIZED
response.close_connection = True
self._unauthorized(request, response, str(ex))

# If no scheme was specified, add a WWW-Authenticate header
if scheme == '':
Expand All @@ -844,6 +874,39 @@ cdef class HttpProtocol:

return True

async def _authenticate_for_default_conn_transport(
self,
HttpRequest request,
HttpResponse response,
transport: srvargs.ServerConnTransport,
):
try:
auth_method = self.server.get_default_auth_method(transport)

# If the auth method and the provided auth information match,
# try to resolve the authentication.
if auth_method is srvargs.ServerAuthMethod.Trust:
pass
elif auth_method is srvargs.ServerAuthMethod.mTLS:
if (
self.http_endpoint_security
is srvargs.ServerEndpointSecurityMode.Tls
or self.is_tls
):
auth_helpers.auth_mtls(self.transport)
else:
raise errors.AuthenticationError(
'authentication failed: wrong method used')

except Exception as ex:
if debug.flags.server:
markup.dump(ex)

self._unauthorized(request, response, str(ex))

return False

return True

def get_request_url(request, is_tls):
request_url = request.url
Expand Down
Loading

0 comments on commit 7482017

Please sign in to comment.