Skip to content

Commit

Permalink
fix(framework) Fix issue preventing REST-based Fleet API from using S…
Browse files Browse the repository at this point in the history
…SL (#4890)
  • Loading branch information
panh99 authored Jan 31, 2025
1 parent 212768b commit 4e31542
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 54 deletions.
72 changes: 25 additions & 47 deletions src/py/flwr/common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@
from logging import DEBUG, ERROR, WARN
from os.path import isfile
from pathlib import Path
from typing import Optional
from typing import Optional, Union

from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
)
from flwr.common.constant import TRANSPORT_TYPE_REST
from flwr.common.logger import log


Expand Down Expand Up @@ -55,9 +51,9 @@ def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
def try_obtain_root_certificates(
args: argparse.Namespace,
grpc_server_address: str,
) -> Optional[bytes]:
) -> Optional[Union[bytes, str]]:
"""Validate and return the root certificates."""
root_cert_path = args.root_certificates
root_cert_path: Optional[str] = args.root_certificates
if args.insecure:
if root_cert_path is not None:
sys.exit(
Expand Down Expand Up @@ -93,56 +89,38 @@ def try_obtain_root_certificates(
grpc_server_address,
root_cert_path,
)
if args.transport == TRANSPORT_TYPE_REST:
return root_cert_path
return root_certificates


def try_obtain_server_certificates(
args: argparse.Namespace,
transport_type: str,
) -> Optional[tuple[bytes, bytes, bytes]]:
"""Validate and return the CA cert, server cert, and server private key."""
if args.insecure:
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
return None
# Check if certificates are provided
if transport_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
if not isfile(args.ssl_ca_certfile):
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
if not isfile(args.ssl_certfile):
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
if not isfile(args.ssl_keyfile):
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
certificates = (
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
Path(args.ssl_certfile).read_bytes(), # server certificate
Path(args.ssl_keyfile).read_bytes(), # server private key
)
return certificates
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
sys.exit(
"You need to provide valid file paths to `--ssl-certfile`, "
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
"connection in Fleet API server (gRPC-rere)."
)
if transport_type == TRANSPORT_TYPE_REST:
if args.ssl_certfile and args.ssl_keyfile:
if not isfile(args.ssl_certfile):
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
if not isfile(args.ssl_keyfile):
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
certificates = (
b"",
Path(args.ssl_certfile).read_bytes(), # server certificate
Path(args.ssl_keyfile).read_bytes(), # server private key
)
return certificates
if args.ssl_certfile or args.ssl_keyfile:
sys.exit(
"You need to provide valid file paths to `--ssl-certfile` "
"and `--ssl-keyfile` to create a secure connection "
"in Fleet API server (REST, experimental)."
)
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
if not isfile(args.ssl_ca_certfile):
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
if not isfile(args.ssl_certfile):
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
if not isfile(args.ssl_keyfile):
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
certificates = (
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
Path(args.ssl_certfile).read_bytes(), # server certificate
Path(args.ssl_keyfile).read_bytes(), # server private key
)
return certificates
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
sys.exit(
"You need to provide valid file paths to `--ssl-certfile`, "
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
"connection in Fleet API server (gRPC-rere)."
)
log(
ERROR,
"Certificates are required unless running in insecure mode. "
Expand Down
10 changes: 3 additions & 7 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def run_superlink() -> None:
simulationio_address, _, _ = _format_address(args.simulationio_api_address)

# Obtain certificates
certificates = try_obtain_server_certificates(args, args.fleet_api_type)
certificates = try_obtain_server_certificates(args)

# Disable the user auth TLS check if args.disable_oidc_tls_cert_verification is
# provided
Expand Down Expand Up @@ -353,17 +353,13 @@ def run_superlink() -> None:
) is None:
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)

_, ssl_certfile, ssl_keyfile = (
certificates if certificates is not None else (None, None, None)
)

fleet_thread = threading.Thread(
target=_run_fleet_api_rest,
args=(
host,
port,
ssl_keyfile,
ssl_certfile,
args.ssl_keyfile,
args.ssl_certfile,
state_factory,
ffs_factory,
num_workers,
Expand Down

0 comments on commit 4e31542

Please sign in to comment.