diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 380185ed26e7..700ac85f341f 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -348,7 +348,6 @@ def _on_backoff(retry_state: RetryState) -> None: node_state = NodeState( node_id=-1, node_config={}, - partition_id=None, ) else: # Call create_node fn to register node @@ -360,7 +359,6 @@ def _on_backoff(retry_state: RetryState) -> None: node_state = NodeState( node_id=node_id, node_config=node_config, - partition_id=None, ) app_state_tracker.register_signal_handler() diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index d0a349b0cae0..393ca4564a35 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -36,12 +36,13 @@ class NodeState: """State of a node where client nodes execute runs.""" def __init__( - self, node_id: int, node_config: Dict[str, str], partition_id: Optional[int] + self, + node_id: int, + node_config: Dict[str, str], ) -> None: self.node_id = node_id self.node_config = node_config self.run_infos: Dict[int, RunInfo] = {} - self._partition_id = partition_id def register_context( self, @@ -59,7 +60,6 @@ def register_context( node_config=self.node_config, state=RecordSet(), run_config=initial_run_config.copy(), - partition_id=self._partition_id, ), ) diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 8d7971fa5280..26ac4fea6855 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None: expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState - node_state = NodeState(node_id=0, node_config={}, partition_id=None) + node_state = NodeState(node_id=0, node_config={}) for task in tasks: run_id = task.run_id diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index f14959589458..72256a62add7 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -57,6 +57,9 @@ FAB_CONFIG_FILE = "pyproject.toml" FLWR_HOME = "FLWR_HOME" +# Constants entries in Node config for Simulation +PARTITION_ID_KEY = "partition-id" +NUM_PARTITIONS_KEY = "num-partitions" GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit" diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index e65300278c84..4da52ba44481 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,7 +16,7 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict from .record import RecordSet @@ -43,17 +43,12 @@ class Context: A config (key/value mapping) held by the entity in a given run and that will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) - partition_id : Optional[int] (default: None) - An index that specifies the data partition that the ClientApp using this Context - object should make use of. Setting this attribute is better suited for - simulation or proto typing setups. """ node_id: int node_config: Dict[str, str] state: RecordSet run_config: Dict[str, str] - partition_id: Optional[int] def __init__( # pylint: disable=too-many-arguments self, @@ -61,10 +56,8 @@ def __init__( # pylint: disable=too-many-arguments node_config: Dict[str, str], state: RecordSet, run_config: Dict[str, str], - partition_id: Optional[int] = None, ) -> None: self.node_id = node_id self.node_config = node_config self.state = state self.run_config = run_config - self.partition_id = partition_id diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 0d2f4d193f0b..0ab29a234f88 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -21,6 +21,7 @@ import ray from flwr.client.client_app import ClientApp +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.context import Context from flwr.common.logger import log from flwr.common.message import Message @@ -168,7 +169,7 @@ def process_message( Return output message and updated context. """ - partition_id = context.partition_id + partition_id = context.node_config[PARTITION_ID_KEY] try: # Submit a task to the pool diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 3abdac7a232b..a38cff96ceef 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -23,6 +23,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -32,9 +33,9 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import PARTITION_ID_KEY from flwr.common.object_ref import load_app from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig @@ -101,12 +102,13 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: # Construct a Message mult_factor = 2024 + run_id = 0 getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) recordset = getpropertiesins_to_recordset(getproperties_ins) message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id="", src_node_id=0, @@ -117,8 +119,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: ), ) - # Construct emtpy Context - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + # Construct NodeState and retrieve context + node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)}) + node_state.register_context(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Expected output expected_output = pi * mult_factor diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 422148489b4a..cd30c40167c5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -29,7 +29,12 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + PING_MAX_INTERVAL, + ErrorCode, +) from flwr.common.logger import log from flwr.common.message import Error from flwr.common.object_ref import load_app @@ -73,7 +78,7 @@ def worker( task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id - # Register and retrieve runstate + # Register and retrieve context node_states[node_id].register_context(run_id=task_ins.run_id) context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) @@ -283,11 +288,15 @@ def start_vce( # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} + # Number of unique partitions + num_partitions = len(set(nodes_mapping.values())) for node_id, partition_id in nodes_mapping.items(): node_states[node_id] = NodeState( node_id=node_id, - node_config={"partition-id": str(partition_id)}, - partition_id=None, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, ) # Load backend config diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 446b0bdeba38..fc52267f9a04 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -327,6 +327,7 @@ def update_resources(f_stop: threading.Event) -> None: client_fn=client_fn, node_id=node_id, partition_id=partition_id, + num_partitions=num_clients, actor_pool=pool, ) initialized_server.client_manager().register(client=client_proxy) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index b62e04aeed79..895272c2fd79 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -24,7 +24,12 @@ from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet -from flwr.common.constant import MessageType, MessageTypeLegacy +from flwr.common.constant import ( + NUM_PARTITIONS_KEY, + PARTITION_ID_KEY, + MessageType, + MessageTypeLegacy, +) from flwr.common.logger import log from flwr.common.recordset_compat import ( evaluateins_to_recordset, @@ -43,11 +48,12 @@ class RayActorClientProxy(ClientProxy): """Flower client proxy which delegates work using Ray.""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, client_fn: ClientFnExt, node_id: int, partition_id: int, + num_partitions: int, actor_pool: VirtualClientEngineActorPool, ): super().__init__(cid=str(node_id)) @@ -61,8 +67,10 @@ def _load_app() -> ClientApp: self.actor_pool = actor_pool self.proxy_state = NodeState( node_id=node_id, - node_config={"partition-id": str(partition_id)}, - partition_id=None, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(num_partitions), + }, ) def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: @@ -74,7 +82,7 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # Retrieve context context = self.proxy_state.retrieve_context(run_id=run_id) - partition_id_str = context.node_config["partition-id"] + partition_id_str = context.node_config[PARTITION_ID_KEY] try: self.actor_pool.submit_client_job( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 1d44ea1d8d2b..62e0cfd61c99 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -23,6 +23,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp +from flwr.client.node_state import NodeState from flwr.common import ( DEFAULT_TTL, Config, @@ -31,9 +32,9 @@ Message, MessageTypeLegacy, Metadata, - RecordSet, Scalar, ) +from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY from flwr.common.recordset_compat import ( getpropertiesins_to_recordset, recordset_to_getpropertiesres, @@ -99,6 +100,7 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]: client_fn=get_dummy_client, node_id=node_id, partition_id=partition_id, + num_partitions=num_proxies, actor_pool=pool, ) for node_id, partition_id in mapping.items() @@ -192,6 +194,17 @@ def test_cid_consistency_without_proxies() -> None: _, pool, mapping = prep() node_ids = list(mapping.keys()) + # register node states + node_states: Dict[int, NodeState] = {} + for node_id, partition_id in mapping.items(): + node_states[node_id] = NodeState( + node_id=node_id, + node_config={ + PARTITION_ID_KEY: str(partition_id), + NUM_PARTITIONS_KEY: str(len(node_ids)), + }, + ) + getproperties_ins = _get_valid_getpropertiesins() recordset = getpropertiesins_to_recordset(getproperties_ins) @@ -200,11 +213,12 @@ def _load_app() -> ClientApp: # submit all jobs (collect later) shuffle(node_ids) + run_id = 0 for node_id in node_ids: message = Message( content=recordset, metadata=Metadata( - run_id=0, + run_id=run_id, message_id="", group_id=str(0), src_node_id=0, @@ -214,26 +228,20 @@ def _load_app() -> ClientApp: message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) + # register and retrieve context + node_states[node_id].register_context(run_id=run_id) + context = node_states[node_id].retrieve_context(run_id=run_id) + partition_id_str = context.node_config[PARTITION_ID_KEY] pool.submit_client_job( lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), - ( - _load_app, - message, - str(node_id), - Context( - node_id=node_id, - node_config={}, - state=RecordSet(), - run_config={}, - partition_id=mapping[node_id], - ), - ), + (_load_app, message, partition_id_str, context), ) # fetch results one at a time shuffle(node_ids) for node_id in node_ids: - message_out, _ = pool.get_client_result(str(node_id), timeout=None) + partition_id_str = str(mapping[node_id]) + message_out, _ = pool.get_client_result(partition_id_str, timeout=None) res = recordset_to_getpropertiesres(message_out.content) assert node_id * pi == res.properties["result"]