Skip to content

Commit

Permalink
Use daemon thread and replace time.sleep with Event.wait (#3224)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Apr 10, 2024
1 parent 6e0dbdb commit af1c375
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/py/flwr/client/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def start_ping_loop(
asynchronous ping operations. The loop can be terminated through the provided stop
event.
"""
thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event))
thread = threading.Thread(
target=_ping_loop, args=(ping_fn, stop_event), daemon=True
)
thread.start()

return thread
5 changes: 3 additions & 2 deletions src/py/flwr/server/compat/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import threading
import time
from typing import Dict, Tuple

from ..client_manager import ClientManager
Expand Down Expand Up @@ -60,6 +59,7 @@ def start_update_client_manager_thread(
client_manager,
f_stop,
),
daemon=True,
)
thread.start()

Expand Down Expand Up @@ -99,4 +99,5 @@ def _update_client_manager(
raise RuntimeError("Could not register node.")

# Sleep for 3 seconds
time.sleep(3)
if not f_stop.is_set():
f_stop.wait(3)
39 changes: 24 additions & 15 deletions src/py/flwr/server/compat/app_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import time
import unittest
from threading import Event
from typing import Optional
from unittest.mock import Mock, patch

from ..client_manager import SimpleClientManager
Expand All @@ -29,30 +31,37 @@ class TestUtils(unittest.TestCase):
def test_start_update_client_manager_thread(self) -> None:
"""Test start_update_client_manager_thread function."""
# Prepare
sleep = time.sleep
sleep_patch = patch("time.sleep", lambda x: sleep(x / 100))
sleep_patch.start()
expected_node_ids = list(range(100))
updated_expected_node_ids = list(range(80, 120))
driver = Mock()
driver.grpc_driver = Mock()
driver.run_id = 123
driver.get_node_ids.return_value = expected_node_ids
client_manager = SimpleClientManager()
original_wait = Event.wait

def custom_wait(self: Event, timeout: Optional[float] = None) -> None:
if timeout is not None:
timeout /= 100
original_wait(self, timeout)

# Execute
thread, f_stop = start_update_client_manager_thread(driver, client_manager)
# Wait until all nodes are registered via `client_manager.sample()`
client_manager.sample(len(expected_node_ids))
# Retrieve all nodes in `client_manager`
node_ids = {proxy.node_id for proxy in client_manager.all().values()}
# Update the GetNodesResponse and wait until the `client_manager` is updated
driver.get_node_ids.return_value = updated_expected_node_ids
sleep(0.1)
# Retrieve all nodes in `client_manager`
updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()}
# Stop the thread
f_stop.set()
# Patching Event.wait with our custom function
with patch.object(Event, "wait", new=custom_wait):
thread, f_stop = start_update_client_manager_thread(driver, client_manager)
# Wait until all nodes are registered via `client_manager.sample()`
client_manager.sample(len(expected_node_ids))
# Retrieve all nodes in `client_manager`
node_ids = {proxy.node_id for proxy in client_manager.all().values()}
# Update the GetNodesResponse and wait until the `client_manager` is updated
driver.get_node_ids.return_value = updated_expected_node_ids
time.sleep(0.1)
# Retrieve all nodes in `client_manager`
updated_node_ids = {
proxy.node_id for proxy in client_manager.all().values()
}
# Stop the thread
f_stop.set()

# Assert
assert node_ids == set(expected_node_ids)
Expand Down

0 comments on commit af1c375

Please sign in to comment.