Skip to content

Commit

Permalink
Apply suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jun 13, 2024
1 parent 0db8ca4 commit 275483d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 67 deletions.
125 changes: 63 additions & 62 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing_extensions import Annotated

from flwr.cli import config_utils
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
Expand Down Expand Up @@ -51,70 +52,70 @@ def run(
) -> None:
"""Run Flower project."""
if use_superexec:
_start_superexec_run()
return

def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

channel = create_channel(
server_address="127.0.0.1:9093",
insecure=True,
root_certificates=None,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
config, errors, warnings = config_utils.load_and_validate()

if config is None:
typer.secho(
"Project configuration could not be loaded.\n"
"pyproject.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()

if warnings:
typer.secho(
"Project configuration is missing the following "
"recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

req = StartRunRequest()
res = stub.StartRun(req)
print(res)
typer.secho("Success", fg=typer.colors.GREEN)

server_app_ref = config["flower"]["components"]["serverapp"]
client_app_ref = config["flower"]["components"]["clientapp"]

if engine is None:
engine = config["flower"]["engine"]["name"]

if engine == Engine.SIMULATION:
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]

typer.secho("Starting run... ", fg=typer.colors.BLUE)
_run_simulation(
server_app_attr=server_app_ref,
client_app_attr=client_app_ref,
num_supernodes=num_supernodes,
)
else:
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

config, errors, warnings = config_utils.load_and_validate()

if config is None:
typer.secho(
"Project configuration could not be loaded.\n"
"pyproject.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()

if warnings:
typer.secho(
"Project configuration is missing the following "
"recommended properties:\n"
+ "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)

typer.secho("Success", fg=typer.colors.GREEN)

server_app_ref = config["flower"]["components"]["serverapp"]
client_app_ref = config["flower"]["components"]["clientapp"]

if engine is None:
engine = config["flower"]["engine"]["name"]

if engine == Engine.SIMULATION:
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"][
"num"
]

typer.secho("Starting run... ", fg=typer.colors.BLUE)
_run_simulation(
server_app_attr=server_app_ref,
client_app_attr=client_app_ref,
num_supernodes=num_supernodes,
)
else:
typer.secho(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)


def _start_superexec_run():
def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)

channel = create_channel(
server_address=SUPEREXEC_DEFAULT_ADDRESS,
insecure=True,
root_certificates=None,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
)
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

req = StartRunRequest()
stub.StartRun(req)
2 changes: 2 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
TRANSPORT_TYPE_VCE,
]

SUPEREXEC_DEFAULT_ADDRESS = "0.0.0.0:9093"

# Constants for ping
PING_DEFAULT_INTERVAL = 30
PING_CALL_TIMEOUT = 5
Expand Down
31 changes: 26 additions & 5 deletions src/py/flwr/superexec/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from flwr.common import EventType, event, log
from flwr.common.address import parse_address
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.object_ref import load_app, validate

Expand Down Expand Up @@ -60,7 +61,7 @@ def run_superexec() -> None:

# Graceful shutdown
register_exit_handlers(
event_type=EventType.RUN_SUPEREXEC_ENTER,
event_type=EventType.RUN_SUPEREXEC_LEAVE,
grpc_servers=grpc_servers,
bckg_threads=None,
)
Expand All @@ -76,16 +77,17 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
description="Start a Flower SuperExec",
)
parser.add_argument(
"executor-plugin",
"executor",
help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
required=True,
)
parser.add_argument(
"--address",
help="SuperExec (gRPC) server address (IPv4, IPv6, or a domain name)",
default="0.0.0.0:9093",
default=SUPEREXEC_DEFAULT_ADDRESS,
)
parser.add_argument(
"--dir",
"--executor-dir",
help="The directory for the plugin.",
default=".",
)
Expand All @@ -96,6 +98,25 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
"paths are provided. By default, the server runs with HTTPS enabled. "
"Use this flag only if you understand the risks.",
)
parser.add_argument(
"--ssl-certfile",
help="Fleet API server SSL certificate file (as a path str) "
"to create a secure connection.",
type=str,
default=None,
)
parser.add_argument(
"--ssl-keyfile",
help="Fleet API server SSL private key file (as a path str) "
"to create a secure connection.",
type=str,
)
parser.add_argument(
"--ssl-ca-certfile",
help="Fleet API server SSL CA certificate file (as a path str) "
"to create a secure connection.",
type=str,
)
return parser


Expand Down Expand Up @@ -142,7 +163,7 @@ def _get_exec_plugin(
if exec_plugin_dir is not None:
sys.path.insert(0, exec_plugin_dir)

plugin_ref: str = getattr(args, "executor-plugin")
plugin_ref: str = getattr(args, "executor")
valid, error_msg = validate(plugin_ref)
if not valid and error_msg:
raise LoadExecPluginError(error_msg) from None
Expand Down

0 comments on commit 275483d

Please sign in to comment.