diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 3396b5c580cb..15d384cb74a2 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -371,7 +371,7 @@ def _on_backoff(retry_state: RetryState) -> None: run_info[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = Run(run_id, "", "") + run_info[run_id] = Run(run_id, "", "", {}) # Register context for this run node_state.register_context(run_id=run_id) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 630d3090360d..8062ce28fcc7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -280,6 +280,7 @@ def get_run(run_id: int) -> Run: run_id, get_run_response.run.fab_id, get_run_response.run.fab_version, + dict(get_run_response.run.override_config.items()), ) try: diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 4b45d7fc24a5..0efa5731ae51 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -352,12 +352,13 @@ def get_run(run_id: int) -> Run: # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return Run(run_id, "", "") + return Run(run_id, "", "", {}) return Run( run_id, res.run.fab_id, res.run.fab_version, + dict(res.run.override_config.items()), ) try: diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index f51830955679..04d2cf5bbf7f 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -194,3 +194,4 @@ class Run: run_id: int fab_id: str fab_version: str + override_config: Dict[str, str] diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index cdb9dd1ee87d..84da5882eb73 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -127,6 +127,7 @@ def _init_run(self) -> None: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, + override_config=dict(res.run.override_config.items()), ) @property diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 0cc1c5a53e13..d0f32e830f7d 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -86,7 +86,10 @@ def setUp(self) -> None: for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( - run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + run_id=61016, + fab_id="mock/mock", + fab_version="v1.0.0", + override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) @@ -98,6 +101,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", ""), MagicMock(state=lambda: state) + state.create_run("", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", ""), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..7f8ded3bdb85 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -69,7 +69,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + dict(request.override_config.items()), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 01499102b7d8..798e71435585 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index df9f2cc96f95..c0bf506fd2b6 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -82,7 +82,9 @@ def register_messages_into_state( ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0") + state.run_ids[run_id] = Run( + run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5a4e4eb0fd9a..bc4bd4478a23 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -275,7 +275,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" return self.public_key_to_node_id.get(client_public_key) - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: @@ -283,7 +288,10 @@ def create_run(self, fab_id: str, fab_version: str) -> int: if run_id not in self.run_ids: self.run_ids[run_id] = Run( - run_id=run_id, fab_id=fab_id, fab_version=fab_version + run_id=run_id, + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, ) return run_id log(ERROR, "Unexpected run creation failure.") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 725f7c2dff4b..ea6f349b9f9a 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -15,6 +15,7 @@ """SQLite based implemenation of server state.""" +import json import re import sqlite3 import time @@ -61,9 +62,10 @@ SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( - run_id INTEGER UNIQUE, - fab_id TEXT, - fab_version TEXT + run_id INTEGER UNIQUE, + fab_id TEXT, + fab_version TEXT, + override_config TEXT ); """ @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: return node_id return None - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) @@ -622,8 +629,13 @@ def create_run(self, fab_id: str, fab_version: str) -> int: query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);" - self.query(query, (run_id, fab_id, fab_version)) + query = ( + "INSERT INTO run (run_id, fab_id, fab_version, override_config)" + "VALUES (?, ?, ?, ?);" + ) + self.query( + query, (run_id, fab_id, fab_version, json.dumps(override_config)) + ) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -687,7 +699,10 @@ def get_run(self, run_id: int) -> Optional[Run]: try: row = self.query(query, (run_id,))[0] return Run( - run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"] + run_id=run_id, + fab_id=row["fab_id"], + fab_version=row["fab_version"], + override_config=json.loads(row["override_config"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..c93f6ba756b8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID from flwr.common.typing import Run @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 373202d5cde6..5f0d23ffc4d8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0") + run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -62,6 +62,7 @@ def test_create_and_get_run(self) -> None: assert run.run_id == run_id assert run.fab_id == "Mock/mock" assert run.fab_version == "v1.0.0" + assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" @@ -90,7 +91,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -125,7 +126,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -199,7 +200,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -214,7 +215,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -228,7 +229,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -242,7 +243,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -259,7 +260,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -302,7 +303,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -323,7 +324,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -335,7 +336,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [] # Execute @@ -352,7 +353,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -368,7 +369,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -390,7 +391,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -405,7 +406,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -422,7 +423,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = 0 # Execute & Assert @@ -441,7 +442,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -460,7 +461,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -475,7 +476,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0") + state.create_run("mock/mock", "v1.0.0", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -489,7 +490,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -507,7 +508,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -608,7 +609,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -627,7 +628,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7c7a412a245b..91805dc5ed7b 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -209,7 +209,7 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "", {}) if run_id: _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id)