Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Sep 23, 2024
2 parents 57930df + b370b06 commit 25d5c1e
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def main(driver: Driver, context: Context) -> None:
strategy = FedAvg(
# Select all available clients
fraction_fit=1.0,
min_fit_clients=5,
# Disable evaluation in demo
fraction_evaluate=(0.0 if is_demo else context.run_config["fraction-evaluate"]),
min_available_clients=5,
Expand Down
7 changes: 7 additions & 0 deletions src/proto/flwr/proto/control.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,11 @@ import "flwr/proto/run.proto";
service Control {
// Request to create a new run
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Get the status of a given run
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}

// Update the status of a given run
rpc UpdateRunStatus(UpdateRunStatusRequest)
returns (UpdateRunStatusResponse) {}
}
20 changes: 20 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ message Run {
string fab_hash = 5;
}

message RunStatus {
// "starting", "running", "finished"
string status = 1;
// "completed", "failed", "stopped" or "" (non-finished)
string sub_status = 2;
// failure details
string details = 3;
}

// CreateRun
message CreateRunRequest {
string fab_id = 1;
Expand All @@ -40,3 +49,14 @@ message CreateRunResponse { uint64 run_id = 1; }
// GetRun
message GetRunRequest { uint64 run_id = 1; }
message GetRunResponse { Run run = 1; }

// UpdateRunStatus
message UpdateRunStatusRequest {
uint64 run_id = 1;
RunStatus run_status = 2;
}
message UpdateRunStatusResponse {}

// GetRunStatus
message GetRunStatusRequest { repeated uint64 run_ids = 1; }
message GetRunStatusResponse { map<uint64, RunStatus> run_status_dict = 1; }
34 changes: 20 additions & 14 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
from logging import DEBUG, ERROR, INFO
from pathlib import Path
from typing import Annotated, Any, Optional
from typing import Annotated, Optional

import grpc
import typer
Expand Down Expand Up @@ -79,14 +79,19 @@ def print_logs(
logger(DEBUG, "Channel closed")


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
logger(DEBUG, channel_connectivity)


def log(
run_id: Annotated[
int,
typer.Argument(help="The Flower run ID to query"),
],
app: Annotated[
Path,
typer.Argument(help="Path of the Flower App to run"),
typer.Argument(help="Path of the Flower project to run"),
] = Path("."),
federation: Annotated[
Optional[str],
Expand All @@ -100,7 +105,7 @@ def log(
),
] = True,
) -> None:
"""Get logs from a Flower App run."""
"""Get logs from a Flower project run."""
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

pyproject_path = app / "pyproject.toml" if app else None
Expand Down Expand Up @@ -145,31 +150,32 @@ def log(
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
}
typer.secho(
f"❌ There is no `{federation}` federation declared in "
f"❌ There is no `{federation}` federation declared in the "
"`pyproject.toml`.\n The following federations were found:\n\n"
+ "\n".join(available_feds),
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

if "address" in federation_config:
_log_with_superexec(federation_config, run_id, stream)
else:
pass
if "address" not in federation_config:
typer.secho(
"❌ `flwr log` currently works with `SuperExec`. Ensure that the correct"
"`SuperExec` address is provided in the `pyproject.toml`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

_log_with_superexec(federation_config, run_id, stream)


# pylint: disable-next=too-many-branches
def _log_with_superexec(
federation_config: dict[str, Any],
federation_config: dict[str, str],
run_id: int,
stream: bool,
) -> None:

def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
logger(DEBUG, channel_connectivity)

insecure_str = federation_config.get("insecure")
if root_certificates := federation_config.get("root-certificates"):
root_certificates_bytes = Path(root_certificates).read_bytes()
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/proto/control_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions src/py/flwr/proto/control_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
)
self.GetRunStatus = channel.unary_unary(
'/flwr.proto.Control/GetRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
)
self.UpdateRunStatus = channel.unary_unary(
'/flwr.proto.Control/UpdateRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
)


class ControlServicer(object):
Expand All @@ -31,6 +41,20 @@ def CreateRun(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRunStatus(self, request, context):
"""Get the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def UpdateRunStatus(self, request, context):
"""Update the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_ControlServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -39,6 +63,16 @@ def add_ControlServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString,
),
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.GetRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
),
'UpdateRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.UpdateRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Control', rpc_method_handlers)
Expand All @@ -65,3 +99,37 @@ def CreateRun(request,
flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus',
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def UpdateRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/UpdateRunStatus',
flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
26 changes: 26 additions & 0 deletions src/py/flwr/proto/control_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ class ControlStub:
flwr.proto.run_pb2.CreateRunResponse]
"""Request to create a new run"""

GetRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.GetRunStatusRequest,
flwr.proto.run_pb2.GetRunStatusResponse]
"""Get the status of a given run"""

UpdateRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.UpdateRunStatusRequest,
flwr.proto.run_pb2.UpdateRunStatusResponse]
"""Update the status of a given run"""


class ControlServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand All @@ -23,5 +33,21 @@ class ControlServicer(metaclass=abc.ABCMeta):
"""Request to create a new run"""
pass

@abc.abstractmethod
def GetRunStatus(self,
request: flwr.proto.run_pb2.GetRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.GetRunStatusResponse:
"""Get the status of a given run"""
pass

@abc.abstractmethod
def UpdateRunStatus(self,
request: flwr.proto.run_pb2.UpdateRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.UpdateRunStatusResponse:
"""Update the status of a given run"""
pass


def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
32 changes: 23 additions & 9 deletions src/py/flwr/proto/run_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 25d5c1e

Please sign in to comment.