Skip to content

Commit

Permalink
merge w/ main
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Apr 2, 2024
2 parents ab85130 + 0173567 commit 219e610
Show file tree
Hide file tree
Showing 19 changed files with 425 additions and 111 deletions.
1 change: 1 addition & 0 deletions baselines/flwr_baselines/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ wget = "^3.2"
virtualenv = "^20.24.6"
pandas = "^1.5.3"
pyhamcrest = "^2.0.4"
pillow = "==10.2.0"

[tool.poetry.dev-dependencies]
isort = "==5.13.2"
Expand Down
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ service Fleet {
}

// CreateNode messages
message CreateNodeRequest {}
message CreateNodeRequest { double ping_interval = 1; }
message CreateNodeResponse { Node node = 1; }

// DeleteNode messages
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TRANSPORT_TYPES,
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.object_ref import load_app, validate
from flwr.common.retry_invoker import RetryInvoker, exponential
Expand Down Expand Up @@ -385,8 +385,6 @@ def _load_client_app() -> ClientApp:
return ClientApp(client_fn=client_fn)

load_client_app_fn = _load_client_app
else:
warn_experimental_feature("`load_client_app_fn`")

# At this point, only `load_client_app_fn` should be used
# Both `client` and `client_fn` must not be used directly
Expand All @@ -397,7 +395,7 @@ def _load_client_app() -> ClientApp:
)

retry_invoker = RetryInvoker(
wait_factory=exponential,
wait_gen_factory=exponential,
recoverable_exceptions=connection_error_type,
max_tries=max_retries,
max_time=max_wait_time,
Expand Down
7 changes: 7 additions & 0 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flwr.client.mod.utils import make_ffn
from flwr.client.typing import ClientFn, Mod
from flwr.common import Context, Message, MessageType
from flwr.common.logger import warn_preview_feature

from .typing import ClientAppCallable

Expand Down Expand Up @@ -123,6 +124,8 @@ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
if self._call:
raise _registration_error(MessageType.TRAIN)

warn_preview_feature("ClientApp-register-train-function")

# Register provided function with the ClientApp object
# Wrap mods around the wrapped step function
self._train = make_ffn(train_fn, self._mods)
Expand Down Expand Up @@ -151,6 +154,8 @@ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
if self._call:
raise _registration_error(MessageType.EVALUATE)

warn_preview_feature("ClientApp-register-evaluate-function")

# Register provided function with the ClientApp object
# Wrap mods around the wrapped step function
self._evaluate = make_ffn(evaluate_fn, self._mods)
Expand Down Expand Up @@ -179,6 +184,8 @@ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
if self._call:
raise _registration_error(MessageType.QUERY)

warn_preview_feature("ClientApp-register-query-function")

# Register provided function with the ClientApp object
# Wrap mods around the wrapped step function
self._query = make_ffn(query_fn, self._mods)
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def run_client() -> int:
server_address=f"[::]:{port}",
insecure=True,
retry_invoker=RetryInvoker(
wait_factory=exponential,
wait_gen_factory=exponential,
recoverable_exceptions=grpc.RpcError,
max_tries=1,
max_time=None,
Expand Down
100 changes: 71 additions & 29 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,49 @@
"""Contextmanager for a gRPC request-response channel to the Flower server."""


import random
import threading
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.client.heartbeat import start_ping_loop
from flwr.client.message_handler.message_handler import validate_out_message
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.constant import (
PING_BASE_MULTIPLIER,
PING_CALL_TIMEOUT,
PING_DEFAULT_INTERVAL,
PING_RANDOM_RANGE,
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
PingRequest,
PingResponse,
PullTaskInsRequest,
PushTaskResRequest,
)
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

KEY_NODE = "node"
KEY_METADATA = "in_message_metadata"


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)


@contextmanager
def grpc_request_response(
def grpc_request_response( # pylint: disable=R0914, R0915
server_address: str,
insecure: bool,
retry_invoker: RetryInvoker,
Expand Down Expand Up @@ -95,8 +103,6 @@ def grpc_request_response(
create_node : Optional[Callable]
delete_node : Optional[Callable]
"""
warn_experimental_feature("`grpc-rere`")

if isinstance(root_certificates, str):
root_certificates = Path(root_certificates).read_bytes()

Expand All @@ -107,47 +113,81 @@ def grpc_request_response(
max_message_length=max_message_length,
)
channel.subscribe(on_channel_state_change)
stub = FleetStub(channel)

# Necessary state to validate messages to be sent
state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}

# Enable create_node and delete_node to store node
node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
# Shared variables for inner functions
stub = FleetStub(channel)
metadata: Optional[Metadata] = None
node: Optional[Node] = None
ping_thread: Optional[threading.Thread] = None
ping_stop_event = threading.Event()

###########################################################################
# receive/send functions
# ping/create_node/delete_node/receive/send functions
###########################################################################

def ping() -> None:
# Get Node
if node is None:
log(ERROR, "Node instance missing")
return

# Construct the ping request
req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)

# Call FleetAPI
res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)

# Check if success
if not res.success:
raise RuntimeError("Ping failed unexpectedly.")

# Wait
rd = random.uniform(*PING_RANDOM_RANGE)
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
next_interval *= PING_BASE_MULTIPLIER + rd
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)

def create_node() -> None:
"""Set create_node."""
create_node_request = CreateNodeRequest()
# Call FleetAPI
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
create_node_response = retry_invoker.invoke(
stub.CreateNode,
request=create_node_request,
)
node_store[KEY_NODE] = create_node_response.node

# Remember the node and the ping-loop thread
nonlocal node, ping_thread
node = cast(Node, create_node_response.node)
ping_thread = start_ping_loop(ping, ping_stop_event)

def delete_node() -> None:
"""Set delete_node."""
# Get Node
if node_store[KEY_NODE] is None:
nonlocal node
if node is None:
log(ERROR, "Node instance missing")
return
node: Node = cast(Node, node_store[KEY_NODE])

# Stop the ping-loop thread
ping_stop_event.set()
if ping_thread is not None:
ping_thread.join()

# Call FleetAPI
delete_node_request = DeleteNodeRequest(node=node)
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)

del node_store[KEY_NODE]
# Cleanup
node = None

def receive() -> Optional[Message]:
"""Receive next task from server."""
# Get Node
if node_store[KEY_NODE] is None:
if node is None:
log(ERROR, "Node instance missing")
return None
node: Node = cast(Node, node_store[KEY_NODE])

# Request instructions (task) from server
request = PullTaskInsRequest(node=node)
Expand All @@ -167,26 +207,27 @@ def receive() -> Optional[Message]:
in_message = message_from_taskins(task_ins) if task_ins else None

# Remember `metadata` of the in message
state[KEY_METADATA] = copy(in_message.metadata) if in_message else None
nonlocal metadata
metadata = copy(in_message.metadata) if in_message else None

# Return the message if available
return in_message

def send(message: Message) -> None:
"""Send task result back to server."""
# Get Node
if node_store[KEY_NODE] is None:
if node is None:
log(ERROR, "Node instance missing")
return

# Get incoming message
in_metadata = state[KEY_METADATA]
if in_metadata is None:
# Get the metadata of the incoming message
nonlocal metadata
if metadata is None:
log(ERROR, "No current message")
return

# Validate out message
if not validate_out_message(message, in_metadata):
if not validate_out_message(message, metadata):
log(ERROR, "Invalid out message")
return

Expand All @@ -197,7 +238,8 @@ def send(message: Message) -> None:
request = PushTaskResRequest(task_res_list=[task_res])
_ = retry_invoker.invoke(stub.PushTaskRes, request)

state[KEY_METADATA] = None
# Cleanup
metadata = None

try:
# Yield methods
Expand Down
72 changes: 72 additions & 0 deletions src/py/flwr/client/heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Heartbeat utility functions."""


import threading
from typing import Callable

import grpc

from flwr.common.constant import PING_CALL_TIMEOUT
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential


def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None:
def wait_fn(wait_time: float) -> None:
if not stop_event.is_set():
stop_event.wait(wait_time)

def on_backoff(state: RetryState) -> None:
err = state.exception
if not isinstance(err, grpc.RpcError):
return
status_code = err.code()
# If ping call timeout is triggered
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
# Avoid long wait time.
if state.actual_wait is None:
return
state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0)

def wrapped_ping() -> None:
if not stop_event.is_set():
ping_fn()

retrier = RetryInvoker(
exponential,
grpc.RpcError,
max_tries=None,
max_time=None,
on_backoff=on_backoff,
wait_function=wait_fn,
)
while not stop_event.is_set():
retrier.invoke(wrapped_ping)


def start_ping_loop(
ping_fn: Callable[[], None], stop_event: threading.Event
) -> threading.Thread:
"""Start a ping loop in a separate thread.
This function initializes a new thread that runs a ping loop, allowing for
asynchronous ping operations. The loop can be terminated through the provided stop
event.
"""
thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event))
thread.start()

return thread
Loading

0 comments on commit 219e610

Please sign in to comment.