Skip to content

Commit

Permalink
merge w/ main
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Apr 2, 2024
2 parents e2ad0df + 4b57017 commit d627097
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PING_CALL_TIMEOUT = 5
PING_BASE_MULTIPLIER = 0.8
PING_RANDOM_RANGE = (-0.1, 0.1)
PING_MAX_INTERVAL = 1e300


class MessageType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_node(
) -> CreateNodeResponse:
"""."""
# Create node
node_id = state.create_node()
node_id = state.create_node(ping_interval=request.ping_interval)
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))


Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.client.node_state import NodeState
from flwr.common.constant import ErrorCode
from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
from flwr.common.logger import log
from flwr.common.message import Error
from flwr.common.object_ref import load_app
Expand All @@ -44,7 +44,7 @@ def _register_nodes(
nodes_mapping: NodeToPartitionMapping = {}
state = state_factory.state()
for i in range(num_nodes):
node_id = state.create_node()
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
nodes_mapping[node_id] = i
log(INFO, "Registered %i nodes", len(nodes_mapping))
return nodes_mapping
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,14 @@ def num_task_res(self) -> int:
"""
return len(self.task_res_store)

def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

with self.lock:
if node_id not in self.node_ids:
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return node_id
log(ERROR, "Unexpected node registration failure.")
return 0
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:

return None

def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
Expand All @@ -478,9 +478,7 @@ def create_node(self) -> int:
)

try:
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.query(query, (node_id, time.time() + 1e9, 1e9))
self.query(query, (node_id, time.time() + ping_interval, ping_interval))
except sqlite3.IntegrityError:
log(ERROR, "Unexpected node registration failure.")
return 0
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""

@abc.abstractmethod
def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""

@abc.abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_create_node_and_get_nodes(self) -> None:

# Execute
for _ in range(10):
node_ids.append(state.create_node())
node_ids.append(state.create_node(ping_interval=10))
retrieved_node_ids = state.get_nodes(run_id)

# Assert
Expand All @@ -331,7 +331,7 @@ def test_delete_node(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_id = state.create_node()
node_id = state.create_node(ping_interval=10)

# Execute
state.delete_node(node_id)
Expand All @@ -346,7 +346,7 @@ def test_get_nodes_invalid_run_id(self) -> None:
state: State = self.state_factory()
state.create_run()
invalid_run_id = 61016
state.create_node()
state.create_node(ping_interval=10)

# Execute
retrieved_node_ids = state.get_nodes(invalid_run_id)
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_acknowledge_ping(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_ids = [state.create_node() for _ in range(100)]
node_ids = [state.create_node(ping_interval=10) for _ in range(100)]
for node_id in node_ids[:70]:
state.acknowledge_ping(node_id, ping_interval=30)
for node_id in node_ids[70:]:
Expand Down

0 comments on commit d627097

Please sign in to comment.