Skip to content

Commit

Permalink
merge w/ main
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Mar 27, 2024
2 parents 74202ba + 8326617 commit d8dcd88
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 23 deletions.
51 changes: 36 additions & 15 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import os
import threading
import time
from logging import ERROR
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple
from uuid import UUID, uuid4

from flwr.common import log, now
Expand All @@ -31,7 +32,8 @@ class InMemoryState(State):
"""In-memory State implementation."""

def __init__(self) -> None:
self.node_ids: Set[int] = set()
# Map node_id to (online_until, ping_interval)
self.node_ids: Dict[int, Tuple[float, float]] = {}
self.run_ids: Set[int] = set()
self.task_ins_store: Dict[UUID, TaskIns] = {}
self.task_res_store: Dict[UUID, TaskRes] = {}
Expand Down Expand Up @@ -185,17 +187,21 @@ def create_node(self) -> int:
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

if node_id not in self.node_ids:
self.node_ids.add(node_id)
return node_id
with self.lock:
if node_id not in self.node_ids:
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
return node_id
log(ERROR, "Unexpected node registration failure.")
return 0

def delete_node(self, node_id: int) -> None:
"""Delete a client node."""
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")
self.node_ids.remove(node_id)
with self.lock:
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")
del self.node_ids[node_id]

def get_nodes(self, run_id: int) -> Set[int]:
"""Return all available client nodes.
Expand All @@ -205,17 +211,32 @@ def get_nodes(self, run_id: int) -> Set[int]:
If the provided `run_id` does not exist or has no matching nodes,
an empty `Set` MUST be returned.
"""
if run_id not in self.run_ids:
return set()
return self.node_ids
with self.lock:
if run_id not in self.run_ids:
return set()
current_time = time.time()
return {
node_id
for node_id, (online_until, _) in self.node_ids.items()
if online_until > current_time
}

def create_run(self) -> int:
"""Create one run."""
# Sample a random int64 as run_id
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
with self.lock:
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

if run_id not in self.run_ids:
self.run_ids.add(run_id)
return run_id
if run_id not in self.run_ids:
self.run_ids.add(run_id)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
with self.lock:
if node_id in self.node_ids:
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return True
return False
36 changes: 30 additions & 6 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
import sqlite3
import time
from logging import DEBUG, ERROR
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
from uuid import UUID, uuid4
Expand All @@ -32,10 +33,16 @@

SQL_CREATE_TABLE_NODE = """
CREATE TABLE IF NOT EXISTS node(
node_id INTEGER UNIQUE
node_id INTEGER UNIQUE,
online_until REAL,
ping_interval REAL
);
"""

SQL_CREATE_INDEX_ONLINE_UNTIL = """
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
"""

SQL_CREATE_TABLE_RUN = """
CREATE TABLE IF NOT EXISTS run(
run_id INTEGER UNIQUE
Expand Down Expand Up @@ -83,7 +90,7 @@
);
"""

DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]


class SqliteState(State):
Expand Down Expand Up @@ -124,6 +131,7 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]:
cur.execute(SQL_CREATE_TABLE_TASK_INS)
cur.execute(SQL_CREATE_TABLE_TASK_RES)
cur.execute(SQL_CREATE_TABLE_NODE)
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
res = cur.execute("SELECT name FROM sqlite_schema;")

return res.fetchall()
Expand Down Expand Up @@ -465,9 +473,14 @@ def create_node(self) -> int:
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

query = "INSERT INTO node VALUES(:node_id);"
query = (
"INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)"
)

try:
self.query(query, {"node_id": node_id})
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.query(query, (node_id, time.time() + 1e9, 1e9))
except sqlite3.IntegrityError:
log(ERROR, "Unexpected node registration failure.")
return 0
Expand All @@ -492,8 +505,8 @@ def get_nodes(self, run_id: int) -> Set[int]:
return set()

# Get nodes
query = "SELECT * FROM node;"
rows = self.query(query)
query = "SELECT node_id FROM node WHERE online_until > ?;"
rows = self.query(query, (time.time(),))
result: Set[int] = {row["node_id"] for row in rows}
return result

Expand All @@ -512,6 +525,17 @@ def create_run(self) -> int:
log(ERROR, "Unexpected run creation failure.")
return 0

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
# Update `online_until` and `ping_interval` for the given `node_id`
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
try:
self.query(query, (time.time() + ping_interval, ping_interval, node_id))
return True
except sqlite3.IntegrityError:
log(ERROR, "`node_id` does not exist.")
return False


def dict_factory(
cursor: sqlite3.Cursor,
Expand Down
19 changes: 19 additions & 0 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,22 @@ def get_nodes(self, run_id: int) -> Set[int]:
@abc.abstractmethod
def create_run(self) -> int:
"""Create one run."""

@abc.abstractmethod
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat.
Parameters
----------
node_id : int
The `node_id` from which the ping was received.
ping_interval : float
The interval (in seconds) from the current timestamp within which the next
ping from this node must be received. This acts as a hard deadline to ensure
an accurate assessment of the node's availability.
Returns
-------
is_acknowledged : bool
True if the ping is successfully acknowledged; otherwise, False.
"""
24 changes: 22 additions & 2 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from abc import abstractmethod
from datetime import datetime, timezone
from typing import List
from unittest.mock import patch
from uuid import uuid4

from flwr.common import DEFAULT_TTL
Expand Down Expand Up @@ -393,6 +394,25 @@ def test_num_task_res(self) -> None:
# Assert
assert num == 2

def test_acknowledge_ping(self) -> None:
"""Test if acknowledge_ping works and if get_nodes return online nodes."""
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_ids = [state.create_node() for _ in range(100)]
for node_id in node_ids[:70]:
state.acknowledge_ping(node_id, ping_interval=30)
for node_id in node_ids[70:]:
state.acknowledge_ping(node_id, ping_interval=90)

# Execute
current_time = time.time()
with patch("time.time", side_effect=lambda: current_time + 50):
actual_node_ids = state.get_nodes(run_id)

# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))


def create_task_ins(
consumer_node_id: int,
Expand Down Expand Up @@ -478,7 +498,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 8
assert len(result) == 9


class SqliteFileBasedTest(StateTest, unittest.TestCase):
Expand All @@ -503,7 +523,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 8
assert len(result) == 9


if __name__ == "__main__":
Expand Down

0 comments on commit d8dcd88

Please sign in to comment.