From cc6101918fc8ab24cc55e2d9f18dbae93e42df50 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 24 Oct 2024 18:56:56 +0100 Subject: [PATCH] feat(framework) Add `get_serverapp_context` and `set_serverapp_context` to `LinkState` (#4365) --- .../linkstate/in_memory_linkstate.py | 13 ++++- .../server/superlink/linkstate/linkstate.py | 29 ++++++++++ .../superlink/linkstate/linkstate_test.py | 56 +++++++++++++++--- .../superlink/linkstate/sqlite_linkstate.py | 57 +++++++++++++++++-- .../flwr/server/superlink/linkstate/utils.py | 13 ++++- 5 files changed, 152 insertions(+), 16 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index a5b8a12cc9d5..c616dafa8951 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -22,7 +22,7 @@ from typing import Optional from uuid import UUID, uuid4 -from flwr.common import log, now +from flwr.common import Context, log, now from flwr.common.constant import ( MESSAGE_TTL_TOLERANCE, NODE_ID_NUM_BYTES, @@ -65,6 +65,7 @@ def __init__(self) -> None: # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} + self.contexts: dict[int, Context] = {} self.task_ins_store: dict[UUID, TaskIns] = {} self.task_res_store: dict[UUID, TaskRes] = {} @@ -500,3 +501,13 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return True return False + + def get_serverapp_context(self, run_id: int) -> Optional[Context]: + """Get the context for the specified `run_id`.""" + return self.contexts.get(run_id) + + def set_serverapp_context(self, run_id: int, context: Context) -> None: + """Set the context for the specified `run_id`.""" + if run_id not in self.run_ids: + raise ValueError(f"Run {run_id} not found") + self.contexts[run_id] = context diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 6e20b6717207..0ca9180f4e39 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -19,6 +19,7 @@ from typing import Optional from uuid import UUID +from flwr.common import Context from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -270,3 +271,31 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: is_acknowledged : bool True if the ping is successfully acknowledged; otherwise, False. """ + + @abc.abstractmethod + def get_serverapp_context(self, run_id: int) -> Optional[Context]: + """Get the context for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run for which to retrieve the context. + + Returns + ------- + Optional[Context] + The context associated with the specified `run_id`, or `None` if no context + exists for the given `run_id`. + """ + + @abc.abstractmethod + def set_serverapp_context(self, run_id: int, context: Context) -> None: + """Set the context for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run for which to set the context. + context : Context + The context to be associated with the specified `run_id`. + """ diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 418e61168915..d29358a24825 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -23,7 +23,7 @@ from unittest.mock import patch from uuid import UUID -from flwr.common import DEFAULT_TTL +from flwr.common import DEFAULT_TTL, Context, RecordSet from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, @@ -31,9 +31,13 @@ public_key_to_bytes, ) from flwr.common.typing import RunStatus -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 + +# pylint: disable=E0611 +from flwr.proto.node_pb2 import Node +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes + +# pylint: enable=E0611 from flwr.server.superlink.linkstate import ( InMemoryLinkState, LinkState, @@ -998,6 +1002,42 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: # Assert assert task_res_uuid is None + def test_get_set_serverapp_context(self) -> None: + """Test get and set serverapp context.""" + # Prepare + state: LinkState = self.state_factory() + context = Context( + node_id=0, + node_config={"mock": "mock"}, + state=RecordSet(), + run_config={"test": "test"}, + ) + run_id = state.create_run(None, None, "9f86d08", {}) + + # Execute + init = state.get_serverapp_context(run_id) + state.set_serverapp_context(run_id, context) + retrieved_context = state.get_serverapp_context(run_id) + + # Assert + assert init is None + assert retrieved_context == context + + def test_set_context_invalid_run_id(self) -> None: + """Test set_serverapp_context with invalid run_id.""" + # Prepare + state: LinkState = self.state_factory() + context = Context( + node_id=0, + node_config={"mock": "mock"}, + state=RecordSet(), + run_config={"test": "test"}, + ) + + # Execute and assert + with self.assertRaises(ValueError): + state.set_serverapp_context(61016, context) # Invalid run_id + def create_task_ins( consumer_node_id: int, @@ -1019,7 +1059,7 @@ def create_task_ins( producer=Node(node_id=0, anonymous=True), consumer=consumer, task_type="mock", - recordset=RecordSet(parameters={}, metrics={}, configs={}), + recordset=ProtoRecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, created_at=time.time(), ), @@ -1044,7 +1084,7 @@ def create_task_res( consumer=Node(node_id=0, anonymous=True), ancestry=ancestry, task_type="mock", - recordset=RecordSet(parameters={}, metrics={}, configs={}), + recordset=ProtoRecordSet(parameters={}, metrics={}, configs={}), ttl=DEFAULT_TTL, created_at=time.time(), ), @@ -1083,7 +1123,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 15 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -1108,7 +1148,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 15 if __name__ == "__main__": diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index bcf0b319f307..89d00528fa56 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -19,13 +19,14 @@ import json import re import sqlite3 +import threading import time from collections.abc import Sequence from logging import DEBUG, ERROR, WARNING from typing import Any, Optional, Union, cast from uuid import UUID, uuid4 -from flwr.common import log, now +from flwr.common import Context, log, now from flwr.common.constant import ( MESSAGE_TTL_TOLERANCE, NODE_ID_NUM_BYTES, @@ -33,13 +34,19 @@ Status, ) from flwr.common.typing import Run, RunStatus, UserConfig -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 + +# pylint: disable=E0611 +from flwr.proto.node_pb2 import Node +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes + +# pylint: enable=E0611 from flwr.server.utils.validator import validate_task_ins_or_res from .linkstate import LinkState from .utils import ( + context_from_bytes, + context_to_bytes, convert_sint64_to_uint64, convert_sint64_values_in_dict_to_uint64, convert_uint64_to_sint64, @@ -92,6 +99,14 @@ ); """ +SQL_CREATE_TABLE_CONTEXT = """ +CREATE TABLE IF NOT EXISTS context( + run_id INTEGER UNIQUE, + context BLOB, + FOREIGN KEY(run_id) REFERENCES run(run_id) +); +""" + SQL_CREATE_TABLE_TASK_INS = """ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, @@ -152,6 +167,7 @@ def __init__( """ self.database_path = database_path self.conn: Optional[sqlite3.Connection] = None + self.lock = threading.RLock() def initialize(self, log_queries: bool = False) -> list[tuple[str]]: """Create tables if they don't exist yet. @@ -175,6 +191,7 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]: # Create each table if not exists queries cur.execute(SQL_CREATE_TABLE_RUN) + cur.execute(SQL_CREATE_TABLE_CONTEXT) cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) @@ -970,6 +987,34 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: log(ERROR, "`node_id` does not exist.") return False + def get_serverapp_context(self, run_id: int) -> Optional[Context]: + """Get the context for the specified `run_id`.""" + # Retrieve context if any + query = "SELECT context FROM context WHERE run_id = ?;" + rows = self.query(query, (convert_uint64_to_sint64(run_id),)) + context = context_from_bytes(rows[0]["context"]) if rows else None + return context + + def set_serverapp_context(self, run_id: int, context: Context) -> None: + """Set the context for the specified `run_id`.""" + # Convert context to bytes + context_bytes = context_to_bytes(context) + sint_run_id = convert_uint64_to_sint64(run_id) + + # Check if any existing Context assigned to the run_id + query = "SELECT COUNT(*) FROM context WHERE run_id = ?;" + if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0: + # Update context + query = "UPDATE context SET context = ? WHERE run_id = ?;" + self.query(query, (context_bytes, sint_run_id)) + else: + try: + # Store context + query = "INSERT INTO context (run_id, context) VALUES (?, ?);" + self.query(query, (sint_run_id, context_bytes)) + except sqlite3.IntegrityError: + raise ValueError(f"Run {run_id} not found") from None + def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]: """Check if the TaskIns exists and is valid (not expired). @@ -1054,7 +1099,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]: def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns: """Turn task_dict into protobuf message.""" - recordset = RecordSet() + recordset = ProtoRecordSet() recordset.ParseFromString(task_dict["recordset"]) result = TaskIns( @@ -1084,7 +1129,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns: def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes: """Turn task_dict into protobuf message.""" - recordset = RecordSet() + recordset = ProtoRecordSet() recordset.ParseFromString(task_dict["recordset"]) result = TaskRes( diff --git a/src/py/flwr/server/superlink/linkstate/utils.py b/src/py/flwr/server/superlink/linkstate/utils.py index 1e5c5de612a5..4a18e8880c9d 100644 --- a/src/py/flwr/server/superlink/linkstate/utils.py +++ b/src/py/flwr/server/superlink/linkstate/utils.py @@ -20,10 +20,11 @@ from os import urandom from uuid import uuid4 -from flwr.common import log +from flwr.common import Context, log, serde from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.typing import RunStatus from flwr.proto.error_pb2 import Error # pylint: disable=E0611 +from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -135,6 +136,16 @@ def convert_sint64_values_in_dict_to_uint64( data_dict[key] = convert_sint64_to_uint64(data_dict[key]) +def context_to_bytes(context: Context) -> bytes: + """Serialize `Context` to bytes.""" + return serde.context_to_proto(context).SerializeToString() + + +def context_from_bytes(context_bytes: bytes) -> Context: + """Deserialize `Context` from bytes.""" + return serde.context_from_proto(ProtoContext.FromString(context_bytes)) + + def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: """Generate a TaskRes with a node unavailable error from a TaskIns.""" current_time = time.time()