From 4f0eec924bef1cc9a252b0635464396c79323b5d Mon Sep 17 00:00:00 2001 From: Javier Date: Sun, 19 May 2024 16:53:33 +0100 Subject: [PATCH] feat(framework) Use `InMemoryDriver` in simulations (#3355) Co-authored-by: Heng Pan --- src/py/flwr/simulation/run_simulation.py | 33 ++---------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index cd475cf7ebdf..2dbeef1a261c 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -24,16 +24,13 @@ from time import sleep from typing import Dict, Optional -import grpc - from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.logger import set_logger_propagation, update_console_handler from flwr.common.typing import ConfigsRecordValues -from flwr.server.driver import Driver, GrpcDriver +from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp -from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc from flwr.server.superlink.fleet import vce from flwr.server.superlink.state import StateFactory from flwr.simulation.ray_transport.utils import ( @@ -56,7 +53,6 @@ def run_simulation_from_cli() -> None: backend_name=args.backend, backend_config=backend_config_dict, app_dir=args.app_dir, - driver_api_address=args.driver_api_address, enable_tf_gpu_growth=args.enable_tf_gpu_growth, verbose_logging=args.verbose, ) @@ -177,7 +173,6 @@ def _main_loop( num_supernodes: int, backend_name: str, backend_config_stream: str, - driver_api_address: str, app_dir: str, enable_tf_gpu_growth: bool, client_app: Optional[ClientApp] = None, @@ -194,21 +189,11 @@ def _main_loop( # Initialize StateFactory state_factory = StateFactory(":flwr-in-memory-state:") - # Start Driver API - driver_server: grpc.Server = run_driver_api_grpc( - address=driver_api_address, - state_factory=state_factory, - certificates=None, - ) - f_stop = asyncio.Event() serverapp_th = None try: # Initialize Driver - driver = GrpcDriver( - driver_service_address=driver_api_address, - root_certificates=None, - ) + driver = InMemoryDriver(state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( @@ -239,9 +224,6 @@ def _main_loop( raise RuntimeError("An error was encountered. Ending simulation.") from ex finally: - # Stop Driver - driver_server.stop(grace=0) - driver.close() # Trigger stop event f_stop.set() @@ -262,7 +244,6 @@ def _run_simulation( client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, app_dir: str = "", - driver_api_address: str = "0.0.0.0:9091", enable_tf_gpu_growth: bool = False, verbose_logging: bool = False, ) -> None: @@ -302,9 +283,6 @@ def _run_simulation( Add specified directory to the PYTHONPATH and load `ClientApp` from there. (Default: current working directory.) - driver_api_address : str (default: "0.0.0.0:9091") - Driver API (gRPC) server address (IPv4, IPv6, or a domain name) - enable_tf_gpu_growth : bool (default: False) A boolean to indicate whether to enable GPU growth on the main thread. This is desirable if you make use of a TensorFlow model on your `ServerApp` while @@ -342,7 +320,6 @@ def _run_simulation( num_supernodes, backend_name, backend_config_stream, - driver_api_address, app_dir, enable_tf_gpu_growth, client_app, @@ -399,12 +376,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: required=True, help="Number of simulated SuperNodes.", ) - parser.add_argument( - "--driver-api-address", - default="0.0.0.0:9091", - type=str, - help="For example: `server:app` or `project.package.module:wrapper.app`", - ) parser.add_argument( "--backend", default="ray",