diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 3f86dfdc01bf..5c1d1aba96f7 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -21,6 +21,11 @@ from pathlib import Path from typing import Callable, ContextManager, Optional, Tuple, Type, Union +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + load_ssh_private_key, + load_ssh_public_key, +) from grpc import RpcError from flwr.client.client import Client @@ -41,6 +46,9 @@ from flwr.common.message import Error from flwr.common.object_ref import load_app, validate from flwr.common.retry_invoker import RetryInvoker, exponential +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + ssh_types_to_elliptic_curve, +) from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response @@ -114,18 +122,51 @@ def _load() -> ClientApp: return client_app + authentication_keys = _try_setup_client_authentication(args) + _start_client_internal( server_address=args.server, load_client_app_fn=_load, transport="rest" if args.rest else "grpc-rere", root_certificates=root_certificates, insecure=args.insecure, + authentication_keys=authentication_keys, max_retries=args.max_retries, max_wait_time=args.max_wait_time, ) register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) +def _try_setup_client_authentication( + args: argparse.Namespace, +) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: + if not args.authentication_keys: + return None + + ssh_private_key = load_ssh_private_key( + Path(args.authentication_keys[0]).read_bytes(), + None, + ) + ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes()) + + try: + client_private_key, client_public_key = ssh_types_to_elliptic_curve( + ssh_private_key, ssh_public_key + ) + except TypeError: + sys.exit( + "The file paths provided could not be read as a private and public " + "key pair. Client authentication requires an elliptic curve public and " + "private key pair. Please provide the file paths containing elliptic " + "curve private and public keys to '--authentication-keys'." + ) + + return ( + client_private_key, + client_public_key, + ) + + def _parse_args_run_client_app() -> argparse.ArgumentParser: """Parse flower-client-app command line arguments.""" parser = argparse.ArgumentParser( @@ -165,6 +206,9 @@ def start_client( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + authentication_keys: Optional[ + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, ) -> None: @@ -249,6 +293,7 @@ class `flwr.client.Client` (default: None) root_certificates=root_certificates, insecure=insecure, transport=transport, + authentication_keys=authentication_keys, max_retries=max_retries, max_wait_time=max_wait_time, ) @@ -269,6 +314,9 @@ def _start_client_internal( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + authentication_keys: Optional[ + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, ) -> None: @@ -393,6 +441,7 @@ def _load_client_app() -> ClientApp: retry_invoker, grpc_max_message_length, root_certificates, + authentication_keys, ) as conn: # pylint: disable-next=W0612 receive, send, create_node, delete_node, get_run = conn @@ -606,7 +655,14 @@ def start_numpy_client( def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ - [str, bool, RetryInvoker, int, Union[bytes, str, None]], + [ + str, + bool, + RetryInvoker, + int, + Union[bytes, str, None], + Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]], + ], ContextManager[ Tuple[ Callable[[], Optional[Message]], diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 2611b160f830..7be8e526942d 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -117,3 +117,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: "app from there." " Default: current working directory.", ) + parser.add_argument( + "--authentication-keys", + nargs=2, + metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"), + type=str, + help="Provide two file paths: (1) the client's private " + "key file, and (2) the client's public key file.", + )