From 8587e73442c088d6d45122b7f413519a6fe37c94 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 14:46:47 +0200 Subject: [PATCH] feat(framework) Add override config to Run --- src/py/flwr/client/app.py | 25 +++++++--- .../client/grpc_adapter_client/connection.py | 3 +- src/py/flwr/client/grpc_client/connection.py | 3 +- .../client/grpc_rere_client/connection.py | 12 +++-- src/py/flwr/client/rest_client/connection.py | 14 ++++-- src/py/flwr/server/driver/grpc_driver.py | 1 + .../server/driver/inmemory_driver_test.py | 10 ++-- .../superlink/driver/driver_servicer.py | 6 ++- .../grpc_rere/server_interceptor_test.py | 4 +- .../superlink/fleet/vce/vce_api_test.py | 4 +- .../server/superlink/state/in_memory_state.py | 12 ++++- .../server/superlink/state/sqlite_state.py | 23 +++++++-- src/py/flwr/server/superlink/state/state.py | 9 +++- .../flwr/server/superlink/state/state_test.py | 47 ++++++++++--------- src/py/flwr/simulation/run_simulation.py | 2 +- 15 files changed, 120 insertions(+), 55 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index d2d5a79f32f3..36757da18960 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,6 +19,7 @@ import time from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -29,6 +30,7 @@ from flwr.client.typing import ClientFnExt from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address +from flwr.common.config import get_fused_config from flwr.common.constant import ( MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_ADAPTER, @@ -41,6 +43,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.typing import Run from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -192,6 +195,7 @@ def _start_client_internal( max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, partition_id: Optional[int] = None, + flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -235,9 +239,16 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - partitioni_id: Optional[int] (default: None) + partition_id: Optional[int] (default: None) The data partition index associated with this node. Better suited for prototyping purposes. + flwr_dir: Optional[Path] (default: None) + The path containing installed Flower Apps. + By default, this value is equal to: + + - `$FLWR_HOME/` if `$FLWR_HOME` is defined + - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined + - `$HOME/.flwr/` in all other cases """ if insecure is None: insecure = root_certificates is None @@ -315,8 +326,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) node_state = NodeState(partition_id=partition_id) - # run_id -> (fab_id, fab_version) - run_info: Dict[int, Tuple[str, str]] = {} + run_info: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -371,13 +381,14 @@ def _on_backoff(retry_state: RetryState) -> None: run_info[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = ("", "") + run_info[run_id] = Run(run_id, "", "", {}) # Register context for this run node_state.register_context(run_id=run_id) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) + context.run_config = get_fused_config(run_info[run_id], flwr_dir) # Create an error reply message that will never be used to prevent # the used-before-assignment linting error @@ -388,7 +399,9 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - client_app: ClientApp = load_client_app_fn(*run_info[run_id]) + client_app: ClientApp = load_client_app_fn( + run_info[run_id].fab_id, run_info[run_id].fab_version + ) # Execute ClientApp reply_message = client_app(message=message, context=context) @@ -573,7 +586,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ], ], diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index e4e32b3accd0..971b630e470b 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.logger import log from flwr.common.message import Message from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run @contextmanager @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server via GrpcAdapter. diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 8c049861c672..3e9f261c1ecf 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -38,6 +38,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.retry_invoker import RetryInvoker +from flwr.common.typing import Run from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Establish a gRPC connection to a gRPC server. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 34dc0e417383..8062ce28fcc7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, @@ -80,7 +81,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -266,7 +267,7 @@ def send(message: Message) -> None: # Cleanup metadata = None - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Call FleetAPI get_run_request = GetRunRequest(run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( @@ -275,7 +276,12 @@ def get_run(run_id: int) -> Tuple[str, str]: ) # Return fab_id and fab_version - return get_run_response.run.fab_id, get_run_response.run.fab_version + return Run( + run_id, + get_run_response.run.fab_id, + get_run_response.run.fab_version, + dict(get_run_response.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index db5bd7eb6770..0efa5731ae51 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -41,6 +41,7 @@ from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -91,7 +92,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Callable[[Message], None], Optional[Callable[[], None]], Optional[Callable[[], None]], - Optional[Callable[[int], Tuple[str, str]]], + Optional[Callable[[int], Run]], ] ]: """Primitives for request/response-based interaction with a server. @@ -344,16 +345,21 @@ def send(message: Message) -> None: res.results, # pylint: disable=no-member ) - def get_run(run_id: int) -> Tuple[str, str]: + def get_run(run_id: int) -> Run: # Construct the request req = GetRunRequest(run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) if res is None: - return "", "" + return Run(run_id, "", "", {}) - return res.run.fab_id, res.run.fab_version + return Run( + run_id, + res.run.fab_id, + res.run.fab_version, + dict(res.run.override_config.items()), + ) try: # Yield methods diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e614df659e3f..ae4b3d2519fb 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -206,6 +206,7 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, + override_config=dict(res.run.override_config.items()), ) return self.stub, self._run.run_id diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 0cc1c5a53e13..d0f32e830f7d 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -86,7 +86,10 @@ def setUp(self) -> None: for _ in range(self.num_nodes) ] self.state.get_run.return_value = Run( - run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + run_id=61016, + fab_id="mock/mock", + fab_version="v1.0.0", + override_config={"test_key": "test_value"}, ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) @@ -98,6 +101,7 @@ def test_get_run(self) -> None: self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.override_config["test_key"], "test_value") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: # Prepare state = StateFactory("").state() self.driver = InMemoryDriver( - state.create_run("", ""), MagicMock(state=lambda: state) + state.create_run("", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteState) @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = StateFactory(":flwr-in-memory-state:") state = state_factory.state() - self.driver = InMemoryDriver(state.create_run("", ""), state_factory) + self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryState) diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 03128f02158e..7f8ded3bdb85 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -69,7 +69,11 @@ def CreateRun( """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") state: State = self.state_factory.state() - run_id = state.create_run(request.fab_id, request.fab_version) + run_id = state.create_run( + request.fab_id, + request.fab_version, + dict(request.override_config.items()), + ) return CreateRunResponse(run_id=run_id) def PushTaskIns( diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 01499102b7d8..798e71435585 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._client_public_key) ) - run_id = self.state.create_run("", "") + run_id = self.state.create_run("", "", {}) request = GetRunRequest(run_id=run_id) client_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(client_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index df9f2cc96f95..c0bf506fd2b6 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -82,7 +82,9 @@ def register_messages_into_state( ) -> Dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0") + state.run_ids[run_id] = Run( + run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={} + ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5a4e4eb0fd9a..bc4bd4478a23 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -275,7 +275,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" return self.public_key_to_node_id.get(client_public_key) - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id with self.lock: @@ -283,7 +288,10 @@ def create_run(self, fab_id: str, fab_version: str) -> int: if run_id not in self.run_ids: self.run_ids[run_id] = Run( - run_id=run_id, fab_id=fab_id, fab_version=fab_version + run_id=run_id, + fab_id=fab_id, + fab_version=fab_version, + override_config=override_config, ) return run_id log(ERROR, "Unexpected run creation failure.") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 725f7c2dff4b..49f40653750e 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,6 +18,7 @@ import re import sqlite3 import time +from ast import literal_eval from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -63,7 +64,8 @@ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE, fab_id TEXT, - fab_version TEXT + fab_version TEXT, + overrides TEXT ); """ @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: return node_id return None - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) @@ -622,8 +629,11 @@ def create_run(self, fab_id: str, fab_version: str) -> int: query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" # If run_id does not exist if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);" - self.query(query, (run_id, fab_id, fab_version)) + query = ( + "INSERT INTO run (run_id, fab_id, fab_version, overrides)" + "VALUES (?, ?, ?, ?);" + ) + self.query(query, (run_id, fab_id, fab_version, str(override_config))) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -687,7 +697,10 @@ def get_run(self, run_id: int) -> Optional[Run]: try: row = self.query(query, (run_id,))[0] return Run( - run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"] + run_id=run_id, + fab_id=row["fab_id"], + fab_version=row["fab_version"], + override_config=literal_eval(row["overrides"]), ) except sqlite3.IntegrityError: log(ERROR, "`run_id` does not exist.") diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 65e2c63cab69..c93f6ba756b8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from uuid import UUID from flwr.common.typing import Run @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `client_public_keys`.""" @abc.abstractmethod - def create_run(self, fab_id: str, fab_version: str) -> int: + def create_run( + self, + fab_id: str, + fab_version: str, + override_config: Dict[str, str], + ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 373202d5cde6..5f0d23ffc4d8 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -52,7 +52,7 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("Mock/mock", "v1.0.0") + run_id = state.create_run("Mock/mock", "v1.0.0", {"test_key": "test_value"}) # Execute run = state.get_run(run_id) @@ -62,6 +62,7 @@ def test_create_and_get_run(self) -> None: assert run.run_id == run_id assert run.fab_id == "Mock/mock" assert run.fab_version == "v1.0.0" + assert run.override_config["test_key"] == "test_value" def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" @@ -90,7 +91,7 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -125,7 +126,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_0 = create_task_ins( consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) @@ -199,7 +200,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -214,7 +215,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -228,7 +229,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -242,7 +243,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -259,7 +260,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute @@ -302,7 +303,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, @@ -323,7 +324,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -335,7 +336,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [] # Execute @@ -352,7 +353,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -368,7 +369,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -390,7 +391,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10) # Execute @@ -405,7 +406,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -422,7 +423,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: State = self.state_factory() public_key = b"mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = 0 # Execute & Assert @@ -441,7 +442,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -460,7 +461,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: State = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -475,7 +476,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_run("mock/mock", "v1.0.0") + state.create_run("mock/mock", "v1.0.0", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -489,7 +490,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -507,7 +508,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) task_0 = create_task_res( producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) @@ -608,7 +609,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -627,7 +628,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() - run_id = state.create_run("mock/mock", "v1.0.0") + run_id = state.create_run("mock/mock", "v1.0.0", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7c7a412a245b..91805dc5ed7b 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -209,7 +209,7 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "", {}) if run_id: _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id)