diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 5bd5fc09a1c6..6e2671482a37 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -40,6 +40,9 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, +) from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto from flwr.common.typing import Fab, Run, RunNotRunningException from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 @@ -130,12 +133,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917 if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() - interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None - if authentication_keys is not None: - interceptors = AuthenticateClientInterceptor( - authentication_keys[0], authentication_keys[1] - ) + # Automatic node auth: generate keys if user didn't provide any + if authentication_keys is None: + authentication_keys = generate_key_pairs() + # Always configure auth interceptor, with either user-provided or generated keys + interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [ + AuthenticateClientInterceptor(*authentication_keys), + ] channel = create_channel( server_address=server_address, insecure=insecure, diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 8c2801b818b1..8ffa3ffeefa1 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -80,7 +80,7 @@ def create_channel( log(DEBUG, "Opened secure gRPC connection using certificates") if interceptors is not None: - channel = grpc.intercept_channel(channel, interceptors) + channel = grpc.intercept_channel(channel, *interceptors) return channel diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 94f656a7083b..8cdadc5e3ac0 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -374,8 +374,9 @@ def run_superlink() -> None: bckg_threads.append(fleet_thread) elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: node_public_keys = _try_load_public_keys_node_authentication(args) - interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None + auto_auth = True if node_public_keys is not None: + auto_auth = False state = state_factory.state() state.clear_supernode_auth_keys() state.store_node_public_keys(node_public_keys) @@ -384,7 +385,10 @@ def run_superlink() -> None: "Node authentication enabled with %d known public keys", len(node_public_keys), ) - interceptors = [AuthenticateServerInterceptor(state_factory)] + else: + log(DEBUG, "Automatic node authentication enabled") + + interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)] fleet_server = _run_fleet_api_grpc_rere( address=fleet_address, diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 2197ee266ac9..2e60a8e0220c 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -106,6 +106,8 @@ def intercept_service( # Continue the RPC call expected_node_id = state.get_node_id(node_pk_bytes) if not handler_call_details.method.endswith("CreateNode"): + # All calls, except for `CreateNode`, must provide a public key that is + # already mapped to a `node_id` (in `LinkState`) if expected_node_id is None: return _unary_unary_rpc_terminator("Invalid node ID") # One of the method handlers in