Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Implement GrpcRereFleetConnection, GrpcBidiConnection, GrpcAdapterFleetConnection, and RestFleetConnection #4056

Open
wants to merge 24 commits into
base: mv-conns
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 36 additions & 75 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import signal
import sys
import time
from contextlib import AbstractContextManager
from dataclasses import dataclass
from logging import ERROR, INFO, WARN
from os import urandom
Expand All @@ -37,7 +36,7 @@
from flwr.client.clientapp.app import flwr_clientapp
from flwr.client.nodestate.nodestate_factory import NodeStateFactory
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
CLIENT_OCTET,
Expand All @@ -59,13 +58,16 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
from flwr.common.typing import Run, RunNotRunningException, UserConfig
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server

from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
from .connection.grpc_adapter.connection import grpc_adapter
from .connection.grpc_bidi.connection import grpc_connection
from .connection.grpc_rere.connection import grpc_request_response
from .connection import (
FleetConnection,
GrpcAdapterFleetConnection,
GrpcBidiFleetConnection,
GrpcRereFleetConnection,
)
from .message_handler.message_handler import handle_control_message
from .numpy_client import NumPyClient
from .run_info_store import DeprecatedRunInfoStore
Expand Down Expand Up @@ -406,41 +408,24 @@ def _on_backoff(retry_state: RetryState) -> None:
root_certificates,
authentication_keys,
) as conn:
receive, send, create_node, delete_node, get_run, get_fab = conn

# Register node when connecting the first time
if run_info_store is None:
if create_node is None:
if transport not in ["grpc-bidi", None]:
raise NotImplementedError(
"All transports except `grpc-bidi` require "
"an implementation for `create_node()`.'"
)
# gRPC-bidi doesn't have the concept of node_id,
# so we set it to -1
run_info_store = DeprecatedRunInfoStore(
node_id=-1,
node_config={},
)
else:
# Call create_node fn to register node
# and store node_id in state
if (node_id := create_node()) is None:
raise ValueError(
"Failed to register SuperNode with the SuperLink"
)
state.set_node_id(node_id)
run_info_store = DeprecatedRunInfoStore(
node_id=state.get_node_id(),
node_config=node_config,
)
# Call create_node fn to register node
# and store node_id in state
if (node_id := conn.create_node()) is None:
raise ValueError("Failed to register SuperNode with the SuperLink")
state.set_node_id(node_id)
run_info_store = DeprecatedRunInfoStore(
node_id=state.get_node_id(),
node_config=node_config,
)

app_state_tracker.register_signal_handler()
# pylint: disable=too-many-nested-blocks
while not app_state_tracker.interrupt:
try:
# Receive
message = receive()
message = conn.receive()
if message is None:
time.sleep(3) # Wait for 3s before asking again
continue
Expand All @@ -463,21 +448,17 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle control message
out_message, sleep_duration = handle_control_message(message)
if out_message:
send(out_message)
conn.send(out_message)
break

# Get run info
run_id = message.metadata.run_id
if run_id not in runs:
if get_run is not None:
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
runs[run_id] = Run.create_empty(run_id=run_id)
runs[run_id] = conn.get_run(run_id)

run: Run = runs[run_id]
if get_fab is not None and run.fab_hash:
fab = get_fab(run.fab_hash, run_id)
if run.fab_hash:
fab = conn.get_fab(run.fab_hash, run_id)
if not isolation:
# If `ClientApp` runs in the same process, install the FAB
install_from_fab(fab.content, flwr_path, True)
Expand Down Expand Up @@ -612,7 +593,7 @@ def _on_backoff(retry_state: RetryState) -> None:
)

# Send
send(reply_message)
conn.send(reply_message)
log(INFO, "Sent reply")

except RunNotRunningException:
Expand All @@ -631,8 +612,8 @@ def _on_backoff(retry_state: RetryState) -> None:
# pylint: enable=too-many-nested-blocks

# Unregister node
if delete_node is not None and app_state_tracker.is_connected:
delete_node() # pylint: disable=not-callable
if app_state_tracker.is_connected:
conn.delete_node()

if sleep_duration == 0:
log(INFO, "Disconnect and shut down")
Expand Down Expand Up @@ -749,30 +730,9 @@ def start_numpy_client(
)


def _init_connection(transport: Optional[str], server_address: str) -> tuple[
Callable[
[
str,
bool,
RetryInvoker,
int,
Union[bytes, str, None],
Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
],
AbstractContextManager[
tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
Optional[Callable[[str, int], Fab]],
]
],
],
str,
type[Exception],
]:
def _init_connection(
transport: Optional[str], server_address: str
) -> tuple[type[FleetConnection], str, type[Exception]]:
# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
Expand All @@ -784,26 +744,27 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
if transport is None:
transport = TRANSPORT_TYPE_GRPC_BIDI

# Use either gRPC bidirectional streaming or REST request/response
# Use grpc-rere/grpc-adapter/rest/grpc-bidi transport layer
connection: Optional[type[FleetConnection]] = None
if transport == TRANSPORT_TYPE_REST:
try:
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import RequestException

from .connection.rest.connection import http_request_response
from .connection import RestFleetConnection
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)
if server_address[:4] != "http":
sys.exit(
"When using the REST API, please provide `https://` or "
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
)
connection, error_type = http_request_response, RequestsConnectionError
connection, error_type = RestFleetConnection, RequestException
elif transport == TRANSPORT_TYPE_GRPC_RERE:
connection, error_type = grpc_request_response, RpcError
connection, error_type = GrpcRereFleetConnection, RpcError
elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
connection, error_type = grpc_adapter, RpcError
connection, error_type = GrpcAdapterFleetConnection, RpcError
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
connection, error_type = grpc_connection, RpcError
connection, error_type = GrpcBidiFleetConnection, RpcError
else:
raise ValueError(
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
Expand Down
8 changes: 8 additions & 0 deletions src/py/flwr/client/connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@


from .fleet_connection import FleetConnection
from .grpc_adapter import GrpcAdapterFleetConnection
from .grpc_bidi import GrpcBidiFleetConnection
from .grpc_rere import GrpcRereFleetConnection
from .rest.rest_fleet_connection import RestFleetConnection

__all__ = [
"FleetConnection",
"GrpcAdapterFleetConnection",
"GrpcBidiFleetConnection",
"GrpcRereFleetConnection",
"RestFleetConnection",
]
80 changes: 80 additions & 0 deletions src/py/flwr/client/connection/fleet_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fleet API definition for the grpc-rere transport layer."""


from abc import ABC, abstractmethod
from typing import Any

from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
DeleteNodeRequest,
DeleteNodeResponse,
PingRequest,
PingResponse,
PullMessagesRequest,
PullMessagesResponse,
PushMessagesRequest,
PushMessagesResponse,
)
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611


class FleetApi(ABC):
"""Fleet API that provides low-level access to Fleet API server."""

@abstractmethod
def Ping( # pylint: disable=C0103
self, request: PingRequest, **kwargs: Any
) -> PingResponse:
"""Fleet.Ping."""

@abstractmethod
def CreateNode( # pylint: disable=C0103
self, request: CreateNodeRequest, **kwargs: Any
) -> CreateNodeResponse:
"""Fleet.CreateNode."""

@abstractmethod
def DeleteNode( # pylint: disable=C0103
self, request: DeleteNodeRequest, **kwargs: Any
) -> DeleteNodeResponse:
"""Fleet.DeleteNode."""

@abstractmethod
def PullMessages( # pylint: disable=C0103
self, request: PullMessagesRequest, **kwargs: Any
) -> PullMessagesResponse:
"""Fleet.PullMessages."""

@abstractmethod
def PushMessages( # pylint: disable=C0103
self, request: PushMessagesRequest, **kwargs: Any
) -> PushMessagesResponse:
"""Fleet.PushMessages."""

@abstractmethod
def GetRun( # pylint: disable=C0103
self, request: GetRunRequest, **kwargs: Any
) -> GetRunResponse:
"""Fleet.GetRun."""

@abstractmethod
def GetFab( # pylint: disable=C0103
self, request: GetFabRequest, **kwargs: Any
) -> GetFabResponse:
"""Fleet.GetFab."""
7 changes: 7 additions & 0 deletions src/py/flwr/client/connection/grpc_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@
# limitations under the License.
# ==============================================================================
"""Client-side part of the GrpcAdapter transport layer."""


from .grpc_adapter_fleet_connection import GrpcAdapterFleetConnection

__all__ = [
"GrpcAdapterFleetConnection",
]
Loading