Skip to content

Commit

Permalink
num-paritions in old simulationengin context
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jul 13, 2024
1 parent 432e78b commit c03e68c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def update_resources(f_stop: threading.Event) -> None:
client_fn=client_fn,
node_id=node_id,
partition_id=partition_id,
num_partitions=num_clients,
actor_pool=pool,
)
initialized_server.client_manager().register(client=client_proxy)
Expand Down
15 changes: 12 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from flwr.client.client_app import ClientApp
from flwr.client.node_state import NodeState
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
from flwr.common.constant import PARTITION_ID_KEY, MessageType, MessageTypeLegacy
from flwr.common.constant import (
NUM_PARTITIONS_KEY,
PARTITION_ID_KEY,
MessageType,
MessageTypeLegacy,
)
from flwr.common.logger import log
from flwr.common.recordset_compat import (
evaluateins_to_recordset,
Expand All @@ -43,11 +48,12 @@
class RayActorClientProxy(ClientProxy):
"""Flower client proxy which delegates work using Ray."""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
client_fn: ClientFnExt,
node_id: int,
partition_id: int,
num_partitions: int,
actor_pool: VirtualClientEngineActorPool,
):
super().__init__(cid=str(node_id))
Expand All @@ -61,7 +67,10 @@ def _load_app() -> ClientApp:
self.actor_pool = actor_pool
self.proxy_state = NodeState(
node_id=node_id,
node_config={PARTITION_ID_KEY: str(partition_id)},
node_config={
PARTITION_ID_KEY: str(partition_id),
NUM_PARTITIONS_KEY: str(num_partitions),
},
)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Metadata,
Scalar,
)
from flwr.common.constant import PARTITION_ID_KEY
from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY
from flwr.common.recordset_compat import (
getpropertiesins_to_recordset,
recordset_to_getpropertiesres,
Expand Down Expand Up @@ -100,6 +100,7 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
client_fn=get_dummy_client,
node_id=node_id,
partition_id=partition_id,
num_partitions=num_proxies,
actor_pool=pool,
)
for node_id, partition_id in mapping.items()
Expand Down Expand Up @@ -198,7 +199,10 @@ def test_cid_consistency_without_proxies() -> None:
for node_id, partition_id in mapping.items():
node_states[node_id] = NodeState(
node_id=node_id,
node_config={PARTITION_ID_KEY: str(partition_id)},
node_config={
PARTITION_ID_KEY: str(partition_id),
NUM_PARTITIONS_KEY: str(len(node_ids)),
},
)

getproperties_ins = _get_valid_getpropertiesins()
Expand Down

0 comments on commit c03e68c

Please sign in to comment.