diff --git a/src/py/flwr/client/heartbeat.py b/src/py/flwr/client/heartbeat.py index 0cc979ddfd13..b68e6163cc01 100644 --- a/src/py/flwr/client/heartbeat.py +++ b/src/py/flwr/client/heartbeat.py @@ -66,7 +66,9 @@ def start_ping_loop( asynchronous ping operations. The loop can be terminated through the provided stop event. """ - thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event)) + thread = threading.Thread( + target=_ping_loop, args=(ping_fn, stop_event), daemon=True + ) thread.start() return thread diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index 696ec1132c4a..709be3d96a33 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -16,7 +16,6 @@ import threading -import time from typing import Dict, Tuple from ..client_manager import ClientManager @@ -60,6 +59,7 @@ def start_update_client_manager_thread( client_manager, f_stop, ), + daemon=True, ) thread.start() @@ -99,4 +99,5 @@ def _update_client_manager( raise RuntimeError("Could not register node.") # Sleep for 3 seconds - time.sleep(3) + if not f_stop.is_set(): + f_stop.wait(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py index 7e47e6eaaf32..023d65b0dc72 100644 --- a/src/py/flwr/server/compat/app_utils_test.py +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -17,6 +17,8 @@ import time import unittest +from threading import Event +from typing import Optional from unittest.mock import Mock, patch from ..client_manager import SimpleClientManager @@ -29,9 +31,6 @@ class TestUtils(unittest.TestCase): def test_start_update_client_manager_thread(self) -> None: """Test start_update_client_manager_thread function.""" # Prepare - sleep = time.sleep - sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) - sleep_patch.start() expected_node_ids = list(range(100)) updated_expected_node_ids = list(range(80, 120)) driver = Mock() @@ -39,20 +38,30 @@ def test_start_update_client_manager_thread(self) -> None: driver.run_id = 123 driver.get_node_ids.return_value = expected_node_ids client_manager = SimpleClientManager() + original_wait = Event.wait + + def custom_wait(self: Event, timeout: Optional[float] = None) -> None: + if timeout is not None: + timeout /= 100 + original_wait(self, timeout) # Execute - thread, f_stop = start_update_client_manager_thread(driver, client_manager) - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_node_ids)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_node_ids.return_value = updated_expected_node_ids - sleep(0.1) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Stop the thread - f_stop.set() + # Patching Event.wait with our custom function + with patch.object(Event, "wait", new=custom_wait): + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + time.sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = { + proxy.node_id for proxy in client_manager.all().values() + } + # Stop the thread + f_stop.set() # Assert assert node_ids == set(expected_node_ids)