Skip to content

Commit

Permalink
feat(framework) Add get_serverapp_context and `set_serverapp_contex…
Browse files Browse the repository at this point in the history
…t` to `LinkState` (#4365)
  • Loading branch information
panh99 authored Oct 24, 2024
1 parent 8c449f5 commit cc61019
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 16 deletions.
13 changes: 12 additions & 1 deletion src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Optional
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common import Context, log, now
from flwr.common.constant import (
MESSAGE_TTL_TOLERANCE,
NODE_ID_NUM_BYTES,
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(self) -> None:

# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
self.contexts: dict[int, Context] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -500,3 +501,13 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return True
return False

def get_serverapp_context(self, run_id: int) -> Optional[Context]:
"""Get the context for the specified `run_id`."""
return self.contexts.get(run_id)

def set_serverapp_context(self, run_id: int, context: Context) -> None:
"""Set the context for the specified `run_id`."""
if run_id not in self.run_ids:
raise ValueError(f"Run {run_id} not found")
self.contexts[run_id] = context
29 changes: 29 additions & 0 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional
from uuid import UUID

from flwr.common import Context
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

Expand Down Expand Up @@ -270,3 +271,31 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
is_acknowledged : bool
True if the ping is successfully acknowledged; otherwise, False.
"""

@abc.abstractmethod
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
"""Get the context for the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run for which to retrieve the context.
Returns
-------
Optional[Context]
The context associated with the specified `run_id`, or `None` if no context
exists for the given `run_id`.
"""

@abc.abstractmethod
def set_serverapp_context(self, run_id: int, context: Context) -> None:
"""Set the context for the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run for which to set the context.
context : Context
The context to be associated with the specified `run_id`.
"""
56 changes: 48 additions & 8 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@
from unittest.mock import patch
from uuid import UUID

from flwr.common import DEFAULT_TTL
from flwr.common import DEFAULT_TTL, Context, RecordSet
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
private_key_to_bytes,
public_key_to_bytes,
)
from flwr.common.typing import RunStatus
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

# pylint: disable=E0611
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes

# pylint: enable=E0611
from flwr.server.superlink.linkstate import (
InMemoryLinkState,
LinkState,
Expand Down Expand Up @@ -998,6 +1002,42 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None:
# Assert
assert task_res_uuid is None

def test_get_set_serverapp_context(self) -> None:
"""Test get and set serverapp context."""
# Prepare
state: LinkState = self.state_factory()
context = Context(
node_id=0,
node_config={"mock": "mock"},
state=RecordSet(),
run_config={"test": "test"},
)
run_id = state.create_run(None, None, "9f86d08", {})

# Execute
init = state.get_serverapp_context(run_id)
state.set_serverapp_context(run_id, context)
retrieved_context = state.get_serverapp_context(run_id)

# Assert
assert init is None
assert retrieved_context == context

def test_set_context_invalid_run_id(self) -> None:
"""Test set_serverapp_context with invalid run_id."""
# Prepare
state: LinkState = self.state_factory()
context = Context(
node_id=0,
node_config={"mock": "mock"},
state=RecordSet(),
run_config={"test": "test"},
)

# Execute and assert
with self.assertRaises(ValueError):
state.set_serverapp_context(61016, context) # Invalid run_id


def create_task_ins(
consumer_node_id: int,
Expand All @@ -1019,7 +1059,7 @@ def create_task_ins(
producer=Node(node_id=0, anonymous=True),
consumer=consumer,
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
recordset=ProtoRecordSet(parameters={}, metrics={}, configs={}),
ttl=DEFAULT_TTL,
created_at=time.time(),
),
Expand All @@ -1044,7 +1084,7 @@ def create_task_res(
consumer=Node(node_id=0, anonymous=True),
ancestry=ancestry,
task_type="mock",
recordset=RecordSet(parameters={}, metrics={}, configs={}),
recordset=ProtoRecordSet(parameters={}, metrics={}, configs={}),
ttl=DEFAULT_TTL,
created_at=time.time(),
),
Expand Down Expand Up @@ -1083,7 +1123,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 13
assert len(result) == 15


class SqliteFileBasedTest(StateTest, unittest.TestCase):
Expand All @@ -1108,7 +1148,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 13
assert len(result) == 15


if __name__ == "__main__":
Expand Down
57 changes: 51 additions & 6 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,34 @@
import json
import re
import sqlite3
import threading
import time
from collections.abc import Sequence
from logging import DEBUG, ERROR, WARNING
from typing import Any, Optional, Union, cast
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common import Context, log, now
from flwr.common.constant import (
MESSAGE_TTL_TOLERANCE,
NODE_ID_NUM_BYTES,
RUN_ID_NUM_BYTES,
Status,
)
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

# pylint: disable=E0611
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes

# pylint: enable=E0611
from flwr.server.utils.validator import validate_task_ins_or_res

from .linkstate import LinkState
from .utils import (
context_from_bytes,
context_to_bytes,
convert_sint64_to_uint64,
convert_sint64_values_in_dict_to_uint64,
convert_uint64_to_sint64,
Expand Down Expand Up @@ -92,6 +99,14 @@
);
"""

SQL_CREATE_TABLE_CONTEXT = """
CREATE TABLE IF NOT EXISTS context(
run_id INTEGER UNIQUE,
context BLOB,
FOREIGN KEY(run_id) REFERENCES run(run_id)
);
"""

SQL_CREATE_TABLE_TASK_INS = """
CREATE TABLE IF NOT EXISTS task_ins(
task_id TEXT UNIQUE,
Expand Down Expand Up @@ -152,6 +167,7 @@ def __init__(
"""
self.database_path = database_path
self.conn: Optional[sqlite3.Connection] = None
self.lock = threading.RLock()

def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
"""Create tables if they don't exist yet.
Expand All @@ -175,6 +191,7 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:

# Create each table if not exists queries
cur.execute(SQL_CREATE_TABLE_RUN)
cur.execute(SQL_CREATE_TABLE_CONTEXT)
cur.execute(SQL_CREATE_TABLE_TASK_INS)
cur.execute(SQL_CREATE_TABLE_TASK_RES)
cur.execute(SQL_CREATE_TABLE_NODE)
Expand Down Expand Up @@ -970,6 +987,34 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
log(ERROR, "`node_id` does not exist.")
return False

def get_serverapp_context(self, run_id: int) -> Optional[Context]:
"""Get the context for the specified `run_id`."""
# Retrieve context if any
query = "SELECT context FROM context WHERE run_id = ?;"
rows = self.query(query, (convert_uint64_to_sint64(run_id),))
context = context_from_bytes(rows[0]["context"]) if rows else None
return context

def set_serverapp_context(self, run_id: int, context: Context) -> None:
"""Set the context for the specified `run_id`."""
# Convert context to bytes
context_bytes = context_to_bytes(context)
sint_run_id = convert_uint64_to_sint64(run_id)

# Check if any existing Context assigned to the run_id
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
# Update context
query = "UPDATE context SET context = ? WHERE run_id = ?;"
self.query(query, (context_bytes, sint_run_id))
else:
try:
# Store context
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
self.query(query, (sint_run_id, context_bytes))
except sqlite3.IntegrityError:
raise ValueError(f"Run {run_id} not found") from None

def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
"""Check if the TaskIns exists and is valid (not expired).
Expand Down Expand Up @@ -1054,7 +1099,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:

def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
"""Turn task_dict into protobuf message."""
recordset = RecordSet()
recordset = ProtoRecordSet()
recordset.ParseFromString(task_dict["recordset"])

result = TaskIns(
Expand Down Expand Up @@ -1084,7 +1129,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:

def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
"""Turn task_dict into protobuf message."""
recordset = RecordSet()
recordset = ProtoRecordSet()
recordset.ParseFromString(task_dict["recordset"])

result = TaskRes(
Expand Down
13 changes: 12 additions & 1 deletion src/py/flwr/server/superlink/linkstate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from os import urandom
from uuid import uuid4

from flwr.common import log
from flwr.common import Context, log, serde
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common.typing import RunStatus
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

Expand Down Expand Up @@ -135,6 +136,16 @@ def convert_sint64_values_in_dict_to_uint64(
data_dict[key] = convert_sint64_to_uint64(data_dict[key])


def context_to_bytes(context: Context) -> bytes:
"""Serialize `Context` to bytes."""
return serde.context_to_proto(context).SerializeToString()


def context_from_bytes(context_bytes: bytes) -> Context:
"""Deserialize `Context` from bytes."""
return serde.context_from_proto(ProtoContext.FromString(context_bytes))


def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
current_time = time.time()
Expand Down

0 comments on commit cc61019

Please sign in to comment.