diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index fe7b2f32a104..2df14969e24e 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -35,6 +35,11 @@ from flwr.proto.exec_pb2_grpc import ExecStub +def on_channel_state_change(channel_connectivity: str) -> None: + """Log channel connectivity.""" + log(DEBUG, channel_connectivity) + + # pylint: disable-next=too-many-locals def run( app: Annotated[ @@ -122,10 +127,6 @@ def _run_with_superexec( config_overrides: Optional[List[str]], ) -> None: - def on_channel_state_change(channel_connectivity: str) -> None: - """Log channel connectivity.""" - log(DEBUG, channel_connectivity) - insecure_str = federation_config.get("insecure") if root_certificates := federation_config.get("root-certificates"): root_certificates_bytes = Path(root_certificates).read_bytes() diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e20eee78e631..e42c4d462fed 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -22,6 +22,7 @@ from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union +import grpc from cryptography.hazmat.primitives.asymmetric import ec from grpc import RpcError @@ -43,6 +44,8 @@ from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential from flwr.common.typing import Fab, Run, UserConfig +from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server +from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -50,6 +53,9 @@ from .message_handler.message_handler import handle_control_message from .node_state import NodeState from .numpy_client import NumPyClient +from .process.clientappio_servicer import ClientAppIoServicer + +ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094" def _check_actionable_client( @@ -667,3 +673,22 @@ def signal_handler(sig, frame): # type: ignore signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) + + +def run_clientappio_api_grpc( + address: str = ADDRESS_CLIENTAPPIO_API_GRPC_RERE, +) -> Tuple[grpc.Server, ClientAppIoServicer]: + """Run ClientAppIo API gRPC server.""" + clientappio_servicer: grpc.Server = ClientAppIoServicer() + clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server + clientappio_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=( + clientappio_servicer, + clientappio_add_servicer_to_server_fn, + ), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + ) + log(INFO, "Starting Flower ClientAppIo gRPC server on %s", address) + clientappio_grpc_server.start() + return clientappio_grpc_server, clientappio_servicer diff --git a/src/py/flwr/client/process/process.py b/src/py/flwr/client/process/process.py new file mode 100644 index 000000000000..a1841940823c --- /dev/null +++ b/src/py/flwr/client/process/process.py @@ -0,0 +1,143 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ClientApp process.""" + +from logging import DEBUG, ERROR, INFO +from typing import Tuple + +import grpc + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common import Context, Message +from flwr.common.constant import ErrorCode +from flwr.common.grpc import create_channel +from flwr.common.logger import log +from flwr.common.message import Error +from flwr.common.serde import ( + context_from_proto, + context_to_proto, + message_from_proto, + message_to_proto, + run_from_proto, +) +from flwr.common.typing import Run + +# pylint: disable=E0611 +from flwr.proto.clientappio_pb2 import ( + PullClientAppInputsRequest, + PullClientAppInputsResponse, + PushClientAppOutputsRequest, + PushClientAppOutputsResponse, +) +from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub + +from .utils import _get_load_client_app_fn + + +def on_channel_state_change(channel_connectivity: str) -> None: + """Log channel connectivity.""" + log(DEBUG, channel_connectivity) + + +def _run_clientapp( # pylint: disable=R0914 + address: str, + token: int, +) -> None: + """Run Flower ClientApp process. + + Parameters + ---------- + address : str + Address of SuperNode + token : int + Unique SuperNode token for ClientApp-SuperNode authentication + """ + channel = create_channel( + server_address=address, + insecure=True, + ) + channel.subscribe(on_channel_state_change) + + try: + stub = ClientAppIoStub(channel) + + # Pull Message, Context, and Run from SuperNode + message, context, run = pull_message(stub=stub, token=token) + + load_client_app_fn = _get_load_client_app_fn( + default_app_ref="", + app_path=None, + multi_app=True, + flwr_dir=None, + ) + + try: + # Load ClientApp + client_app: ClientApp = load_client_app_fn(run.fab_id, run.fab_version) + + # Execute ClientApp + reply_message = client_app(message=message, context=context) + except Exception as ex: # pylint: disable=broad-exception-caught + # Don't update/change NodeState + + e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION + # Ex fmt: ":<'division by zero'>" + reason = str(type(ex)) + ":<'" + str(ex) + "'>" + exc_entity = "ClientApp" + if isinstance(ex, LoadClientAppError): + reason = "An exception was raised when attempting to load `ClientApp`" + e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION + + log(ERROR, "%s raised an exception", exc_entity, exc_info=ex) + + # Create error message + reply_message = message.create_error_reply( + error=Error(code=e_code, reason=reason) + ) + + # Push Message and Context to SuperNode + _ = push_message(stub=stub, token=token, message=reply_message, context=context) + + except KeyboardInterrupt: + log(INFO, "Closing connection") + except grpc.RpcError as e: + log(ERROR, "GRPC error occurred: %s", str(e)) + finally: + channel.close() + + +def pull_message(stub: grpc.Channel, token: int) -> Tuple[Message, Context, Run]: + """Pull message from SuperNode to ClientApp.""" + res: PullClientAppInputsResponse = stub.PullClientAppInputs( + PullClientAppInputsRequest(token=token) + ) + message = message_from_proto(res.message) + context = context_from_proto(res.context) + run = run_from_proto(res.run) + return message, context, run + + +def push_message( + stub: grpc.Channel, token: int, message: Message, context: Context +) -> PushClientAppOutputsResponse: + """Push message to SuperNode from ClientApp.""" + proto_message = message_to_proto(message) + proto_context = context_to_proto(context) + res: PushClientAppOutputsResponse = stub.PushClientAppOutputs( + PushClientAppOutputsRequest( + token=token, message=proto_message, context=proto_context + ) + ) + return res diff --git a/src/py/flwr/client/process/utils.py b/src/py/flwr/client/process/utils.py new file mode 100644 index 000000000000..e52eba93a92b --- /dev/null +++ b/src/py/flwr/client/process/utils.py @@ -0,0 +1,108 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ClientApp loading utils.""" + +from logging import DEBUG +from pathlib import Path +from typing import Callable, Optional + +from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.common.config import ( + get_flwr_dir, + get_metadata_from_config, + get_project_config, + get_project_dir, +) +from flwr.common.logger import log +from flwr.common.object_ref import load_app, validate + + +def _get_load_client_app_fn( + default_app_ref: str, + app_path: Optional[str], + multi_app: bool, + flwr_dir: Optional[str] = None, +) -> Callable[[str, str], ClientApp]: + """Get the load_client_app_fn function. + + If `multi_app` is True, this function loads the specified ClientApp + based on `fab_id` and `fab_version`. If `fab_id` is empty, a default + ClientApp will be loaded. + + If `multi_app` is False, it ignores `fab_id` and `fab_version` and + loads a default ClientApp. + """ + if not multi_app: + log( + DEBUG, + "Flower SuperNode will load and validate ClientApp `%s`", + default_app_ref, + ) + + valid, error_msg = validate(default_app_ref, project_dir=app_path) + if not valid and error_msg: + raise LoadClientAppError(error_msg) from None + + def _load(fab_id: str, fab_version: str) -> ClientApp: + runtime_app_dir = Path(app_path if app_path else "").absolute() + # If multi-app feature is disabled + if not multi_app: + # Set app reference + client_app_ref = default_app_ref + # If multi-app feature is enabled but app directory is provided + elif app_path is not None: + config = get_project_config(runtime_app_dir) + this_fab_version, this_fab_id = get_metadata_from_config(config) + + if this_fab_version != fab_version or this_fab_id != fab_id: + raise LoadClientAppError( + f"FAB ID or version mismatch: Expected FAB ID '{this_fab_id}' and " + f"FAB version '{this_fab_version}', but received FAB ID '{fab_id}' " + f"and FAB version '{fab_version}'.", + ) from None + + # log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.") + + # Set app reference + client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] + # If multi-app feature is enabled + else: + try: + runtime_app_dir = get_project_dir( + fab_id, fab_version, get_flwr_dir(flwr_dir) + ) + config = get_project_config(runtime_app_dir) + except Exception as e: + raise LoadClientAppError("Failed to load ClientApp") from e + + # Set app reference + client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] + + # Load ClientApp + log( + DEBUG, + "Loading ClientApp `%s`", + client_app_ref, + ) + client_app = load_app(client_app_ref, LoadClientAppError, runtime_app_dir) + + if not isinstance(client_app, ClientApp): + raise LoadClientAppError( + f"Attribute {client_app_ref} is not of type {ClientApp}", + ) from None + + return client_app + + return _load diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 5840b57c0ab6..b2cb1b5f033e 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -18,7 +18,7 @@ import sys from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives.asymmetric import ec @@ -27,15 +27,8 @@ load_ssh_public_key, ) -from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event -from flwr.common.config import ( - get_flwr_dir, - get_metadata_from_config, - get_project_config, - get_project_dir, - parse_config_args, -) +from flwr.common.config import get_flwr_dir, parse_config_args from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -43,9 +36,10 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature -from flwr.common.object_ref import load_app, validate from ..app import _start_client_internal +from ..process.process import _run_clientapp +from ..process.utils import _get_load_client_app_fn ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092" @@ -143,6 +137,7 @@ def flwr_clientapp() -> None: args.address, args.token, ) + _run_clientapp(address=args.address, token=int(args.token)) def _warn_deprecated_server_arg(args: argparse.Namespace) -> None: @@ -200,85 +195,6 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]: return root_certificates -def _get_load_client_app_fn( - default_app_ref: str, - app_path: Optional[str], - multi_app: bool, - flwr_dir: Optional[str] = None, -) -> Callable[[str, str], ClientApp]: - """Get the load_client_app_fn function. - - If `multi_app` is True, this function loads the specified ClientApp - based on `fab_id` and `fab_version`. If `fab_id` is empty, a default - ClientApp will be loaded. - - If `multi_app` is False, it ignores `fab_id` and `fab_version` and - loads a default ClientApp. - """ - if not multi_app: - log( - DEBUG, - "Flower SuperNode will load and validate ClientApp `%s`", - default_app_ref, - ) - - valid, error_msg = validate(default_app_ref, project_dir=app_path) - if not valid and error_msg: - raise LoadClientAppError(error_msg) from None - - def _load(fab_id: str, fab_version: str) -> ClientApp: - runtime_app_dir = Path(app_path if app_path else "").absolute() - # If multi-app feature is disabled - if not multi_app: - # Set app reference - client_app_ref = default_app_ref - # If multi-app feature is enabled but app directory is provided - elif app_path is not None: - config = get_project_config(runtime_app_dir) - this_fab_version, this_fab_id = get_metadata_from_config(config) - - if this_fab_version != fab_version or this_fab_id != fab_id: - raise LoadClientAppError( - f"FAB ID or version mismatch: Expected FAB ID '{this_fab_id}' and " - f"FAB version '{this_fab_version}', but received FAB ID '{fab_id}' " - f"and FAB version '{fab_version}'.", - ) from None - - # log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.") - - # Set app reference - client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] - # If multi-app feature is enabled - else: - try: - runtime_app_dir = get_project_dir( - fab_id, fab_version, get_flwr_dir(flwr_dir) - ) - config = get_project_config(runtime_app_dir) - except Exception as e: - raise LoadClientAppError("Failed to load ClientApp") from e - - # Set app reference - client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"] - - # Load ClientApp - log( - DEBUG, - "Loading ClientApp `%s`", - client_app_ref, - ) - client_app = load_app(client_app_ref, LoadClientAppError, runtime_app_dir) - - if not isinstance(client_app, ClientApp): - raise LoadClientAppError( - f"Attribute {client_app_ref} is not of type {ClientApp}", - ) from None - - return client_app - - return _load - - def _parse_args_run_supernode() -> argparse.ArgumentParser: """Parse flower-supernode command line arguments.""" parser = argparse.ArgumentParser( diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index a8d02802a8b1..1df821d0f495 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -28,7 +28,7 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState -from flwr.client.supernode.app import _get_load_client_app_fn +from flwr.client.process.utils import _get_load_client_app_fn from flwr.common.constant import ( NUM_PARTITIONS_KEY, PARTITION_ID_KEY,