diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 0fd0df871785..3f1d2c38cb10 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -131,8 +131,6 @@ def ping() -> None: # Get Node if node is None: log(ERROR, "Node instance missing") - if not ping_stop_event.is_set(): - ping_stop_event.wait(PING_CALL_TIMEOUT) return # Construct the ping request diff --git a/src/py/flwr/client/heartbeat.py b/src/py/flwr/client/heartbeat.py index 60d7ec92bec2..0cc979ddfd13 100644 --- a/src/py/flwr/client/heartbeat.py +++ b/src/py/flwr/client/heartbeat.py @@ -20,13 +20,41 @@ import grpc +from flwr.common.constant import PING_CALL_TIMEOUT +from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential + def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None: - while not stop_event.is_set(): - try: + def wait_fn(wait_time: float) -> None: + if not stop_event.is_set(): + stop_event.wait(wait_time) + + def on_backoff(state: RetryState) -> None: + err = state.exception + if not isinstance(err, grpc.RpcError): + return + status_code = err.code() + # If ping call timeout is triggered + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + # Avoid long wait time. + if state.actual_wait is None: + return + state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0) + + def wrapped_ping() -> None: + if not stop_event.is_set(): ping_fn() - except grpc.RpcError: - pass + + retrier = RetryInvoker( + exponential, + grpc.RpcError, + max_tries=None, + max_time=None, + on_backoff=on_backoff, + wait_function=wait_fn, + ) + while not stop_event.is_set(): + retrier.invoke(wrapped_ping) def start_ping_loop( diff --git a/src/py/flwr/client/heartbeat_test.py b/src/py/flwr/client/heartbeat_test.py new file mode 100644 index 000000000000..286429e075b1 --- /dev/null +++ b/src/py/flwr/client/heartbeat_test.py @@ -0,0 +1,59 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for heartbeat utility functions.""" + + +import threading +import time +import unittest +from unittest.mock import MagicMock + +from .heartbeat import start_ping_loop + + +class TestStartPingLoopWithFailures(unittest.TestCase): + """Test heartbeat utility functions.""" + + def test_ping_loop_terminates(self) -> None: + """Test if the ping loop thread terminates when flagged.""" + # Prepare + ping_fn = MagicMock() + stop_event = threading.Event() + + # Execute + thread = start_ping_loop(ping_fn, stop_event) + time.sleep(1) + stop_event.set() + thread.join(timeout=1) + + # Assert + self.assertTrue(ping_fn.called) + self.assertFalse(thread.is_alive()) + + def test_ping_loop_with_failures_terminates(self) -> None: + """Test if the ping loop thread with failures terminates when flagged.""" + # Prepare + ping_fn = MagicMock(side_effect=RuntimeError()) + stop_event = threading.Event() + + # Execute + thread = start_ping_loop(ping_fn, stop_event) + time.sleep(1) + stop_event.set() + thread.join(timeout=1) + + # Assert + self.assertTrue(ping_fn.called) + self.assertFalse(thread.is_alive()) diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 92fc9a12751f..263094120059 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -148,8 +148,6 @@ def ping() -> None: # Get Node if node is None: log(ERROR, "Node instance missing") - if not ping_stop_event.is_set(): - ping_stop_event.wait(PING_CALL_TIMEOUT) return # Construct the ping request @@ -170,8 +168,6 @@ def ping() -> None: # Check status code and headers if res.status_code != 200: - if not ping_stop_event.is_set(): - ping_stop_event.wait(PING_CALL_TIMEOUT) return if "content-type" not in res.headers: log( @@ -179,8 +175,6 @@ def ping() -> None: "[Node] POST /%s: missing header `Content-Type`", PATH_PULL_TASK_INS, ) - if not ping_stop_event.is_set(): - ping_stop_event.wait(PING_CALL_TIMEOUT) return if res.headers["content-type"] != "application/protobuf": log( @@ -188,8 +182,6 @@ def ping() -> None: "[Node] POST /%s: header `Content-Type` has wrong value", PATH_PULL_TASK_INS, ) - if not ping_stop_event.is_set(): - ping_stop_event.wait(PING_CALL_TIMEOUT) return # Deserialize ProtoBuf from bytes