Skip to content

Commit

Permalink
feat(framework) Introduce new client_fn signature passing the `Cont…
Browse files Browse the repository at this point in the history
…ext` (#3779)

Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Jul 13, 2024
1 parent ea8f940 commit 64dd325
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 44 deletions.
7 changes: 3 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
Expand Down Expand Up @@ -138,7 +138,7 @@ class `flwr.client.Client` (default: None)
Starting an SSL-enabled gRPC client using system certificates:
>>> def client_fn(node_id: int, partition_id: Optional[int]):
>>> def client_fn(context: Context):
>>> return FlowerClient()
>>>
>>> start_client(
Expand Down Expand Up @@ -253,8 +253,7 @@ class `flwr.client.Client` (default: None)
if client_fn is None:
# Wrap `Client` instance in `client_fn`
def single_client_factory(
node_id: int, # pylint: disable=unused-argument
partition_id: Optional[int], # pylint: disable=unused-argument
context: Context, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
raise ValueError(
Expand Down
34 changes: 27 additions & 7 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,41 @@
from .typing import ClientAppCallable


def _alert_erroneous_client_fn() -> None:
raise ValueError(
"A `ClientApp` cannot make use of a `client_fn` that does "
"not have a signature in the form: `def client_fn(context: "
"Context)`. You can import the `Context` like this: "
"`from flwr.common import Context`"
)


def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
client_fn_args = inspect.signature(client_fn).parameters
first_arg = list(client_fn_args.keys())[0]

if len(client_fn_args) != 1:
_alert_erroneous_client_fn()

first_arg_type = client_fn_args[first_arg].annotation

if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
if first_arg_type is str or first_arg == "cid":
# Warn previous signature for `client_fn` seems to be used
warn_deprecated_feature(
"`client_fn` now expects a signature `def client_fn(node_id: int, "
"partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
f"{dict(client_fn_args.items())}"
"`client_fn` now expects a signature `def client_fn(context: Context)`."
"The provided `client_fn` has signature: "
f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
" `from flwr.common import Context`"
)

# Wrap depcreated client_fn inside a function with the expected signature
def adaptor_fn(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
return client_fn(str(partition_id)) # type: ignore
context: Context,
) -> Client: # pylint: disable=unused-argument
# if patition-id is defined, pass it. Else pass node_id that should
# always be defined during Context init.
cid = context.node_config.get("partition-id", context.node_id)
return client_fn(str(cid)) # type: ignore

return adaptor_fn

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
client_fn: ClientFnExt, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
client = client_fn(message.metadata.dst_node_id, context.partition_id)
client = client_fn(context)

# Check if NumPyClient is returend
if isinstance(client, NumPyClient):
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
import uuid
from copy import copy
from typing import List, Optional
from typing import List

from flwr.client import Client
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -114,9 +114,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:


def _get_client_fn(client: Client) -> ClientFnExt:
def client_fn(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument
return client

return client_fn
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Custom types for Flower clients."""


from typing import Callable, Optional
from typing import Callable

from flwr.common import Context, Message

from .client import Client as Client

# Compatibility
ClientFn = Callable[[str], Client]
ClientFnExt = Callable[[int, Optional[int]], Client]
ClientFnExt = Callable[[Context], Client]

ClientAppCallable = Callable[[Message, Context], Message]
Mod = Callable[[Message, Context, ClientAppCallable], Message]
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}


def get_dummy_client(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument
"""Return a DummyClient converted to Client type."""
return DummyClient().to_client()

Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def start_vce(
node_states: Dict[int, NodeState] = {}
for node_id, partition_id in nodes_mapping.items():
node_states[node_id] = NodeState(
node_id=node_id, node_config={}, partition_id=partition_id
node_id=node_id,
node_config={"partition-id": str(partition_id)},
partition_id=None,
)

# Load backend config
Expand Down
17 changes: 10 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def _load_app() -> ClientApp:
self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState(
node_id=node_id, node_config={}, partition_id=self.partition_id
node_id=node_id,
node_config={"partition-id": str(partition_id)},
partition_id=None,
)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
Expand All @@ -70,18 +72,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
# Register state
self.proxy_state.register_context(run_id=run_id)

# Retrieve state
state = self.proxy_state.retrieve_context(run_id=run_id)
# Retrieve context
context = self.proxy_state.retrieve_context(run_id=run_id)
partition_id_str = context.node_config["partition-id"]

try:
self.actor_pool.submit_client_job(
lambda a, a_fn, mssg, partition_id, state: a.run.remote(
a_fn, mssg, partition_id, state
lambda a, a_fn, mssg, partition_id, context: a.run.remote(
a_fn, mssg, partition_id, context
),
(self.app_fn, message, str(self.partition_id), state),
(self.app_fn, message, partition_id_str, context),
)
out_mssg, updated_context = self.actor_pool.get_client_result(
str(self.partition_id), timeout
partition_id_str, timeout
)

# Update state
Expand Down
32 changes: 17 additions & 15 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from math import pi
from random import shuffle
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Tuple, Type

import ray

Expand All @@ -39,7 +39,10 @@
recordset_to_getpropertiesres,
)
from flwr.common.recordset_compat_test import _get_valid_getpropertiesins
from flwr.simulation.app import _create_node_id_to_partition_mapping
from flwr.simulation.app import (
NodeToPartitionMapping,
_create_node_id_to_partition_mapping,
)
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand All @@ -65,16 +68,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}


def get_dummy_client(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def get_dummy_client(context: Context) -> Client:
"""Return a DummyClient converted to Client type."""
return DummyClient(node_id).to_client()
return DummyClient(context.node_id).to_client()


def prep(
actor_type: Type[VirtualClientEngineActor] = ClientAppActor,
) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover
) -> Tuple[
List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping
]: # pragma: no cover
"""Prepare ClientProxies and pool for tests."""
client_resources = {"num_cpus": 1, "num_gpus": 0.0}

Expand All @@ -101,15 +104,15 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
for node_id, partition_id in mapping.items()
]

return proxies, pool
return proxies, pool, mapping


def test_cid_consistency_one_at_a_time() -> None:
"""Test that ClientProxies get the result of client job they submit.
Submit one job and waits for completion. Then submits the next and so on
"""
proxies, _ = prep()
proxies, _, _ = prep()

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand Down Expand Up @@ -139,7 +142,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
All jobs are submitted at the same time. Then fetched one at a time. This also tests
NodeState (at each Proxy) and RunState basic functionality.
"""
proxies, _ = prep()
proxies, _, _ = prep()
run_id = 0

getproperties_ins = _get_valid_getpropertiesins()
Expand Down Expand Up @@ -186,9 +189,8 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:

def test_cid_consistency_without_proxies() -> None:
"""Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy."""
proxies, pool = prep()
num_clients = len(proxies)
node_ids = list(range(num_clients))
_, pool, mapping = prep()
node_ids = list(mapping.keys())

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand Down Expand Up @@ -219,11 +221,11 @@ def _load_app() -> ClientApp:
message,
str(node_id),
Context(
node_id=0,
node_id=node_id,
node_config={},
state=RecordSet(),
run_config={},
partition_id=node_id,
partition_id=mapping[node_id],
),
),
)
Expand Down

0 comments on commit 64dd325

Please sign in to comment.