Skip to content

Commit

Permalink
feat(framework) Use InMemoryDriver in simulations (#3355)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
jafermarq and panh99 authored May 19, 2024
1 parent 938b087 commit 4f0eec9
Showing 1 changed file with 2 additions and 31 deletions.
33 changes: 2 additions & 31 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,13 @@
from time import sleep
from typing import Dict, Optional

import grpc

from flwr.client import ClientApp
from flwr.common import EventType, event, log
from flwr.common.logger import set_logger_propagation, update_console_handler
from flwr.common.typing import ConfigsRecordValues
from flwr.server.driver import Driver, GrpcDriver
from flwr.server.driver import Driver, InMemoryDriver
from flwr.server.run_serverapp import run
from flwr.server.server_app import ServerApp
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
from flwr.server.superlink.fleet import vce
from flwr.server.superlink.state import StateFactory
from flwr.simulation.ray_transport.utils import (
Expand All @@ -56,7 +53,6 @@ def run_simulation_from_cli() -> None:
backend_name=args.backend,
backend_config=backend_config_dict,
app_dir=args.app_dir,
driver_api_address=args.driver_api_address,
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
verbose_logging=args.verbose,
)
Expand Down Expand Up @@ -177,7 +173,6 @@ def _main_loop(
num_supernodes: int,
backend_name: str,
backend_config_stream: str,
driver_api_address: str,
app_dir: str,
enable_tf_gpu_growth: bool,
client_app: Optional[ClientApp] = None,
Expand All @@ -194,21 +189,11 @@ def _main_loop(
# Initialize StateFactory
state_factory = StateFactory(":flwr-in-memory-state:")

# Start Driver API
driver_server: grpc.Server = run_driver_api_grpc(
address=driver_api_address,
state_factory=state_factory,
certificates=None,
)

f_stop = asyncio.Event()
serverapp_th = None
try:
# Initialize Driver
driver = GrpcDriver(
driver_service_address=driver_api_address,
root_certificates=None,
)
driver = InMemoryDriver(state_factory)

# Get and run ServerApp thread
serverapp_th = run_serverapp_th(
Expand Down Expand Up @@ -239,9 +224,6 @@ def _main_loop(
raise RuntimeError("An error was encountered. Ending simulation.") from ex

finally:
# Stop Driver
driver_server.stop(grace=0)
driver.close()
# Trigger stop event
f_stop.set()

Expand All @@ -262,7 +244,6 @@ def _run_simulation(
client_app_attr: Optional[str] = None,
server_app_attr: Optional[str] = None,
app_dir: str = "",
driver_api_address: str = "0.0.0.0:9091",
enable_tf_gpu_growth: bool = False,
verbose_logging: bool = False,
) -> None:
Expand Down Expand Up @@ -302,9 +283,6 @@ def _run_simulation(
Add specified directory to the PYTHONPATH and load `ClientApp` from there.
(Default: current working directory.)
driver_api_address : str (default: "0.0.0.0:9091")
Driver API (gRPC) server address (IPv4, IPv6, or a domain name)
enable_tf_gpu_growth : bool (default: False)
A boolean to indicate whether to enable GPU growth on the main thread. This is
desirable if you make use of a TensorFlow model on your `ServerApp` while
Expand Down Expand Up @@ -342,7 +320,6 @@ def _run_simulation(
num_supernodes,
backend_name,
backend_config_stream,
driver_api_address,
app_dir,
enable_tf_gpu_growth,
client_app,
Expand Down Expand Up @@ -399,12 +376,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
required=True,
help="Number of simulated SuperNodes.",
)
parser.add_argument(
"--driver-api-address",
default="0.0.0.0:9091",
type=str,
help="For example: `server:app` or `project.package.module:wrapper.app`",
)
parser.add_argument(
"--backend",
default="ray",
Expand Down

0 comments on commit 4f0eec9

Please sign in to comment.