Skip to content

Commit

Permalink
feat(framework) Enable automatic node authentication (#4867)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent 41f9d3d commit 54a3536
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
15 changes: 10 additions & 5 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
)
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
from flwr.common.typing import Fab, Run, RunNotRunningException
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
Expand Down Expand Up @@ -130,12 +133,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
if isinstance(root_certificates, str):
root_certificates = Path(root_certificates).read_bytes()

interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
if authentication_keys is not None:
interceptors = AuthenticateClientInterceptor(
authentication_keys[0], authentication_keys[1]
)
# Automatic node auth: generate keys if user didn't provide any
if authentication_keys is None:
authentication_keys = generate_key_pairs()

# Always configure auth interceptor, with either user-provided or generated keys
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
AuthenticateClientInterceptor(*authentication_keys),
]
channel = create_channel(
server_address=server_address,
insecure=insecure,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create_channel(
log(DEBUG, "Opened secure gRPC connection using certificates")

if interceptors is not None:
channel = grpc.intercept_channel(channel, interceptors)
channel = grpc.intercept_channel(channel, *interceptors)

return channel

Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,9 @@ def run_superlink() -> None:
bckg_threads.append(fleet_thread)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
node_public_keys = _try_load_public_keys_node_authentication(args)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
auto_auth = True
if node_public_keys is not None:
auto_auth = False
state = state_factory.state()
state.clear_supernode_auth_keys()
state.store_node_public_keys(node_public_keys)
Expand All @@ -384,7 +385,10 @@ def run_superlink() -> None:
"Node authentication enabled with %d known public keys",
len(node_public_keys),
)
interceptors = [AuthenticateServerInterceptor(state_factory)]
else:
log(DEBUG, "Automatic node authentication enabled")

interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]

fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def intercept_service(
# Continue the RPC call
expected_node_id = state.get_node_id(node_pk_bytes)
if not handler_call_details.method.endswith("CreateNode"):
# All calls, except for `CreateNode`, must provide a public key that is
# already mapped to a `node_id` (in `LinkState`)
if expected_node_id is None:
return _unary_unary_rpc_terminator("Invalid node ID")
# One of the method handlers in
Expand Down

0 comments on commit 54a3536

Please sign in to comment.