diff --git a/src/py/flwr/common/args.py b/src/py/flwr/common/args.py index 241aed935a2d..37935b127156 100644 --- a/src/py/flwr/common/args.py +++ b/src/py/flwr/common/args.py @@ -20,13 +20,9 @@ from logging import DEBUG, ERROR, WARN from os.path import isfile from pathlib import Path -from typing import Optional +from typing import Optional, Union -from flwr.common.constant import ( - TRANSPORT_TYPE_GRPC_ADAPTER, - TRANSPORT_TYPE_GRPC_RERE, - TRANSPORT_TYPE_REST, -) +from flwr.common.constant import TRANSPORT_TYPE_REST from flwr.common.logger import log @@ -55,9 +51,9 @@ def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None: def try_obtain_root_certificates( args: argparse.Namespace, grpc_server_address: str, -) -> Optional[bytes]: +) -> Optional[Union[bytes, str]]: """Validate and return the root certificates.""" - root_cert_path = args.root_certificates + root_cert_path: Optional[str] = args.root_certificates if args.insecure: if root_cert_path is not None: sys.exit( @@ -93,56 +89,38 @@ def try_obtain_root_certificates( grpc_server_address, root_cert_path, ) + if args.transport == TRANSPORT_TYPE_REST: + return root_cert_path return root_certificates def try_obtain_server_certificates( args: argparse.Namespace, - transport_type: str, ) -> Optional[tuple[bytes, bytes, bytes]]: """Validate and return the CA cert, server cert, and server private key.""" if args.insecure: log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") return None # Check if certificates are provided - if transport_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]: - if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: - if not isfile(args.ssl_ca_certfile): - sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") - if not isfile(args.ssl_certfile): - sys.exit("Path argument `--ssl-certfile` does not point to a file.") - if not isfile(args.ssl_keyfile): - sys.exit("Path argument `--ssl-keyfile` does not point to a file.") - certificates = ( - Path(args.ssl_ca_certfile).read_bytes(), # CA certificate - Path(args.ssl_certfile).read_bytes(), # server certificate - Path(args.ssl_keyfile).read_bytes(), # server private key - ) - return certificates - if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile: - sys.exit( - "You need to provide valid file paths to `--ssl-certfile`, " - "`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure " - "connection in Fleet API server (gRPC-rere)." - ) - if transport_type == TRANSPORT_TYPE_REST: - if args.ssl_certfile and args.ssl_keyfile: - if not isfile(args.ssl_certfile): - sys.exit("Path argument `--ssl-certfile` does not point to a file.") - if not isfile(args.ssl_keyfile): - sys.exit("Path argument `--ssl-keyfile` does not point to a file.") - certificates = ( - b"", - Path(args.ssl_certfile).read_bytes(), # server certificate - Path(args.ssl_keyfile).read_bytes(), # server private key - ) - return certificates - if args.ssl_certfile or args.ssl_keyfile: - sys.exit( - "You need to provide valid file paths to `--ssl-certfile` " - "and `--ssl-keyfile` to create a secure connection " - "in Fleet API server (REST, experimental)." - ) + if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile: + if not isfile(args.ssl_ca_certfile): + sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.") + if not isfile(args.ssl_certfile): + sys.exit("Path argument `--ssl-certfile` does not point to a file.") + if not isfile(args.ssl_keyfile): + sys.exit("Path argument `--ssl-keyfile` does not point to a file.") + certificates = ( + Path(args.ssl_ca_certfile).read_bytes(), # CA certificate + Path(args.ssl_certfile).read_bytes(), # server certificate + Path(args.ssl_keyfile).read_bytes(), # server private key + ) + return certificates + if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile: + sys.exit( + "You need to provide valid file paths to `--ssl-certfile`, " + "`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure " + "connection in Fleet API server (gRPC-rere)." + ) log( ERROR, "Certificates are required unless running in insecure mode. " diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 8cdadc5e3ac0..079e4b8af80d 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -265,7 +265,7 @@ def run_superlink() -> None: simulationio_address, _, _ = _format_address(args.simulationio_api_address) # Obtain certificates - certificates = try_obtain_server_certificates(args, args.fleet_api_type) + certificates = try_obtain_server_certificates(args) # Disable the user auth TLS check if args.disable_oidc_tls_cert_verification is # provided @@ -353,17 +353,13 @@ def run_superlink() -> None: ) is None: flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST) - _, ssl_certfile, ssl_keyfile = ( - certificates if certificates is not None else (None, None, None) - ) - fleet_thread = threading.Thread( target=_run_fleet_api_rest, args=( host, port, - ssl_keyfile, - ssl_certfile, + args.ssl_keyfile, + args.ssl_certfile, state_factory, ffs_factory, num_workers,