diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index ee8271b8eca4..6a4061a72505 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -41,6 +41,7 @@ PING_CALL_TIMEOUT = 5 PING_BASE_MULTIPLIER = 0.8 PING_RANDOM_RANGE = (-0.1, 0.1) +PING_MAX_INTERVAL = 1e300 class MessageType: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 9fa7656198e5..39edd606b464 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -43,7 +43,7 @@ def create_node( ) -> CreateNodeResponse: """.""" # Create node - node_id = state.create_node() + node_id = state.create_node(ping_interval=request.ping_interval) return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 79b3b6ea3937..9c27fca79c12 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -24,7 +24,7 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.common.constant import ErrorCode +from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app @@ -44,7 +44,7 @@ def _register_nodes( nodes_mapping: NodeToPartitionMapping = {} state = state_factory.state() for i in range(num_nodes): - node_id = state.create_node() + node_id = state.create_node(ping_interval=PING_MAX_INTERVAL) nodes_mapping[node_id] = i log(INFO, "Registered %i nodes", len(nodes_mapping)) return nodes_mapping diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 6fc57707ac36..2ce6dcd4599a 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -182,16 +182,14 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) with self.lock: if node_id not in self.node_ids: - # Default ping interval is 30s - # TODO: change 1e9 to 30s # pylint: disable=W0511 - self.node_ids[node_id] = (time.time() + 1e9, 1e9) + self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return node_id log(ERROR, "Unexpected node registration failure.") return 0 diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 6996d51d2a9b..b68d19bd96d9 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -468,7 +468,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: return None - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" # Sample a random int64 as node_id node_id: int = int.from_bytes(os.urandom(8), "little", signed=True) @@ -478,9 +478,7 @@ def create_node(self) -> int: ) try: - # Default ping interval is 30s - # TODO: change 1e9 to 30s # pylint: disable=W0511 - self.query(query, (node_id, time.time() + 1e9, 1e9)) + self.query(query, (node_id, time.time() + ping_interval, ping_interval)) except sqlite3.IntegrityError: log(ERROR, "Unexpected node registration failure.") return 0 diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 313290eb1022..b356cd47befa 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -132,7 +132,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" @abc.abstractmethod - def create_node(self) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in state, and return `node_id`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 1757cfac4255..8e49a380bb16 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -319,7 +319,7 @@ def test_create_node_and_get_nodes(self) -> None: # Execute for _ in range(10): - node_ids.append(state.create_node()) + node_ids.append(state.create_node(ping_interval=10)) retrieved_node_ids = state.get_nodes(run_id) # Assert @@ -331,7 +331,7 @@ def test_delete_node(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run() - node_id = state.create_node() + node_id = state.create_node(ping_interval=10) # Execute state.delete_node(node_id) @@ -346,7 +346,7 @@ def test_get_nodes_invalid_run_id(self) -> None: state: State = self.state_factory() state.create_run() invalid_run_id = 61016 - state.create_node() + state.create_node(ping_interval=10) # Execute retrieved_node_ids = state.get_nodes(invalid_run_id) @@ -399,7 +399,7 @@ def test_acknowledge_ping(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run() - node_ids = [state.create_node() for _ in range(100)] + node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) for node_id in node_ids[70:]: