diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index a8b5a5a85136..1ff374749884 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -30,25 +30,47 @@ from .typing import ClientAppCallable +def _alert_erroneous_client_fn() -> None: + raise ValueError( + "A `ClientApp` cannot make use of a `client_fn` that does " + "not have a signature in the form: `def client_fn(context: " + "Context)`. You can import the `Context` type from `flwr.common`." + ) + + def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt: client_fn_args = inspect.signature(client_fn).parameters first_arg = list(client_fn_args.keys())[0] - if len(client_fn_args) != 1 or client_fn_args[first_arg].annotation is not Context: - warn_deprecated_feature( - "`client_fn` now expects a signature `def client_fn(context: Context)`." - "\The provided `client_fn` has signature: " - f"{dict(client_fn_args.items())}" - ) - - # Wrap depcreated client_fn inside a function with the expected signature - def adaptor_fn(context: Context) -> Client: # pylint: disable=unused-argument - # if patition-id is defined, pass it. Else pass node_id that should always - # be defined during Context init. - cid = context.node_config.get("partition-id", context.node_id) - return client_fn(str(cid)) # type: ignore - - return adaptor_fn + if len(client_fn_args) == 1: + + first_arg_type = client_fn_args[first_arg].annotation + + if first_arg_type is str: + # Warn previous signature for `client_fn` seems to be used + warn_deprecated_feature( + "`client_fn` now expects a signature `def client_fn(context: Context)`." + "The provided `client_fn` has signature: " + f"{dict(client_fn_args.items())}. You can import the `Context` type " + "from `flwr.common`." + ) + + # Wrap depcreated client_fn inside a function with the expected signature + def adaptor_fn( + context: Context, + ) -> Client: # pylint: disable=unused-argument + # if patition-id is defined, pass it. Else pass node_id that should + # always be defined during Context init. + cid = context.node_config.get("partition-id", context.node_id) + return client_fn(str(cid)) # type: ignore + + return adaptor_fn + + if first_arg_type is not Context: + _alert_erroneous_client_fn() + + else: + _alert_erroneous_client_fn() return client_fn