Skip to content

Commit

Permalink
updated logic for new client_fn wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jul 12, 2024
1 parent cd4df61 commit 72a7b04
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,47 @@
from .typing import ClientAppCallable


def _alert_erroneous_client_fn() -> None:
raise ValueError(
"A `ClientApp` cannot make use of a `client_fn` that does "
"not have a signature in the form: `def client_fn(context: "
"Context)`. You can import the `Context` type from `flwr.common`."
)


def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
client_fn_args = inspect.signature(client_fn).parameters
first_arg = list(client_fn_args.keys())[0]

if len(client_fn_args) != 1 or client_fn_args[first_arg].annotation is not Context:
warn_deprecated_feature(
"`client_fn` now expects a signature `def client_fn(context: Context)`."
"\The provided `client_fn` has signature: "
f"{dict(client_fn_args.items())}"
)

# Wrap depcreated client_fn inside a function with the expected signature
def adaptor_fn(context: Context) -> Client: # pylint: disable=unused-argument
# if patition-id is defined, pass it. Else pass node_id that should always
# be defined during Context init.
cid = context.node_config.get("partition-id", context.node_id)
return client_fn(str(cid)) # type: ignore

return adaptor_fn
if len(client_fn_args) == 1:

first_arg_type = client_fn_args[first_arg].annotation

if first_arg_type is str:
# Warn previous signature for `client_fn` seems to be used
warn_deprecated_feature(
"`client_fn` now expects a signature `def client_fn(context: Context)`."
"The provided `client_fn` has signature: "
f"{dict(client_fn_args.items())}. You can import the `Context` type "
"from `flwr.common`."
)

# Wrap depcreated client_fn inside a function with the expected signature
def adaptor_fn(
context: Context,
) -> Client: # pylint: disable=unused-argument
# if patition-id is defined, pass it. Else pass node_id that should
# always be defined during Context init.
cid = context.node_config.get("partition-id", context.node_id)
return client_fn(str(cid)) # type: ignore

return adaptor_fn

if first_arg_type is not Context:
_alert_erroneous_client_fn()

else:
_alert_erroneous_client_fn()

return client_fn

Expand Down

0 comments on commit 72a7b04

Please sign in to comment.