Skip to content

Commit

Permalink
Send ping from SuperNode (#3181)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Apr 1, 2024
1 parent 41b491b commit 9842e41
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 53 deletions.
94 changes: 69 additions & 25 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
"""Contextmanager for a gRPC request-response channel to the Flower server."""


import random
import threading
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.client.heartbeat import start_ping_loop
from flwr.client.message_handler.message_handler import validate_out_message
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.constant import (
PING_BASE_MULTIPLIER,
PING_CALL_TIMEOUT,
PING_DEFAULT_INTERVAL,
PING_RANDOM_RANGE,
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.message import Message, Metadata
Expand All @@ -32,24 +41,23 @@
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
PingRequest,
PingResponse,
PullTaskInsRequest,
PushTaskResRequest,
)
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

KEY_NODE = "node"
KEY_METADATA = "in_message_metadata"


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)


@contextmanager
def grpc_request_response(
def grpc_request_response( # pylint: disable=R0914, R0915
server_address: str,
insecure: bool,
retry_invoker: RetryInvoker,
Expand Down Expand Up @@ -107,47 +115,81 @@ def grpc_request_response(
max_message_length=max_message_length,
)
channel.subscribe(on_channel_state_change)
stub = FleetStub(channel)

# Necessary state to validate messages to be sent
state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}

# Enable create_node and delete_node to store node
node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
# Shared variables for inner functions
stub = FleetStub(channel)
metadata: Optional[Metadata] = None
node: Optional[Node] = None
ping_thread: Optional[threading.Thread] = None
ping_stop_event = threading.Event()

###########################################################################
# receive/send functions
# ping/create_node/delete_node/receive/send functions
###########################################################################

def ping() -> None:
# Get Node
if node is None:
log(ERROR, "Node instance missing")
return

# Construct the ping request
req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)

# Call FleetAPI
res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)

# Check if success
if not res.success:
raise RuntimeError("Ping failed unexpectedly.")

# Wait
rd = random.uniform(*PING_RANDOM_RANGE)
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
next_interval *= PING_BASE_MULTIPLIER + rd
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)

def create_node() -> None:
"""Set create_node."""
# Call FleetAPI
create_node_request = CreateNodeRequest()
create_node_response = retry_invoker.invoke(
stub.CreateNode,
request=create_node_request,
)
node_store[KEY_NODE] = create_node_response.node

# Remember the node and the ping-loop thread
nonlocal node, ping_thread
node = cast(Node, create_node_response.node)
ping_thread = start_ping_loop(ping, ping_stop_event)

def delete_node() -> None:
"""Set delete_node."""
# Get Node
if node_store[KEY_NODE] is None:
nonlocal node
if node is None:
log(ERROR, "Node instance missing")
return
node: Node = cast(Node, node_store[KEY_NODE])

# Stop the ping-loop thread
ping_stop_event.set()
if ping_thread is not None:
ping_thread.join()

# Call FleetAPI
delete_node_request = DeleteNodeRequest(node=node)
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)

del node_store[KEY_NODE]
# Cleanup
node = None

def receive() -> Optional[Message]:
"""Receive next task from server."""
# Get Node
if node_store[KEY_NODE] is None:
if node is None:
log(ERROR, "Node instance missing")
return None
node: Node = cast(Node, node_store[KEY_NODE])

# Request instructions (task) from server
request = PullTaskInsRequest(node=node)
Expand All @@ -167,26 +209,27 @@ def receive() -> Optional[Message]:
in_message = message_from_taskins(task_ins) if task_ins else None

# Remember `metadata` of the in message
state[KEY_METADATA] = copy(in_message.metadata) if in_message else None
nonlocal metadata
metadata = copy(in_message.metadata) if in_message else None

# Return the message if available
return in_message

def send(message: Message) -> None:
"""Send task result back to server."""
# Get Node
if node_store[KEY_NODE] is None:
if node is None:
log(ERROR, "Node instance missing")
return

# Get incoming message
in_metadata = state[KEY_METADATA]
if in_metadata is None:
# Get the metadata of the incoming message
nonlocal metadata
if metadata is None:
log(ERROR, "No current message")
return

# Validate out message
if not validate_out_message(message, in_metadata):
if not validate_out_message(message, metadata):
log(ERROR, "Invalid out message")
return

Expand All @@ -197,7 +240,8 @@ def send(message: Message) -> None:
request = PushTaskResRequest(task_res_list=[task_res])
_ = retry_invoker.invoke(stub.PushTaskRes, request)

state[KEY_METADATA] = None
# Cleanup
metadata = None

try:
# Yield methods
Expand Down
72 changes: 72 additions & 0 deletions src/py/flwr/client/heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Heartbeat utility functions."""


import threading
from typing import Callable

import grpc

from flwr.common.constant import PING_CALL_TIMEOUT
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential


def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None:
def wait_fn(wait_time: float) -> None:
if not stop_event.is_set():
stop_event.wait(wait_time)

def on_backoff(state: RetryState) -> None:
err = state.exception
if not isinstance(err, grpc.RpcError):
return
status_code = err.code()
# If ping call timeout is triggered
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
# Avoid long wait time.
if state.actual_wait is None:
return
state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0)

def wrapped_ping() -> None:
if not stop_event.is_set():
ping_fn()

retrier = RetryInvoker(
exponential,
grpc.RpcError,
max_tries=None,
max_time=None,
on_backoff=on_backoff,
wait_function=wait_fn,
)
while not stop_event.is_set():
retrier.invoke(wrapped_ping)


def start_ping_loop(
ping_fn: Callable[[], None], stop_event: threading.Event
) -> threading.Thread:
"""Start a ping loop in a separate thread.
This function initializes a new thread that runs a ping loop, allowing for
asynchronous ping operations. The loop can be terminated through the provided stop
event.
"""
thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event))
thread.start()

return thread
59 changes: 59 additions & 0 deletions src/py/flwr/client/heartbeat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for heartbeat utility functions."""


import threading
import time
import unittest
from unittest.mock import MagicMock

from .heartbeat import start_ping_loop


class TestStartPingLoopWithFailures(unittest.TestCase):
"""Test heartbeat utility functions."""

def test_ping_loop_terminates(self) -> None:
"""Test if the ping loop thread terminates when flagged."""
# Prepare
ping_fn = MagicMock()
stop_event = threading.Event()

# Execute
thread = start_ping_loop(ping_fn, stop_event)
time.sleep(1)
stop_event.set()
thread.join(timeout=1)

# Assert
self.assertTrue(ping_fn.called)
self.assertFalse(thread.is_alive())

def test_ping_loop_with_failures_terminates(self) -> None:
"""Test if the ping loop thread with failures terminates when flagged."""
# Prepare
ping_fn = MagicMock(side_effect=RuntimeError())
stop_event = threading.Event()

# Execute
thread = start_ping_loop(ping_fn, stop_event)
time.sleep(1)
stop_event.set()
thread.join(timeout=1)

# Assert
self.assertTrue(ping_fn.called)
self.assertFalse(thread.is_alive())
Loading

0 comments on commit 9842e41

Please sign in to comment.