diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index bfe5147f78e1..fa17ba9a8481 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -319,7 +319,13 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - node_state = NodeState(partition_id=partition_id) + # Empty dict (for now) + # This will be removed once users can pass node_config via flower-supernode + node_config: Dict[str, str] = {} + + # NodeState gets initialized when the first connection is established + node_state: Optional[NodeState] = None + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: @@ -334,9 +340,33 @@ def _on_backoff(retry_state: RetryState) -> None: ) as conn: receive, send, create_node, delete_node, get_run = conn - # Register node - if create_node is not None: - create_node() # pylint: disable=not-callable + # Register node when connecting the first time + if node_state is None: + if create_node is None: + if transport not in ["grpc-bidi", None]: + raise NotImplementedError( + "All transports except `grpc-bidi` require " + "an implementation for `create_node()`.'" + ) + # gRPC-bidi doesn't have the concept of node_id, + # so we set it to -1 + node_state = NodeState( + node_id=-1, + node_config={}, + partition_id=partition_id, + ) + else: + # Call create_node fn to register node + node_id: Optional[int] = ( # pylint: disable=assignment-from-none + create_node() + ) # pylint: disable=not-callable + if node_id is None: + raise ValueError("Node registration failed") + node_state = NodeState( + node_id=node_id, + node_config=node_config, + partition_id=partition_id, + ) app_state_tracker.register_signal_handler() while not app_state_tracker.interrupt: @@ -580,7 +610,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index 971b630e470b..80a5cf0b4656 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 3e9f261c1ecf..a6417106d51b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -72,7 +72,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8062ce28fcc7..e573df6854bc 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -176,7 +176,7 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" # Call FleetAPI create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) @@ -189,6 +189,7 @@ def create_node() -> None: nonlocal node, ping_thread node = cast(Node, create_node_response.node) ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 9ce4c9620c43..96de7ce0c2cb 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -145,7 +145,7 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(state=RecordSet(), run_config={}), + context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 5e4c4411e1f7..2832576fb4fc 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) - return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={}) + return Context( + node_id=123, + node_config={}, + state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), + run_config={}, + ) def _make_set_state_fn( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index 7a1dd8988399..a5bbd0a0bb4d 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(state=state, run_config={}) + context = Context(node_id=0, node_config={}, state=state, run_config={}) message = _get_dummy_flower_message() # Execute @@ -129,7 +129,7 @@ def test_filter(self) -> None: # Prepare footprint: List[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 2b090eba9720..d0a349b0cae0 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional from flwr.common import Context, RecordSet from flwr.common.config import get_fused_config @@ -35,8 +35,11 @@ class RunInfo: class NodeState: """State of a node where client nodes execute runs.""" - def __init__(self, partition_id: Optional[int]) -> None: - self._meta: Dict[str, Any] = {} # holds metadata about the node + def __init__( + self, node_id: int, node_config: Dict[str, str], partition_id: Optional[int] + ) -> None: + self.node_id = node_id + self.node_config = node_config self.run_infos: Dict[int, RunInfo] = {} self._partition_id = partition_id @@ -52,6 +55,8 @@ def register_context( self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( + node_id=self.node_id, + node_config=self.node_config, state=RecordSet(), run_config=initial_run_config.copy(), partition_id=self._partition_id, diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index effd64a3ae7a..8d7971fa5280 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None: expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState - node_state = NodeState(partition_id=None) + node_state = NodeState(node_id=0, node_config={}, partition_id=None) for task in tasks: run_id = task.run_id diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 0efa5731ae51..3e81969d898c 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -90,7 +90,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Tuple[ Callable[[], Optional[Message]], Callable[[Message], None], - Optional[Callable[[], None]], + Optional[Callable[[], Optional[int]]], Optional[Callable[[], None]], Optional[Callable[[int], Run]], ] @@ -237,19 +237,20 @@ def ping() -> None: if not ping_stop_event.is_set(): ping_stop_event.wait(next_interval) - def create_node() -> None: + def create_node() -> Optional[int]: """Set create_node.""" req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL) # Send the request res = _request(req, CreateNodeResponse, PATH_CREATE_NODE) if res is None: - return + return None # Remember the node and the ping-loop thread nonlocal node, ping_thread node = res.node ping_thread = start_ping_loop(ping, ping_stop_event) + return node.node_id def delete_node() -> None: """Set delete_node.""" diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 8120723ce9e9..e65300278c84 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -27,6 +27,11 @@ class Context: Parameters ---------- + node_id : int + The ID that identifies the node. + node_config : Dict[str, str] + A config (key/value mapping) unique to the node and independent of the + `run_config`. This config persists across all runs this node participates in. state : RecordSet Holds records added by the entity in a given run and that will stay local. This means that the data it holds will never leave the system it's running from. @@ -44,16 +49,22 @@ class Context: simulation or proto typing setups. """ + node_id: int + node_config: Dict[str, str] state: RecordSet - partition_id: Optional[int] run_config: Dict[str, str] + partition_id: Optional[int] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, + node_id: int, + node_config: Dict[str, str], state: RecordSet, run_config: Dict[str, str], partition_id: Optional[int] = None, ) -> None: + self.node_id = node_id + self.node_config = node_config self.state = state self.run_config = run_config self.partition_id = partition_id diff --git a/src/py/flwr/server/compat/legacy_context.py b/src/py/flwr/server/compat/legacy_context.py index 9e120c824103..ee09d79012dc 100644 --- a/src/py/flwr/server/compat/legacy_context.py +++ b/src/py/flwr/server/compat/legacy_context.py @@ -52,4 +52,4 @@ def __init__( self.strategy = strategy self.client_manager = client_manager self.history = History() - super().__init__(state, run_config={}) + super().__init__(node_id=0, node_config={}, state=state, run_config={}) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index b4697e99913f..4cc25feb7e0e 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -78,7 +78,9 @@ def _load() -> ServerApp: server_app = _load() # Initialize Context - context = Context(state=RecordSet(), run_config=server_app_run_config) + context = Context( + node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config + ) # Call ServerApp server_app(driver=driver, context=context) diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index 7de8774d4c81..b0672b3202ed 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) called = {"called": False} diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 287983003f8c..da4390194d05 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -120,7 +120,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ) # Construct emtpy Context - context = Context(state=RecordSet(), run_config={}) + context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) # Expected output expected_output = pi * mult_factor diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 3c0b36e1ca3c..134fd34ed8f0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -284,7 +284,9 @@ def start_vce( # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} for node_id, partition_id in nodes_mapping.items(): - node_states[node_id] = NodeState(partition_id=partition_id) + node_states[node_id] = NodeState( + node_id=node_id, node_config={}, partition_id=partition_id + ) # Load backend config log(DEBUG, "Supported backends: %s", list(supported_backends.keys())) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 31bc22c84bd5..f2684016048e 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -59,7 +59,9 @@ def _load_app() -> ClientApp: self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState(partition_id=self.partition_id) + self.proxy_state = NodeState( + node_id=node_id, node_config={}, partition_id=self.partition_id + ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: """Sumbit a message to the ActorPool.""" diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 83f6cfe05313..8831e5f475ea 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -218,7 +218,13 @@ def _load_app() -> ClientApp: _load_app, message, str(node_id), - Context(state=RecordSet(), run_config={}, partition_id=node_id), + Context( + node_id=0, + node_config={}, + state=RecordSet(), + run_config={}, + partition_id=node_id, + ), ), )