From 9842e41615a6ae8fb47c3aac78473bf31d8cb368 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 1 Apr 2024 18:03:22 +0100 Subject: [PATCH] Send ping from SuperNode (#3181) --- .../client/grpc_rere_client/connection.py | 94 +++++++++---- src/py/flwr/client/heartbeat.py | 72 ++++++++++ src/py/flwr/client/heartbeat_test.py | 59 ++++++++ src/py/flwr/client/rest_client/connection.py | 128 ++++++++++++++---- src/py/flwr/common/constant.py | 6 + .../fleet/message_handler/message_handler.py | 3 +- .../superlink/fleet/rest_rere/rest_api.py | 28 ++++ 7 files changed, 337 insertions(+), 53 deletions(-) create mode 100644 src/py/flwr/client/heartbeat.py create mode 100644 src/py/flwr/client/heartbeat_test.py diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index e6e22998b947..06573ffaafb7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -15,15 +15,24 @@ """Contextmanager for a gRPC request-response channel to the Flower server.""" +import random +import threading from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Tuple, Union, cast +from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.constant import ( + PING_BASE_MULTIPLIER, + PING_CALL_TIMEOUT, + PING_DEFAULT_INTERVAL, + PING_RANDOM_RANGE, +) from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature from flwr.common.message import Message, Metadata @@ -32,6 +41,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, + PingRequest, + PingResponse, PullTaskInsRequest, PushTaskResRequest, ) @@ -39,9 +50,6 @@ from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -KEY_NODE = "node" -KEY_METADATA = "in_message_metadata" - def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" @@ -49,7 +57,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_request_response( +def grpc_request_response( # pylint: disable=R0914, R0915 server_address: str, insecure: bool, retry_invoker: RetryInvoker, @@ -107,47 +115,81 @@ def grpc_request_response( max_message_length=max_message_length, ) channel.subscribe(on_channel_state_change) - stub = FleetStub(channel) - - # Necessary state to validate messages to be sent - state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} - # Enable create_node and delete_node to store node - node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} + # Shared variables for inner functions + stub = FleetStub(channel) + metadata: Optional[Metadata] = None + node: Optional[Node] = None + ping_thread: Optional[threading.Thread] = None + ping_stop_event = threading.Event() ########################################################################### - # receive/send functions + # ping/create_node/delete_node/receive/send functions ########################################################################### + def ping() -> None: + # Get Node + if node is None: + log(ERROR, "Node instance missing") + return + + # Construct the ping request + req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL) + + # Call FleetAPI + res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT) + + # Check if success + if not res.success: + raise RuntimeError("Ping failed unexpectedly.") + + # Wait + rd = random.uniform(*PING_RANDOM_RANGE) + next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT + next_interval *= PING_BASE_MULTIPLIER + rd + if not ping_stop_event.is_set(): + ping_stop_event.wait(next_interval) + def create_node() -> None: """Set create_node.""" + # Call FleetAPI create_node_request = CreateNodeRequest() create_node_response = retry_invoker.invoke( stub.CreateNode, request=create_node_request, ) - node_store[KEY_NODE] = create_node_response.node + + # Remember the node and the ping-loop thread + nonlocal node, ping_thread + node = cast(Node, create_node_response.node) + ping_thread = start_ping_loop(ping, ping_stop_event) def delete_node() -> None: """Set delete_node.""" # Get Node - if node_store[KEY_NODE] is None: + nonlocal node + if node is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) + # Stop the ping-loop thread + ping_stop_event.set() + if ping_thread is not None: + ping_thread.join() + + # Call FleetAPI delete_node_request = DeleteNodeRequest(node=node) retry_invoker.invoke(stub.DeleteNode, request=delete_node_request) - del node_store[KEY_NODE] + # Cleanup + node = None def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node - if node_store[KEY_NODE] is None: + if node is None: log(ERROR, "Node instance missing") return None - node: Node = cast(Node, node_store[KEY_NODE]) # Request instructions (task) from server request = PullTaskInsRequest(node=node) @@ -167,7 +209,8 @@ def receive() -> Optional[Message]: in_message = message_from_taskins(task_ins) if task_ins else None # Remember `metadata` of the in message - state[KEY_METADATA] = copy(in_message.metadata) if in_message else None + nonlocal metadata + metadata = copy(in_message.metadata) if in_message else None # Return the message if available return in_message @@ -175,18 +218,18 @@ def receive() -> Optional[Message]: def send(message: Message) -> None: """Send task result back to server.""" # Get Node - if node_store[KEY_NODE] is None: + if node is None: log(ERROR, "Node instance missing") return - # Get incoming message - in_metadata = state[KEY_METADATA] - if in_metadata is None: + # Get the metadata of the incoming message + nonlocal metadata + if metadata is None: log(ERROR, "No current message") return # Validate out message - if not validate_out_message(message, in_metadata): + if not validate_out_message(message, metadata): log(ERROR, "Invalid out message") return @@ -197,7 +240,8 @@ def send(message: Message) -> None: request = PushTaskResRequest(task_res_list=[task_res]) _ = retry_invoker.invoke(stub.PushTaskRes, request) - state[KEY_METADATA] = None + # Cleanup + metadata = None try: # Yield methods diff --git a/src/py/flwr/client/heartbeat.py b/src/py/flwr/client/heartbeat.py new file mode 100644 index 000000000000..0cc979ddfd13 --- /dev/null +++ b/src/py/flwr/client/heartbeat.py @@ -0,0 +1,72 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Heartbeat utility functions.""" + + +import threading +from typing import Callable + +import grpc + +from flwr.common.constant import PING_CALL_TIMEOUT +from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential + + +def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None: + def wait_fn(wait_time: float) -> None: + if not stop_event.is_set(): + stop_event.wait(wait_time) + + def on_backoff(state: RetryState) -> None: + err = state.exception + if not isinstance(err, grpc.RpcError): + return + status_code = err.code() + # If ping call timeout is triggered + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + # Avoid long wait time. + if state.actual_wait is None: + return + state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0) + + def wrapped_ping() -> None: + if not stop_event.is_set(): + ping_fn() + + retrier = RetryInvoker( + exponential, + grpc.RpcError, + max_tries=None, + max_time=None, + on_backoff=on_backoff, + wait_function=wait_fn, + ) + while not stop_event.is_set(): + retrier.invoke(wrapped_ping) + + +def start_ping_loop( + ping_fn: Callable[[], None], stop_event: threading.Event +) -> threading.Thread: + """Start a ping loop in a separate thread. + + This function initializes a new thread that runs a ping loop, allowing for + asynchronous ping operations. The loop can be terminated through the provided stop + event. + """ + thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event)) + thread.start() + + return thread diff --git a/src/py/flwr/client/heartbeat_test.py b/src/py/flwr/client/heartbeat_test.py new file mode 100644 index 000000000000..286429e075b1 --- /dev/null +++ b/src/py/flwr/client/heartbeat_test.py @@ -0,0 +1,59 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for heartbeat utility functions.""" + + +import threading +import time +import unittest +from unittest.mock import MagicMock + +from .heartbeat import start_ping_loop + + +class TestStartPingLoopWithFailures(unittest.TestCase): + """Test heartbeat utility functions.""" + + def test_ping_loop_terminates(self) -> None: + """Test if the ping loop thread terminates when flagged.""" + # Prepare + ping_fn = MagicMock() + stop_event = threading.Event() + + # Execute + thread = start_ping_loop(ping_fn, stop_event) + time.sleep(1) + stop_event.set() + thread.join(timeout=1) + + # Assert + self.assertTrue(ping_fn.called) + self.assertFalse(thread.is_alive()) + + def test_ping_loop_with_failures_terminates(self) -> None: + """Test if the ping loop thread with failures terminates when flagged.""" + # Prepare + ping_fn = MagicMock(side_effect=RuntimeError()) + stop_event = threading.Event() + + # Execute + thread = start_ping_loop(ping_fn, stop_event) + time.sleep(1) + stop_event.set() + thread.join(timeout=1) + + # Assert + self.assertTrue(ping_fn.called) + self.assertFalse(thread.is_alive()) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d2cc71ba3b3f..514635103f01 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -15,16 +15,25 @@ """Contextmanager for a REST request-response channel to the Flower server.""" +import random import sys +import threading from contextlib import contextmanager from copy import copy from logging import ERROR, INFO, WARN -from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Tuple, Union +from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH -from flwr.common.constant import MISSING_EXTRA_REST +from flwr.common.constant import ( + MISSING_EXTRA_REST, + PING_BASE_MULTIPLIER, + PING_CALL_TIMEOUT, + PING_DEFAULT_INTERVAL, + PING_RANDOM_RANGE, +) from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker @@ -33,6 +42,8 @@ CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, + PingRequest, + PingResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -47,19 +58,15 @@ sys.exit(MISSING_EXTRA_REST) -KEY_NODE = "node" -KEY_METADATA = "in_message_metadata" - - PATH_CREATE_NODE: str = "api/v0/fleet/create-node" PATH_DELETE_NODE: str = "api/v0/fleet/delete-node" PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins" PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res" +PATH_PING: str = "api/v0/fleet/ping" @contextmanager -# pylint: disable-next=too-many-statements -def http_request_response( +def http_request_response( # pylint: disable=R0914, R0915 server_address: str, insecure: bool, # pylint: disable=unused-argument retry_invoker: RetryInvoker, @@ -127,16 +134,71 @@ def http_request_response( "must be provided as a string path to the client.", ) - # Necessary state to validate messages to be sent - state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None} - - # Enable create_node and delete_node to store node - node_store: Dict[str, Optional[Node]] = {KEY_NODE: None} + # Shared variables for inner functions + metadata: Optional[Metadata] = None + node: Optional[Node] = None + ping_thread: Optional[threading.Thread] = None + ping_stop_event = threading.Event() ########################################################################### - # receive/send functions + # ping/create_node/delete_node/receive/send functions ########################################################################### + def ping() -> None: + # Get Node + if node is None: + log(ERROR, "Node instance missing") + return + + # Construct the ping request + req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL) + req_bytes: bytes = req.SerializeToString() + + # Send the request + res = requests.post( + url=f"{base_url}/{PATH_PING}", + headers={ + "Accept": "application/protobuf", + "Content-Type": "application/protobuf", + }, + data=req_bytes, + verify=verify, + timeout=PING_CALL_TIMEOUT, + ) + + # Check status code and headers + if res.status_code != 200: + return + if "content-type" not in res.headers: + log( + WARN, + "[Node] POST /%s: missing header `Content-Type`", + PATH_PULL_TASK_INS, + ) + return + if res.headers["content-type"] != "application/protobuf": + log( + WARN, + "[Node] POST /%s: header `Content-Type` has wrong value", + PATH_PULL_TASK_INS, + ) + return + + # Deserialize ProtoBuf from bytes + ping_res = PingResponse() + ping_res.ParseFromString(res.content) + + # Check if success + if not ping_res.success: + raise RuntimeError("Ping failed unexpectedly.") + + # Wait + rd = random.uniform(*PING_RANDOM_RANGE) + next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT + next_interval *= PING_BASE_MULTIPLIER + rd + if not ping_stop_event.is_set(): + ping_stop_event.wait(next_interval) + def create_node() -> None: """Set create_node.""" create_node_req_proto = CreateNodeRequest() @@ -175,15 +237,25 @@ def create_node() -> None: # Deserialize ProtoBuf from bytes create_node_response_proto = CreateNodeResponse() create_node_response_proto.ParseFromString(res.content) - # pylint: disable-next=no-member - node_store[KEY_NODE] = create_node_response_proto.node + + # Remember the node and the ping-loop thread + nonlocal node, ping_thread + node = create_node_response_proto.node + ping_thread = start_ping_loop(ping, ping_stop_event) def delete_node() -> None: """Set delete_node.""" - if node_store[KEY_NODE] is None: + nonlocal node + if node is None: log(ERROR, "Node instance missing") return - node: Node = cast(Node, node_store[KEY_NODE]) + + # Stop the ping-loop thread + ping_stop_event.set() + if ping_thread is not None: + ping_thread.join() + + # Send DeleteNode request delete_node_req_proto = DeleteNodeRequest(node=node) delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString() res = retry_invoker.invoke( @@ -215,13 +287,15 @@ def delete_node() -> None: PATH_PULL_TASK_INS, ) + # Cleanup + node = None + def receive() -> Optional[Message]: """Receive next task from server.""" # Get Node - if node_store[KEY_NODE] is None: + if node is None: log(ERROR, "Node instance missing") return None - node: Node = cast(Node, node_store[KEY_NODE]) # Request instructions (task) from server pull_task_ins_req_proto = PullTaskInsRequest(node=node) @@ -273,29 +347,29 @@ def receive() -> Optional[Message]: task_ins = None # Return the Message if available + nonlocal metadata message = None - state[KEY_METADATA] = None if task_ins is not None: message = message_from_taskins(task_ins) - state[KEY_METADATA] = copy(message.metadata) + metadata = copy(message.metadata) log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS) return message def send(message: Message) -> None: """Send task result back to server.""" # Get Node - if node_store[KEY_NODE] is None: + if node is None: log(ERROR, "Node instance missing") return # Get incoming message - in_metadata = state[KEY_METADATA] - if in_metadata is None: + nonlocal metadata + if metadata is None: log(ERROR, "No current message") return # Validate out message - if not validate_out_message(message, in_metadata): + if not validate_out_message(message, metadata): log(ERROR, "Invalid out message") return @@ -321,7 +395,7 @@ def send(message: Message) -> None: timeout=None, ) - state[KEY_METADATA] = None + metadata = None # Check status code and headers if res.status_code != 200: diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 7d30a10f5881..99ba2d1d1c63 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -36,6 +36,12 @@ TRANSPORT_TYPE_VCE, ] +# Constants for ping +PING_DEFAULT_INTERVAL = 30 +PING_CALL_TIMEOUT = 5 +PING_BASE_MULTIPLIER = 0.8 +PING_RANDOM_RANGE = (-0.1, 0.1) + class MessageType: """Message type.""" diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index d4e63a8f2d46..9fa7656198e5 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -63,7 +63,8 @@ def ping( state: State, # pylint: disable=unused-argument ) -> PingResponse: """.""" - return PingResponse(success=True) + res = state.acknowledge_ping(request.node.node_id, request.ping_interval) + return PingResponse(success=res) def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse: diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index b022b34c68c8..33d17ef1d579 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -21,6 +21,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, + PingRequest, PullTaskInsRequest, PushTaskResRequest, ) @@ -152,11 +153,38 @@ async def push_task_res(request: Request) -> Response: # Check if token is need ) +async def ping(request: Request) -> Response: + """Ping.""" + _check_headers(request.headers) + + # Get the request body as raw bytes + ping_request_bytes: bytes = await request.body() + + # Deserialize ProtoBuf + ping_request_proto = PingRequest() + ping_request_proto.ParseFromString(ping_request_bytes) + + # Get state from app + state: State = app.state.STATE_FACTORY.state() + + # Handle message + ping_response_proto = message_handler.ping(request=ping_request_proto, state=state) + + # Return serialized ProtoBuf + ping_response_bytes = ping_response_proto.SerializeToString() + return Response( + status_code=200, + content=ping_response_bytes, + headers={"Content-Type": "application/protobuf"}, + ) + + routes = [ Route("/api/v0/fleet/create-node", create_node, methods=["POST"]), Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]), Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]), Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]), + Route("/api/v0/fleet/ping", ping, methods=["POST"]), ] app: Starlette = Starlette(